深入理解Pytorch中模型保存文件pth
一、Pytorch中模型保存和加载方法
本文在介绍Pytorch中模型保存文件pth之前,将先探讨如模型的保存/加载的方法。
三个核心函数:
- torch.save:把序列化的对象保存到硬盘。利用Python的pickle来实现序列化。模型、tensor以及字典都可以用该函数进行保存;
- torch.load:采用 pickle 将反序列化的对象从存储中加载进来。
- torch.nn.Module.load_state_dict:采用一个反序列化的state_dict加载一个模型的参数字典。
保存/加载模型
在Pytorch中,模型的保存和加载主要有两种方法,一种是保存/加载整个模型,另一种是只保存/加载模型参数。
1. 保存整个模型
这种方法保存和加载模型都是采用最简单的语法。这种方法将是采用Python的pickle模块来保存整个模型,它的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是pickle并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。
示例方法:
1 | # 保存整个模型 |
造成的影响:
1 | import torch |
如上图所示,保存的完整模型是一个字典,包含了模型、优化器、学习率调整器、迭代次数等信息。这种方法的缺点是,如果要加载模型,必须要保证模型的类别和结构不变,否则会报错。
例如,如果加载了之前的模型,但是想修改学习率之类的参数,由于加载的模型中保存了optimizer和scheduler,所以再次加载此文件时会使用之前的学习率。如果想要修改参数,其实只需要将模型权重加载进来就可以了,不需要再加载optimizer和scheduler。
如下面的代码:
1 | import torch |
或者在保存模型时,只保存模型的参数,如后面1.2节所示。
2. 仅保存模型参数(官方推荐)
当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。
示例代码:
1 | torch.save(model.state_dict(), PATH) # 保存模型参数 |
- 其中,model.state_dict() 返回一个字典,包含了模型的可学习参数,如卷积层的权重和偏置等。model.load_state_dict() 将保存的参数加载到模型中。
- 而load_state_dict() 方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load() ,而不是直接**model.load_state_dict(PATH)**。
例如,若采取以下的网络结构:
1 | class DQN(nn.Module): |
保存模型参数的代码如下:
1 | # 保存模型参数 |
加载模型参数的代码如下:
1 | # 加载模型参数 |
如上图所示,保存的模型参数是一个字典,包含了模型的可学习参数,如卷积层的权重和偏置等。
二、权重和偏置
如上图所示,神经网络的权重和偏置是模型的可学习参数,是模型的核心,也是模型的灵魂。在训练模型时,权重和偏置是不断变化的,而在预测时,权重和偏置是固定的。
其中:
- 输入层各节点中,要素x、权重w和偏执b为输入,z为输出。
- weight、bais一般是从高斯分布中随机初始化的值。
- 权重表示可能性大小,偏置用于正确分类样本,保证输出值不能被随便激活。
计算出z后,对z使用激活函数σ。 - 激活函数用于向模型引入一些非线性,可以将神经元的输出幅度限制在一定范围,一般为(-1,1)、(0,1)。
- 常用的激活函数有sigmoid、tanh、ReLU。
整个过程为:首先前向传播(对变量和权重进行计算,最后求出误差);然后反向传播(遍历每个层每个连接对误差的贡献,然后调整权重和偏置)。
一般来说,全连接层和卷积层都会有权重和偏置,但是池化层,激活层,归一化层等就不一定需要权重和偏置。这些层的作用是对输入进行非线性变换,降采样,规范化等操作,不涉及参数的学习。
三、模型参数的二进制分析
1. torch.save() 底层函数分析
官方给出的函数作用解释:将序列化对象保存到磁盘。此函数使用Python的pickle实用程序进行序列化。使用此函数可以保存各种对象的模型、张量和字典。
1 | def save( |
官方note:“请注意,load_state_dict()函数接受字典对象,而不是保存对象的路径。这意味着,在将其传递给load_state_dict()函数之前,您必须对保存的state_dict进行反序列化。例如,您不能使用model.load_state_dict(PATH)加载。”
1 | torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL,_use_new_zipfile_serialization=True) |
参数:
- obj:保存的对象
- f:文件对象(必须实现write和flush方法)或字符串或os.PathLike对象,包含文件名
- pickle_module:用于序列化元数据和对象的模块
- pickle_protocol:可以指定以覆盖默认协议
也就是在save和load之间,需要进行序列化和反序列化的操作。
2.pickle 序列化和反序列化
因此必须提到pickle这样一个序列化模块,pickle模块实现了基本的数据序列和反序列化。通过pickle模块的序列化操作我们能够将程序中运行的对象信息保存到文件中去,永久存储;通过pickle模块的反序列化操作,我们能够从文件中创建上一次程序保存的对象。
pickle的源代码文件可见Lib/pickle.py
pickle的使用方法可见官方文档
简而言之,Pytorch保存整个module使用的是pickle库,由于这个库在保存类的时候,并不是保存类本身,而是只保存了类名和类定义的位置,在加载的时候,pickle库会找类定义的位置,去加载类的定义。这就导致了在保存和加载的时候,如果类的定义发生了变化,就会出现找不到类的定义的错误。
可以被序列化/反序列化的对象
下列类型可以被封存:
- None、True 和 False
- 整数、浮点数、复数
- str、byte、bytearray
- 只包含可封存对象的集合,包括 tuple、list、set 和 dict
- 定义在模块最外层的函数(使用 def 定义,lambda 函数则不可以)
- 定义在模块最外层的内置函数
- 定义在模块最外层的类
- 某些类实例,这些类的 dict 属性值或 getstate() 函数的返回值可以被封存(详情参阅 封存类实例 这一段)。
尝试封存不能被封存的对象会抛出 PicklingError 异常,异常发生时,可能有部分字节已经被写入指定文件中。尝试封存递归层级很深的对象时,可能会超出最大递归层级限制,此时会抛出 RecursionError 异常,可以通过 sys.setrecursionlimit() 调整递归层级,不过请谨慎使用这个函数,因为可能会导致解释器崩溃。
注意,函数(内置函数或用户自定义函数)在被封存时,引用的是函数全名。这意味着只有函数所在的模块名,与函数名会被封存,函数体及其属性不会被封存。因此,在解封的环境中,函数所属的模块必须是可以被导入的,而且模块必须包含这个函数被封存时的名称,否则会抛出异常。
同样的,类也只封存名称,所以在解封环境中也有和函数相同的限制。注意,类体及其数据不会被封存,所以在下面的例子中类属性 attr 不会存在于解封后的环境中:
1 | import pickle |
用Hex Fiend软件(Windows下的WinHex软件)查看file.pickle文件,可以如下所示,可以看到确实只封存了名称。
这些限制决定了为什么必须在一个模块的最外层定义可封存的函数和类。
类似的,在封存类的实例时,其类体和类数据不会跟着实例一起被封存,只有实例数据会被封存。这样设计是有目的的,在将来修复类中的错误、给类增加方法之后,仍然可以载入原来版本类实例的封存数据来还原该实例。如果你准备长期使用一个对象,可能会同时存在较多版本的类体,可以为对象添加版本号,这样就可以通过类的 setstate() 方法将老版本转换成新版本。