Vanilla Meta Llama 3 in 256 Lines of Code (Inference only)
I just finished watching Andrej Karpathy’s video Let’s reproduce GPT-2 (124M). It is fascinating to break down complex LLM systems into such a simplified version so that everyone with basic machine learning knowledge can understand it. I’m not an LLM researcher, but it doesn’t seem difficult to break down some open-source LLMs since there are dozens of public codes and blogs available. In this blog, I will show a vanilla Llama 3 implementation, which loads the pre-trained weights into the network, ensuring the output matches the output from the Hugging Face implementation. I’m doing all this just for fun. However, if there are any mistakes, don’t hesitate to let me know.
HuggingFace LLama 3
First, we’ll work with the Hugging Face GPT. The goals are:
- To know the exact network layers of the Llama 3 pre-trained model, so we can build a compatible model.
- To get the generated answer from the original Llama 3 model, so we can compare our results with it and verify our implementation.
First, we print the names of each Llama 3 layer:
Then, we get an example output from the Hugging Face Llama 3.
import transformers
import torch
from transformers import set_seed
pipeline = transformers.pipeline(
"text-generation",
model="meta-llama/Meta-Llama-3-8B-Instruct",
device="cuda",
use_cache=False,
)
set_seed(42)
output = pipeline("Hello, I'm a language model,", max_length=100, num_return_sequences=1, truncation=True, do_sample=False)
print(output)
# --------------------------------------------------------------------------------------------------- #
# output:
[{'generated_text': "Hello, I'm a language model, and I'm here to help you with your questions. I can provide information on a wide range of topics, from science and history to entertainment and culture. I can also help you with language-related tasks, such as grammar and vocabulary practice, and even assist with writing and proofreading. So, what's on your mind? What do you want to talk about or ask me? I'm all ears!"}]
My vanilla LLama 3
Here is my vanilla Llama 3. Some code is learned from or directly borrowed from the HuggingFace Llama 3 implementation or the Meta Llama 3 GitHub repository.
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from tokenizer import Tokenizer
from torch.nn import functional as F
@dataclass
class ModelArgs:
dim: int = 4096
n_kv_heads: int = 8
vocab_size: int = 128256 # 128000 BPE merges + 256 bytes tokens
n_layers: int = 32
n_heads: int = 32
ffn_dim_multiplier: float = 1.3
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
norm_eps: float = 1e-5
rope_theta: float = 500000
max_seq_len: int = 2048
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class Attention(nn.Module):
def __init__(self, model_args: ModelArgs) -> None:
super().__init__()
self.dim, self.n_heads = model_args.dim, model_args.n_heads
self.head_dim = model_args.dim // model_args.n_heads
self.n_kv_heads = model_args.n_kv_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.q_proj = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=model_args.max_seq_len, base=model_args.rope_theta)
def forward(self, x, pos_ids):
bs, seqlen, _ = x.shape
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
xq = xq.view(bs, seqlen, self.n_heads, self.head_dim).transpose(1, 2)
xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(xv, pos_ids)
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin)
# repeat k/v heads if n_kv_heads < n_heads
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
output = output.view(bs, seqlen, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, model_args: ModelArgs) -> None:
super().__init__()
hidden_dim = int(2 * model_args.dim * 4 / 3)
hidden_dim = int(model_args.ffn_dim_multiplier * hidden_dim)
hidden_dim = model_args.multiple_of * ((hidden_dim + model_args.multiple_of - 1) // model_args.multiple_of)
self.gate_proj = nn.Linear(model_args.dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(model_args.dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, model_args.dim, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, model_args: ModelArgs) -> None:
super().__init__()
self.self_attn = Attention(model_args)
self.mlp = MLP(model_args)
self.input_layernorm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
self.post_attention_layernorm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
def forward(self, x, pos_ids):
h = x + self.self_attn(self.input_layernorm(x), pos_ids)
out = h + self.mlp(self.post_attention_layernorm(h))
return out
class GPT(nn.Module):
def __init__(self, model_args: ModelArgs) -> None:
super().__init__()
self.embed_tokens = nn.Embedding(model_args.vocab_size, model_args.dim)
self.layers = nn.ModuleDict()
for layer_id in range(model_args.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)
self.lm_head = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
def forward(self, x):
bs, seqlen = x.shape
pos_ids = torch.arange(seqlen, device=x.device).unsqueeze(0).expand(bs, -1)
h = self.embed_tokens(x)
for layer in self.layers.values():
# h = layer(h, self.freqs_cis)
h = layer(h, pos_ids)
h = self.norm(h)
output = self.lm_head(h)
return output
@classmethod
def from_pretrained(cls, model_type):
config = ModelArgs()
model = GPT(config)
sd = model.state_dict()
sd_keys = sd.keys()
# init a huggingface/transformers model
from transformers import AutoModelForCausalLM
model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
sd_hf = model_hf.state_dict()
# copy while ensuring all of the parameters are aligned and match in names and shapes
sd_keys_hf = sd_hf.keys()
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
# vanilla copy over the other parameters
assert sd_hf[k].shape == sd[k.replace('model.', '')].shape
with torch.no_grad():
sd[k.replace('model.', '')].copy_(sd_hf[k])
return model
# --------------------------------------------------------------------------------------------------- #
num_return_sequences = 1
max_length = 100
# model = GPT(ModelArgs())
model = GPT.from_pretrained("llama3")
# print model layers
sd = model.state_dict()
for k, v in sd.items():
print(k, v.shape)
model.eval()
model.cuda()
# prefix tokens
enc = Tokenizer(model_path="llama3/tokenizer.model")
tokens = enc.encode("Hello, I'm a language model,", bos=False, eos=False)
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
x = tokens.to('cuda')
torch.manual_seed(42)
torch.cuda.manual_seed(42)
while x.size(1) < max_length:
with torch.no_grad():
logits = model(x)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, 1, dim=-1)
ix = torch.multinomial(topk_probs, num_samples=1)
xcol = torch.gather(topk_indices, -1, ix)
x = torch.cat((x, xcol), dim=1)
for i in range(num_return_sequences):
tokens = x[i, :max_length].tolist()
try:
# Try to find the index of token 128009
index = tokens.index(128009)
# Cut off all tokens from this index onward
tokens = tokens[:index]
except ValueError:
# Handle the case where 128009 is not in the list
print("Token 128009 is not in the list. No changes made.")
decoded = enc.decode(tokens)
print(">", decoded)
Results: