SWA(随机权重平均)

时间:2022-12-04 07:58:17


SWA(随机权重平均)

[Averaging Weights Leads to Wider Optima and Better Generalization](Averaging Weights Leads to Wider Optima and Better Generalization)
随机权重平均:在优化的末期取k个优化轨迹上的checkpoints,平均他们的权重,得到最终的网络权重,这样就会使得最终的权重位于flat曲面更中心的位置,缓解权重震荡问题,获得一个更加平滑的解,相比于传统训练有更泛化的解。

SWA(随机权重平均)

效果如下:

SWA(随机权重平均)

SWA和EMA

在​​EMA指数滑动平均(Exponential Moving Average)​​我们讨论了指数滑动平均,可以发现SWA和EMA是有相似之处:

  • 都是在训练之外的操作,不影响训练过程。
  • 与集成学习类似,都是一种权值的平均,EMA是一种指数平均,会赋予近期更多的权重,SWA则是平均赋权重。

所以这里参考了的SWA实现,添加了EMA的实现,两者不同在于影子权值的更新方式。

class WeightAverage(Optimizer):
def __init__(self, optimizer, wa_start=None, wa_freq=None, wa_lr=None, mode='swa'):
"""实现参考:https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
论文:Averaging Weights Leads to Wider Optima and Better Generalization
两种权重平均的方式 swa 和 ema
两种模式:自动模式和手动模式
参数:
optimizer (torch.optim.Optimizer): optimizer to use with SWA
wa_start (int): SWA开始应用的step
wa_freq (int): 更新SWA的频数
wa_lr (float): 自动模式:从swa_start开始应用
"""
if isinstance(mode, float):
self.mode = 'ema'
self.beta = mode
else:
self.mode = mode
self._auto_mode, (self.wa_start, self.wa_freq) = self._check_params(wa_start, wa_freq)
self.wa_lr = wa_lr
# 参数检查
if self._auto_mode:
if wa_start < 0:
raise ValueError("Invalid wa_start: {}".format(wa_start))
if wa_freq < 1:
raise ValueError("Invalid wa_freq: {}".format(wa_freq))
else:
if self.wa_lr is not None:
warnings.warn("Some of wa_start, wa_freq is None, ignoring wa_lr")
self.wa_lr = None
self.wa_start = None
self.wa_freq = None

if self.wa_lr is not None and self.wa_lr < 0:
raise ValueError("Invalid WA learning rate: {}".format(wa_lr))

self.optimizer = optimizer
self.defaults = self.optimizer.defaults
self.param_groups = self.optimizer.param_groups
self.state = defaultdict(dict)
self.opt_state = self.optimizer.state

for group in self.param_groups:
# ema 不需要保存已经平均的个数,为了兼容swa不修改
group['n_avg'] = 0
group['step_counter'] = 0

@staticmethod
def _check_params(swa_start, swa_freq):
"""检查参数,确认执行模式,并将参数转为int
"""
params = [swa_start, swa_freq]
params_none = [param is None for param in params]
if not all(params_none) and any(params_none):
warnings.warn("Some of swa_start, swa_freq is None, ignoring other")
for i, param in enumerate(params):
if param is not None and not isinstance(param, int):
params[i] = int(param)
warnings.warn("Casting swa_start, swa_freq to int")
return not any(params_none), params

def _reset_lr_to_swa(self):
"""应用wa学习率
"""
if self.wa_lr is None:
return
for param_group in self.param_groups:
if param_group['step_counter'] >= self.wa_start:
param_group['lr'] = self.wa_lr

def update_swa_group(self, group):
"""更新一组参数的wa: 随机权重平均或者指数滑动平均
"""
for p in group['params']:
param_state = self.state[p]
if 'wa_buffer' not in param_state:
param_state['wa_buffer'] = torch.zeros_like(p.data)
buf = param_state['wa_buffer']
if self.mode == 'swa':
virtual_decay = 1 / float(group["n_avg"] + 1)
diff = (p.data - buf) * virtual_decay # buf + (p-buf) / (n+1) = (p + n*buf) / (n+1)
buf.add_(diff)
else:
buf.mul_(self.beta).add_((1-self.beta) * p.data)
group["n_avg"] += 1

def update_swa(self):
"""手动模式:更新所有参数的swa
"""
for group in self.param_groups:
self.update_swa_group(group)

def swap_swa_sgd(self):
"""1.交换swa和模型的参数 2.训练结束时和评估时调用
"""
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
if 'wa_buffer' not in param_state:
warnings.warn("WA wasn't applied to param {}; skipping it".format(p))
continue
buf = param_state['wa_buffer']
tmp = torch.empty_like(p.data)
tmp.copy_(p.data)
p.data.copy_(buf)
buf.copy_(tmp)

def step(self, closure=None):
"""1.梯度更新 2.如果是自动模式更新swa参数
"""
self._reset_lr_to_swa()
loss = self.optimizer.step(closure)
for group in self.param_groups:
group["step_counter"] += 1
steps = group["step_counter"]
if self._auto_mode:
if steps > self.wa_start and steps % self.wa_freq == 0:
self.update_swa_group(group)
return loss

def state_dict(self):
"""打包 opt_state 优化器状态,swa_state SWA状态,param_groups 参数组
"""
opt_state_dict = self.optimizer.state_dict()
wa_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
opt_state = opt_state_dict["state"]
param_groups = opt_state_dict["param_groups"]
return {"opt_state": opt_state, "wa_state": wa_state,
"param_groups": param_groups}

def load_state_dict(self, state_dict):
"""加载swa和优化器的状态参数
"""
wa_state_dict = {"state": state_dict["wa_state"],
"param_groups": state_dict["param_groups"]}
opt_state_dict = {"state": state_dict["opt_state"],
"param_groups": state_dict["param_groups"]}
super(WeightAverage, self).load_state_dict(wa_state_dict)
self.optimizer.load_state_dict(opt_state_dict)
self.opt_state = self.optimizer.state

def add_param_group(self, param_group):
"""将一组参数添加到优化器的 `param_groups`.
"""
param_group['n_avg'] = 0
param_group['step_counter'] = 0
self.optimizer.add_param_group(param_group)

@staticmethod
def bn_update(loader, model, device=None):
"""更新 BatchNorm running_mean, running_var
"""
if not _check_bn(model):
return
was_training = model.training
model.train()
momenta = {}
model.apply(_reset_bn)
model.apply(lambda module: _get_momenta(module, momenta))
n = 0
for input in loader:
if isinstance(input, (list, tuple)):
input = input[0]
b = input.size(0) # batch_size

momentum = b / float(n + b)
for module in momenta.keys():
module.momentum = momentum

if device is not None:
input = input.to(device)

model(input)
n += b

model.apply(lambda module: _set_momenta(module, momenta))
model.train(was_training)


# BatchNorm utils
def _check_bn_apply(module, flag):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
flag[0] = True


def _check_bn(model):
flag = [False]
model.apply(lambda module: _check_bn_apply(module, flag))
return flag[0]


def _reset_bn(module):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)


def _get_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
momenta[module] = module.momentum


def _set_momenta(module, momenta):
if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
module.momentum = momenta[module]

​Stochastic Weight Averaging in PyTorch​