✅ 用 .pth 保存的只是模型的权重参数(state_dict)
❌ 并不会保存模型的结构(即 class MyModel(nn.Module): ... 的代码)
✅ 解决方法:保存整个模型对象(包括结构 + 权重)
PyTorch 提供了另一种方式:
pythonCopyEdit# 保存完整模型对象(包含结构和参数)
torch.save(model, 'models/full_model.pt')
加载时:
pythonCopyEdit# 加载完整模型(不需要再次定义 class)
model = torch.load('models/full_model.pt')
model.eval()
✅ 优点 vs 缺点
| 方法 | 保存结构 | 灵活性 | 推荐用途 |
|---|---|---|---|
torch.save(model.state_dict()) | ❌ | 高 | 通常训练 & 复现 |
torch.save(model) | ✅ | 低 | 快速部署 / demo |
❗ 需要注意的是:
torch.save(model)这种方式使用了 Python 的 pickle,因此保存/加载时 模型定义代码必须在内存中或不能缺失依赖环境,对长期部署和迁移性略差。
✅ 实用建议
最推荐的做法是:
- 使用
.pth保存权重 + Git 保存模型定义代码; - 如果你想 一键部署/展示给他人运行,那就用
.pt保存完整模型。