Gemma : The secret ingredient behind gemini?

Gemma : On Feb 21, 2024, Google announced a new series of light weight state of the art open models named Gemma built from the same research and technology used to create gemini models. In this blog we will see a quick performance summary of the Gemma, followed by a quick review of its key components along with their implementations.

Here are some key details to know from the official blog. You can try the demo here.

 

  • Model is released in two sizes : Gemma 2B and Gemma 7B, each size is released with pre-trained and instruction tuned variants.
  • Toolchain for inference and supervised fine-tuning across all major frameworks such as JAX, tensorflow, and pytorch through Keras 3.0.
  • Terms of use permit responsible commercial usage and distribution for all organizations, regardless of size.

Performance benchmarks:

In the technical report of Gemma, the team has evaluated the performance of the model over various standard academic benchmarks and grouped them based on capabilities and average the respective scores to compare with LLaMA 2 ( 13B and 7B ) and Mistral 7B. Below chart shows the relative performance of Gemma and other models. We can see Gemma 7B is outperforming LLaMA2 7B and Mistral 7B with a significant margin and performance is very similar to LLaMA2 13B.

benchmark

The model architecture is based on the transformer’s decoder architecture proposed in the Attention is all you need paper. Gemma is trained with a context length of 8192 tokens, the other core parameters of the Gemma model are specified in the technical report as :

params

Gemma utilizes several improvements proposed after the original transformers paper, lets have a look at them along with their implementations.

Multi-Query Attention 
From the above model parameters details we can see that number of heads for K,V and Q are different in 2B sized model as single head for K, V and 8 heads for Q whereas in 7B model all the K, V, Q are using 16 heads.
The implementation is available here. The highlighted code shows the use of self.num_key_value_heads instead of self.num_heads.
				
					self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
				
			
Embeddings 
Instead of using absolute positional embeddings, Gemma uses rotary positional embeddings in each layer, Gemma also share embeddings across inputs and outputs to reduce model size.
				
					class GemmaRotaryEmbedding(nn.Module):
        def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
            super().__init__()
            self.dim = dim
            self.max_position_embeddings = max_position_embeddings
            self.base = base
            self.register_buffer("inv_freq", None, persistent=False)

        def forward(self, x, position_ids, seq_len=None):
            # x: [bs, num_attention_heads, seq_len, head_size]
            if self.inv_freq is None:
                self.inv_freq = 1.0 / (
                    self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
                )
            inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
            position_ids_expanded = position_ids[:, None, :].float()
            freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
				
			
GeGLU Activations
The standard ReLU non-linearity is replaced by the GeGLU activation function. GeGLU is an activation which is a variant of GLU activation. In Gemma the below implementation of GeLU is used
GeGLU(x,W,V,b,c)=GELU(xW+b)(xV+c)
				
					class GELUActivation(nn.Module):
        def __init__(self, use_gelu_python: bool = False):
            super().__init__()
            if use_gelu_python:
                self.act = self._gelu_python
            else:
                self.act = nn.functional.gelu
    
        def _gelu_python(self, input: Tensor) -> Tensor:
            return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
    
        def forward(self, input: Tensor) -> Tensor:
            return self.act(input)
				
			

The output of the above GELU activation is then multiplied by the output of up_proj linear layer as shown below to achieve GeGLU.

				
					class GemmaMLP(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config
            self.hidden_size = config.hidden_size
            self.intermediate_size = config.intermediate_size
            self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
            self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
            self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
            self.act_fn = ACT2FN[config.hidden_act]
    
        def forward(self, x):
            return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
				
			
Normalization
Gemma utilizes RMSNorm as the normalization function. Whereas, there are two interesting findings about the normalization used in Gemma as:
1. In technical report it is mentioned about the normalizer location as “both the input and the output of each transformer sub-layer is normalized, a deviation from the standard practice of solely normalizing one or the other” but in the huggingface implementation they just have one. On the other hand in the official implementation it can be seen that there are two nomalizations (input as well as output).
2. There is an additional unit offset present in the RMSNorm which was not mentioned in the technical report.
				
					class RMSNorm(torch.nn.Module):

        def __init__(
            self,
            dim: int,
            eps: float = 1e-6,
            add_unit_offset: bool = True,
        ):
            super().__init__()
            self.eps = eps
            self.add_unit_offset = add_unit_offset
            self.weight = nn.Parameter(torch.zeros(dim))
    
        def _norm(self, x):
            return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    
        def forward(self, x):
            x = self._norm(x.float()).type_as(x)
            if self.add_unit_offset:
                output = x * (1 + self.weight)
            else:
                output = x * self.weight
            return output
				
			

Below is the decoder layer from official implementation which has the RMSNorm applied to the input and output of the decoder layer

				
					class GemmaDecoderLayer(nn.Module):

        def __init__(
            self,
            config: gemma_config.GemmaConfig,
        ):
            super().__init__()
            self.self_attn = GemmaAttention(
                hidden_size=config.hidden_size,
                num_heads=config.num_attention_heads,
                num_kv_heads=config.num_key_value_heads,
                head_dim=config.head_dim,
                quant=config.quant,
            )
            self.mlp = GemmaMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                quant=config.quant,
            )
            self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    
        def forward(
            self,
            hidden_states: torch.Tensor,
            freqs_cis: torch.Tensor,
            kv_write_indices: torch.Tensor,
            kv_cache: Tuple[torch.Tensor, torch.Tensor],
            mask: torch.Tensor,
        ) -> torch.Tensor:
            # Self Attention
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
            hidden_states = self.self_attn(
                hidden_states=hidden_states,
                freqs_cis=freqs_cis,
                kv_write_indices=kv_write_indices,
                kv_cache=kv_cache,
                mask=mask,
            )
            hidden_states = residual + hidden_states
    
            # MLP
            residual = hidden_states
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.mlp(hidden_states)
            hidden_states = residual + hidden_states
    
            return hidden_states
				
			

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top