All Articles

Pruning and Knowledge Distillation

Llama 3.2 was released last month, introducing medium-sized vision language models (11B and 90B parameters) alongside lightweight, text-only models (1B and 3B). Notably, Meta has implemented pruning and distillation techniques on these lightweight models. As I have not previously explored pruning, I am eager to investigate its principles and understand the rationale behind Meta’s application of this technique for their lightweight models.

Llama3.2 Lightweight Model Pruning & Distillation

Knowledge Distillation

I possess extensive hands-on experience with knowledge distillation from my tenure at a medical AI startup specializing in Speech-to-Text (STT) technology for radiology. In this role, I contributed to the development of a lightweight model utilizing knowledge distillation techniques, optimizing it for deployment on local computers with limited computational resources.

What is Knowledge Distillation?

While this post primarily focuses on pruning, I would like to briefly discuss knowledge distillation. Knowledge distillation is a method designed to transfer knowledge from a “teacher” model to a “student” model. Instead of relying solely on hard labels, this approach leverages the soft label outputs generated by the teacher model to provide more informative and nuanced guidance for training the student model. This process can be likened to the student model learning from the insights and expertise of the teacher model.

Consider an image classification task aimed at distinguishing among three animal classes: dogs, cats, and tigers. Traditionally, we might use one-hot encoded labels, such as [0, 1, 0] to represent a cat. However, this hard labeling discards potential nuances in the data.

In knowledge distillation, we first train a large, high-capacity teacher model. When presented with an image of a cat, this model might output a probability distribution like [0.05, 0.80, 0.15] for [dog, cat, tiger]. This soft label captures more information than the hard label, suggesting that while the image is most likely a cat, it shares some features with tigers and, to a lesser extent, dogs.

The student model is then trained using these soft labels from the teacher model, often in combination with the original hard labels. The soft labels provide several advantages:

  1. They encode relationships between classes that are not captured by hard labels.
  2. They provide a form of regularization, potentially improving generalization.
  3. They can help the student model learn more efficiently, often achieving performance closer to the teacher model with fewer parameters.

Pruning

What is Pruning?

Returning to the topic of pruning, the term itself refers to the removal of unnecessary branches from a tree. This concept is analogous in deep learning. Pruning is a technique that reduces the size of a neural network by eliminating less important or redundant parameters while attempting to maintain the model’s performance. These parameters may include weights, neurons, or filters that are integral to a deep learning model. The goal of pruning is to enhance the model’s efficiency, reducing its memory footprint, computational requirements, and potentially accelerating inference times. I believe that through pruning, Meta could develop superior wearable devices; I suspect that Orion is a product of this technology.

When a model is trained, not all neurons or parameters contribute equally to its final performance. Pruning identifies and removes those less critical components, resulting in a smaller and more optimized model. The Llama 3.2 release notes indicate that Meta employed structured pruning in a single-shot manner from the Llama 3.1 8B model. Before delving into structured pruning, let us first discuss unstructured pruning.

Unstructured Pruning

Unstructured pruning is a straightforward approach that removes individual parameters (weights) without considering the larger structures (such as neurons or filters) to which they belong. It employs minimum thresholds based on the raw weights themselves or their activations to determine whether an individual parameter should be pruned. Consequently, weights that fall below certain thresholds are pruned, as weights with low magnitudes are deemed less important.

For instance, consider a set of weights as follows:

weights = [0.002,0.1, 0.5,0.0003, 1.2,0.75]

If the threshold for unstructured pruning is set at 0.01, the pruned weights would be:

pruned_weights = [0,0.1, 0.5, 0, 1.2,0.75]

It is important to note that pruning is performed based on the absolute value of each weight.

Unstructured pruning offers several advantages in neural network optimization. By targeting individual weights within a neural network, it can eliminate unnecessary weights across the entire network while preserving the overall architecture of the model. Therefore, unstructured pruning can be applied to any deep learning model architecture.

PyTorch provides various methods to perform unstructured pruning:

A study titled Post-training Deep Neural Network Pruning via Layer-wise Calibration demonstrated that a weight reduction of up to 50% can be achieved with only a marginal accuracy loss of approximately 1.5%. The paper discusses the concept of L2-Normalized Magnitude, which involves normalizing weights by the L2 norm of a layer. This methodology has proven particularly effective in scenarios involving batch normalization layers.

Structured Pruning

Structured pruning represents a more systematic approach to pruning, wherein entire structures within the network are removed. This method modifies the architecture of the model in a structured manner, maintaining the overall organization of the network.

For example, consider MobileNetV2, which features an intermediate layer comprising 128 feature maps (channels). Through structured pruning, it may be determined that 30 of these channels exhibit consistently low activation values across the dataset, indicating their minimal contribution to the model’s overall output. Consequently, these 30 channels can be pruned, resulting in a reduced set of 98 channels. This reduction allows subsequent pointwise convolutions to operate on a smaller input size, thereby enhancing the network’s efficiency while maintaining an acceptable level of accuracy.

Examples of structured pruning include:

  • Filter Pruning: Uses criteria like the L1/L2 norm of filter weights, activation magnitude, or APoZ to decide which filters to remove.
  • Neuron Pruning: Focuses on neurons in fully connected layers, using activation-based or gradient-based metrics.
  • Channel Pruning: Removes entire feature maps, often guided by batch normalization scaling factors or activation magnitudes.
  • Layer Pruning: Removes whole layers based on their contribution to output performance or redundancy in representation.

Structured pruning maintains the integrity of the network’s architecture, making it easier for hardware to process the pruned model efficiently. By focusing on the importance of larger structures (filters, neurons, channels), it can effectively reduce model size and computation while preserving performance. The key is to use criteria that capture the significance of each structure, ensuring that the network retains the most crucial components.

Why Structured Pruning?

Structured pruning offers distinct advantages over unstructured pruning, particularly regarding fine-tuning. Following unstructured pruning, the resulting sparse distribution of non-zero weights can complicate subsequent training, as the remaining weights must compensate for the connections that have been removed. This compensation is further complicated by the irregular and sparse nature of the weight distribution, which can hinder effective gradient updates. In contrast, structured pruning maintains a dense and contiguous architecture among the remaining network components, facilitating a more straightforward and often more effective fine-tuning process compared to unstructured pruning.

I believe this is why Meta opted for structured pruning over unstructured pruning while developing their lightweight models. In the post-training phase, they conducted several rounds of alignment on top of the pre-trained model, where each round involved supervised fine-tuning (SFT), rejection sampling (RS), and direct preference optimization (DPO). Now that I have a clearer understanding of pruning, I will endeavor to dissect how Llama 3.2 was developed in my next post.

Oct 6, 2024

AI Enthusiast and a Software Engineer