《模型保存与加载》
本系列来总结Pytorch训练中的模型结构一些内容,包括模型的定义,模型参数化初始化方法,模型的保存与加载等
文章目录
- 0 博客目录
- 1 保存和加载
- 1.1 Save源码
- 1.2 Load源码
- 2 一般形式
- 2.1 保存整个网络
- 2.2 保存网络参数
- 2.3 保存更多参数
- 3 CPN
- 3.1 CPN模型保存--train
- 3.2 CPN模型加载--test
- 3.3 CPN模型加载--resume
- 3.4 CPN模型加载--finetuning
- 4 细节补充
0 博客目录
Pytorch模型训练(0) - CPN源码解析
Pytorch模型训练(1) - 模型定义
Pytorch模型训练(2) - 模型初始化
Pytorch模型训练(3) - 模型保存与加载
Pytorch模型训练(4) - Loss Function
Pytorch模型训练(5) - Optimizer
Pytorch模型训练(6) - 数据加载
1 保存和加载
1.1 Save源码
Save使用pickle工具将模型对象序列化为pickle文件到disk
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file. 保存模型到disk
See also: :ref:`recommend-saving-models`
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string
containing a file name 保存模型的文件对象或文件名
pickle_module: module used for pickling metadata and objects 使用python的pickle格式序列化模型
pickle_protocol: can be specified to override the default protocol pickle协议
.. warning::
If you are using Python 2, does NOT support
as a valid file-like object. This is because the write method should return
the number of bytes written; () does not do this.
Please use something like instead.
python2不支持作为文件对象,因为其()不能返回write方法需要的写入字节个数
但可用
Example:
>>> # Save to file
>>> x = ([0, 1, 2, 3, 4])
>>> (x, '')
>>> # Save to buffer
>>> buffer = ()
>>> (x, buffer)
"""
调用底层_save方法,略微复杂,不继续探讨
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
使用这个save函数可以保存各种对象的模型、张量和字典;一般Pytorch保存模型后缀为:.pt 或 .pth 或 .pkl
1.2 Load源码
Load使用pickle的unpickle工具将pickle的对象文件反序列化为内存
def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
"""
User extensions can register their own location tags and tagging and
deserialization methods using `register_package`.
Args:
文件对象或文件名
f: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
一个函数: 可以是,字符串,指定的重映射位置
可以用来指定加载模型到GPU或CPU等, 默认GPU
map_location: a function, , string or a dict specifying how to remap storage locations
pickle格式类型:这里应该时反pickle序列化
pickle_module: module used for unpickling metadata and objects (has to
match the pickle_module used to serialize file)
可选字段:比如 ``encoding=...`` 在版本切换种,编码冲突可用
pickle_load_args: optional keyword arguments passed over to
``pickle_module.load`` and ``pickle_module.Unpickler``, .,
``encoding=...``.
.. note::
When you call :meth:`()` on a file which contains GPU tensors, those tensors
will be loaded to GPU by default. You can call `(.., map_location='cpu')`
and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
.. note::
In Python 3, when loading files saved by Python 2, you may encounter
``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``. This is
caused by the difference of handling in byte strings in Python2 and
Python 3. You may use extra ``encoding`` keyword argument to specify how
these objects should be loaded, ., ``encoding='latin1'`` decodes them
to strings using ``latin1`` encoding, and ``encoding='bytes'`` keeps them
as byte arrays which can be decoded later with ``byte_array.decode(...)``.
Example:
#默认加载到GPU
>>> ('')
# Load all tensors onto the CPU
加载到CPU
>>> ('', map_location=('cpu'))
# Load all tensors onto the CPU, using a function
用函数加载到CPU
>>> ('', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
加载到GPU1
>>> ('', map_location=lambda storage, loc: (1))
# Map tensors from GPU 1 to GPU 0
从GPU1映射到GPU0
>>> ('', map_location={'cuda:1':'cuda:0'})
# Load tensor from object
从 对象加载
>>> with open('') as f:
buffer = (())
>>> (buffer)
"""
new_fd = False
if isinstance(f, str) or \
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
(sys.version_info[0] == 3 and isinstance(f, )):
new_fd = True
f = open(f, 'rb')
try:
return _load(f, map_location, pickle_module, **pickle_load_args)
finally:
if new_fd:
()
2 一般形式
从源码不难看出pytorch保存模型的方式多样,保存模型的后缀名也是多样的,但要注意使用哪种保存,就要使用对应的加载方式
一般我们常用到Pytorch加载和保存模型方式有以下几种种:
2.1 保存整个网络
(model, PATH)
model=(PATH)
这种方式重新加载的时候不需要自定义网络结构,保存时已经把网络结构保存了下来
2.2 保存网络参数
这种方式,速度快,占空间少
(model.state_dict(),PATH)
model.load_state_dict((PATH))
或者
(.state_dict(), final_model_state_file)
.load_state_dict((final_model_state_file))
仅保存和加载模型参数,这种方式重新加载的时候需要自己定义网络model,并且其中的参数名称与结构要与保存的模型中的一致(可以是部分网络,比如只使用VGG的前几层),相对灵活,便于对网络进行修改
2.3 保存更多参数
在实验中往往需要保存更多的信息,比如优化器的参数,那么可以采取下面的方法保存:
({
'epoch': epochID + 1,
'state_dict': model.state_dict(),
'best_loss': lossMIN,
'optimizer': optimizer.state_dict(),
'alpha': ,
'gamma':
},checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.')
以上包含的信息有,epochID, state_dict, min loss, optimizer, 自定义损失函数的两个参数;格式以字典的格式存储。对应加载的方式:
def load_checkpoint(model, checkpoint_PATH, optimizer):
if checkpoint != None:
model_CKPT = (checkpoint_PATH)
model.load_state_dict(model_CKPT['state_dict'])
print('loading checkpoint!')
optimizer.load_state_dict(model_CKPT['optimizer'])
return model, optimizer
但是,我们可能修改了一部分网络,比如加了一些,删除一些,等等,那么需要过滤这些参数,加载方式:
def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
if checkpoint != 'No':
print("loading checkpoint...")
model_dict = model.state_dict()
modelCheckpoint = (checkpoint)
pretrained_dict = modelCheckpoint['state_dict']
# 过滤操作
new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
model_dict.update(new_dict)
# 打印出来,更新了多少的参数
print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
model.load_state_dict(model_dict)
print("loaded finished!")
# 如果不需要更新优化器那么设置为false
if loadOptimizer == True:
optimizer.load_state_dict(modelCheckpoint['optimizer'])
print('loaded! optimizer')
else:
print('not loaded optimizer')
else:
print('No checkpoint is included')
return model, optimizer
3 CPN
3.1 CPN模型保存–train
save_model({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
}, checkpoint=)
保存了一些必要训练参数和模型参数
3.2 CPN模型加载–test
checkpoint_file = (, +'.')
checkpoint = (checkpoint_file)
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_file, checkpoint['epoch']))
测试模型时,我们只关注模型参数
3.3 CPN模型加载–resume
if :
if isfile():
print("=> loading checkpoint '{}'".format())
checkpoint = ()
pretrained_dict = checkpoint['state_dict']
model.load_state_dict(pretrained_dict)
args.start_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(, checkpoint['epoch']))
logger = Logger(join(, ''), resume=True)
else:
print("=> no checkpoint found at '{}'".format())
else:
logger = Logger(join(, ''))
logger.set_names(['Epoch', 'LR', 'Train Loss'])
resume是指接着某一次保存的模型继续训练,因为我们在训练中,可能中断或需要调调参数,就可以用这种方式;一般来说,它需要保存模型时保存当时的训练现场,就像caffe训练时保存的solverstate文件
3.4 CPN模型加载–finetuning
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
print('Initialize with pre-trained ResNet')
from collections import OrderedDict
state_dict = model.state_dict()
pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
for k, v in pretrained_state_dict.items():
if k not in state_dict:
continue
state_dict[k] = v
print('successfully load '+str(len(state_dict.keys()))+' keys')
model.load_state_dict(state_dict)
return model
finetuning与resume之间还是有点区别的;我们常常说的finetuning(迁移学习)本质就是加载预训练,继续训练;当然加载时,可能会根据需求选择参数,也可能会适当冻结部分参数等
4 细节补充
1)model.state_dict
pytorch 中的 state_dict 是一个简单的python的字典对象;在模型中,它将每一层与它的对应参数建立映射关系,如model的每一层的weights及偏置等等
注意:只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数,如lr, momentum,weight_decay等
2)OrderedDict
collections模块中的有序字典;模型中,大部分字典对象都是用它,如Sequential:
# Example of using Sequential
model = (
nn.Conv2d(1,20,5),
(),
nn.Conv2d(20,64,5),
()
)
# Example of using Sequential with OrderedDict
model = (OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', ()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', ())
]))
在Python中,dict这个数据结构由于hash的特性,是无序的,这在有的时候会给我们带来一些麻烦, 幸运的是,collections模块为我们提供了OrderedDict,当你要获得一个有序的字典对象时,用它就对了