Gemma : On Feb 21, 2024, Google announced a new series of light weight state of the art open models named Gemma built from the same research and technology used to create gemini models. In this blog we will see a quick performance summary of the Gemma, followed by a quick review of its key components along with their implementations.
Here are some key details to know from the official blog. You can try the demo here.
- Model is released in two sizes : Gemma 2B and Gemma 7B, each size is released with pre-trained and instruction tuned variants.
- Toolchain for inference and supervised fine-tuning across all major frameworks such as JAX, tensorflow, and pytorch through Keras 3.0.
- Terms of use permit responsible commercial usage and distribution for all organizations, regardless of size.
Performance benchmarks:
In the technical report of Gemma, the team has evaluated the performance of the model over various standard academic benchmarks and grouped them based on capabilities and average the respective scores to compare with LLaMA 2 ( 13B and 7B ) and Mistral 7B. Below chart shows the relative performance of Gemma and other models. We can see Gemma 7B is outperforming LLaMA2 7B and Mistral 7B with a significant margin and performance is very similar to LLaMA2 13B.
The model architecture is based on the transformer’s decoder architecture proposed in the Attention is all you need paper. Gemma is trained with a context length of 8192 tokens, the other core parameters of the Gemma model are specified in the technical report as :
Gemma utilizes several improvements proposed after the original transformers paper, lets have a look at them along with their implementations.
Multi-Query Attention
The implementation is available here. The highlighted code shows the use of self.num_key_value_heads instead of self.num_heads.
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
Embeddings
class GemmaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
GeGLU Activations
class GELUActivation(nn.Module):
def __init__(self, use_gelu_python: bool = False):
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
The output of the above GELU activation is then multiplied by the output of up_proj linear layer as shown below to achieve GeGLU.
class GemmaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
Normalization
1. In technical report it is mentioned about the normalizer location as “both the input and the output of each transformer sub-layer is normalized, a deviation from the standard practice of solely normalizing one or the other” but in the huggingface implementation they just have one. On the other hand in the official implementation it can be seen that there are two nomalizations (input as well as output).
class RMSNorm(torch.nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = True,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
x = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
output = x * (1 + self.weight)
else:
output = x * self.weight
return output
Below is the decoder layer from official implementation which has the RMSNorm applied to the input and output of the decoder layer
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
):
super().__init__()
self.self_attn = GemmaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
quant=config.quant,
)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant=config.quant,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states