HuggingFace Transformers
TL; DR
The Llama Model forward logic is located in LlamaModel.forward:src/transformers/models/llama/modeling_llama.py,518-624
First Setup Your Environment
Clone the github repository “byrzhm/llama2-hf-start” into local workspace. Run the following commands in terminal.
1
2
3
4
5
6
git clone git@github.com:byrzhm/llama2-hf-start.git
cd llama2-hf-start
conda create -n llama2 python=3.10
conda activate llama2
pip install -r requirements.txt
Please make sure that
- your GPU memory is larger than 14GB and
- your HuggingFace account has access to the Llama2 models.
Choose Your Favourite IDE
You can use VSCode, PyCharm, or any other IDE to run Python programs. Here I use PyCharm. I set the Run/Debug Script to run.py and run the script step by step in debug mode.
Going Deeper with the Fxxking Source Code
Preparing the Pretrained Model and Tokenizer
from_pretrained
1
2
3
4
5
6
7
8
9
10
model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="cuda",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
The method AutoModelForCausalLM.from_pretrained initializes and returns a PyTorch model. To initialize, it loads the weights into the target device (if device_map is not set, it will be set to “cpu” as default). If torch.dtype is not set, it will use torch.float32 as default.
For a 24GB RTX 4090 GPU and llama2-7B Model, if
torch.dtypeis not set totorch.float16, it will cause GPU out of memory. Because the 7B model in the default FP32 format requires 7 billion parameters * 4 bytes per parameter = 28 GB, which exceeds the GPU memory capacity.
To fetch the weights and configuration file, the model uses the cached_file function, which can retrieve them from the Internet or a local cache. The cached_file function definition is located in src/transformers/utils/hub.py.
To successfully create a pretrained model, AutoModelForCausalLM needs to create a Config instance by calling AutoConfig.from_pretrained first, which is located in src/transformers/auto/configuration_auto.py.
After parsing the configuration file, we know that it is a llama model. So AutoModelForCausalLM.from_pretrained will call LlamaForCausalLM.from_pretrained, see src/transformers/models/auto/auto_factory.py.
1
2
3
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
The method
LlamaForCausalLM.from_pretrainedcallsPreTrainedModel.from_pretraineddue to inheritance.
Fetching Checkpoint and Creating Models
It first attempts to get the weights file "model.safetensors". Because Llama-7B-Chat-HF model doesn’t does not have "model.safetensors" but instead has "model-00001-of-00002.safetensors" and "model-00002-of-00002.safetensors", cached_file will return None. Therefore, it will attempt to fetch model.safetensors.index.json. After cached_file successfully returns the path to model.safetensors.index.json, it determines that the weights are sharded. Then it will call get_checkpoint_shard_files to download sharded weights.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
After get_checkpoint_shard_files completes, resolved_archive_file will be ['/path/to/model-00001-of-00002.safetensors', '/path/to/model-00002-of-00002.safetensors'], and sharded_metadata will be the content of model.safetensors.index.json plus all the checkpoint keys (e.g., 'lm_head.weight').
Then it will instantiate a model. Notice that currently weights are not going to be loaded into CPU or GPU memory. To instantiate a LlamaForCausalLM instance, three __init__ functions will be called, they are
LlamaForCausalLM.__init__,PreTrainedModel.__init__torch.nn.Module.__init__.1 2 3
with ContextManagers(init_contexts): # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs)
In LlamaForCausalLM.__init__, a LlamaModel instance is created. To create a LlamaModel instance, three _init_ functions will be called, they are
Loading the Weights into Memory
Now we will load the weights from disk to memory. The loading task is done in PreTrainedModel._load_pretrained_model.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
with ContextManagers(load_contexts):
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)
The following code fragment actually loads weights into memory.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
assign_to_params_buffers = None
for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
continue
map_location = None
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
)
if low_cpu_mem_usage:
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
set_module_tensor_to_device(
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
hf_quantizer=hf_quantizer,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
error_msgs += new_error_msgs
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs += _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
)
# force memory release
del state_dict
gc.collect()
Click to see more about "PreTrainedModel._load_pretrained_model"
We load the sharded checkpoint files into memory one by one.
First, we need to load the shard_file into memory in state_dict format by calling load_state_dict, which calls safe_load_file, alias of safetensors.torch.load_file, inside.
1
2
3
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
If the target device is GPU, _load_state_dict_into_meta_model will be called. If CPU, _load_state_dict_into_model will be called. In most cases, we will use GPU to accelerate inference, so we will examine _load_state_dict_into_meta_model.
First, the parameters should be cast to proper data type, for example, torch.float16 or torch.bfloat16. Then the parameters are loaded into GPU memory through set_module_tensor_to_device one by one.
1
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
In set_module_tensor_to_device, it uses torch.Tensor.to to perform device conversion. The following code comes from documentation, which demonstrate the device conversion process.
1
2
3
4
5
6
7
8
9
10
tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu
tensor.to(torch.float64)
cuda0 = torch.device('cuda:0')
tensor.to(cuda0)
tensor.to(cuda0, dtype=torch.float64)
other = torch.randn((), dtype=torch.float64, device=cuda0)
tensor.to(other, non_blocking=True)
AutoTokenizer.from_pretrainedis quite similar, we are not going to study it.
Creating a Pipeline for Inference
To enable inference, we first need to create a pipeline object. In run.py, we create an object by calling function pipeline(...) and assign variable pipe to it.
1
2
3
4
5
6
7
8
9
10
11
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
top_k=40,
repetition_penalty=1.1
)
The first argument passed to the pipeline function is the task, which in this case is "text-generation". “text-generation” is a pre-defined task, so check_task returns a pre-defined pipeline transformers.pipelines.text_generation.TextGenerationPipeline for “text-generation” task, which is contained in dictionary targeted_task.
1
2
3
4
5
6
if task in custom_tasks:
# ...
else:
normalized_task, targeted_task, task_options = check_task(task)
if pipeline_class is None:
pipeline_class = targeted_task["impl"]
pipeline will instantiate and return a pre-defined pipeline object. Here is the return statement of pipeline function. The instantiation will trigger TextGenerationPipeline.__init__ and Pipeline.__init__.
1
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
Starting Inference
Now we are ready to run the model inference using the pipeline object. In run.py, we call pipe(...) to start inference. Pipeline objects implemented special method __call__, so we can call it.
1
print(pipe(prompt_template.format(prompt=prompt))[0]['generated_text'])
Pipeline.__call__ will finally go to Pipeline.run_single:
1
return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
Pipeline.run_single is quite simple and here it is:
1
2
3
4
5
def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
model_inputs = self.preprocess(inputs, **preprocess_params)
model_outputs = self.forward(model_inputs, **forward_params)
outputs = self.postprocess(model_outputs, **postprocess_params)
return outputs
The Pipeline.run_single method involves three main steps: preprocessing the input, running the model, and postprocessing the output.
- First, it converts input prompt
inputsto tokensmodel_inputsby callingTextGenerationPipeline.preprocess. - Then, it feeds the
model_inputsto the model and gets themodel_outputsby callingPipeline.forward. - Finally, it decodes the
model_outputsto text by callingTextGenerationPipeline.postprocess.
Click to see more about "TextGenerationPipeline.preprocess"
TextGenerationPipeline.preprocess will use tokenizer passed in to convert input prompt text to tokens, more precisely a sequence of token indices.
1
inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs)
Click to see more about "Pipeline.forward"
Here is the implementation of Pipeline.forward.
1
2
3
4
5
6
7
8
9
with self.device_placement():
# ...
inference_context = self.get_inference_context()
with inference_context():
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
model_outputs = self._forward(model_inputs, **forward_params)
model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
# ...
return model_outputs
First, we need to ensure that the input tensors have been moved from CPU to GPU memory. This is implemented by calling Pipeline._ensure_tensor_on_device.
Now weights and inputs are fully prepared. We are ready to run inference on GPU. This is implemented by calling TextGenerationPipeline._forward, which will call LlamaForCausalLM.generate inside.
1
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
The generated_sequence is actually a sequence of token indices. Then function returns a dictionary with the generated_sequence, input_ids, and prompt_text.
1
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
Finally, we call Pipeline._ensure_tensor_on_device to ensure that the model_outputs has been moved from GPU memory to CPU memory for further operations on CPU.
Click to see more about "LlamaForCausalLM.generate"
Calling LlamaForCausalLM.generate actually goes to GenerationMixin.generate. Under the TextGenerationPipeline configuration, it falls to the following branch.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# ...
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
result = self._sample(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
# ...
Here is the main loop in the LLM inference process, where the model processes input tokens in stages (prefill and decode) to generate the next token until the sequence is complete.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
is_prefill = True
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True)
# ...
# concat
The
self(...)call within the loop and themodel_forwardfunction both invoke theLlamaForCausalLM.forwardmethod, which processes the input tokens through the model layers to generate the next token in the sequence.
LlamaForCausalLM.forwardwill callLlamaModel.forward.LlamaModel.forwardwill callEmbedding.forward,LlamaRotaryEmbedding.forward, …, seeLlamaModel.forwardimplementation.
Click to see more about "TextGenerationPipeline.postprocess"
What TextGenerationPipeline.postprocess does is basically using tokenizer to decode the token index sequence to text.
1
2
3
4
5
6
# Decode text
text = self.tokenizer.decode(
sequence,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
Additional Details
Implementation of Attention Machanism
It uses torch.nn.functional.scaled_dot_product_attention to calculate the scaled dot product.
1
2
3
4
5
6
7
8
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
Implementation of Sampling
We call nn.functional.softmax to get the probability distribution of each token and torch.multinomial to simulate the distribution.
1
2
3
4
5
6
7
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)