L1 & L2 regularization
L1(Lasso) and L2(Ridge) regularization are used to address overfitting by penalizing large weights in the model.
- L1 Regularization(Lasso): Adds the absolute value of all weights in the model to the loss function.
- Encourages sparsity, meaning some weights can be driven to zero, effectively removing unimportant features
- prefer when features selection is desired
- L2 Regularization(Ridge): Adds the squared value of all weights to the loss function.
- Discourages excessively large weights but generally doesn’t drive them to zero, resulting in a smoother weight distribution
- more stable and efficient, default choice
Implementing in PyTorch
Using weight_decay
PyTorch的一些优化算法如SGD,Adam内部实现了L2正则化,weight_decay
代表L2的强度。
A higher value of weight_decay leads to stronger regularization.
optimizer = torch.optim.SGD(model.parameter(), lr=0.01, weight_decay=0.001)
Manual
import torch
from torch import nn
def L1_regularization(model):
l1_reg = 0
for parm in model.parameters():
l1_reg += torch.sum(torch.abs(parm))
return l1_reg
def L2_regularization(model):
l2_reg = 0
for parm in model.parameters():
l2_reg += torch.sum(torch.square(parm))
return l2_reg
def loss(pred, target, L1_weight=0., l2_weight=0.):
loss = loss_function(pred, target)
l1_penalty = L1_regularization() * L1_weight
L2_penalty = L2_regularization() * L2_weight
return loss = L1_penalty + L2_penalty