pytorch模型保存与加载

pytorch模型保存与加载

第一种方式

image-20230210141948373

采用这种方式保存的模型是模型的结构和模型的参数。

在加载模型时:

image-20230210142011454

此时模型的参数也被保存下来了,可以通过如下方式查看,即通过debug的模式,在的模型加载之后查看模型的参数:

image-20230210142038612

第二种方式

image-20230210141746039

这种方式保存的是模型的参数,其中模型后面的vgg16.state_dict() 表示将模型的参数保存为字典类型的数据。这是一种比较推荐的保存方式,因为,如果是自己定义的模型采用第一种的保存方式则在别的地方加载这个模型的时候就会找不到模型的结构。

第二种方式的模型加载

image-20230210142451555

这种方式保存的模型首先要取得模型的结构,然后通过load_state_dict(torch.load('model.pth')) 来将之前保存的模型的参数加载到模型当中。

通常来说应该首先采用第二种方式保存和加载模型。

自定义模型的保存和加载

如果采用第一种方式保存:

image-20230210142828356

在另一个文件中加载保存的这个模型时:

image-20230210142947692

会出错,即找不到这个模型的结构。

此时需要将原先的网络模型的结构复制到或者通过import 引入到这个文件中:

image-20230210143124732

这样才可以正常加载模型的参数,这样与最开始模型的定义与实现不一样的地方在于:不需要写model=Tudui() 这句话来创建模型的对象。

所以一帮保存和加载模型都使用第二种方式。


pytorch模型保存与加载
http://csgituser.github.io/2023/03/08/pytorch模型保存与加载/
Author
Museum
Posted on
March 8, 2023
Licensed under