扫二维码与项目经理沟通
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流
model.load_state_dict(torch.load('/home/yangjy/projects/Jane_git_tf/weights/con_model/best1_2022-12-02-09-36.pth', map_location=device))
Question?情形:
新的model是需要两个模型作前期的处理后的结果,如model1得到feature1,model2得到feature2,最终(现在训练的)model(model3)需要学习的是根据feature1和feature2进行整合和特征学习正确分辨出最终的结果。这个时候model3在第一次训练做初始化的时候需要加载model1和model2的权重,但是后来训练的时候如果初始权重是之前训练好的model3的权重,就不要再加载model1和model2的权重后再加载model3的权重,机器在加载的过程中都是需要消耗时间的,一方面是资源成本的浪费,无论是时间成本还是内存占用率都是很大的消耗;其次刚刚我发现,这样重复性加载时影响最终的模型训练效果的,模型在加载权重的过程个人建议不要写在模型初始化的过程中,这种不灵活的写法,很可能会产生bias!!
class model3(nn.Module):
def __init__(self, num_classes, device,model1_path, model2_path,freeeze_pretain):
super(Conv_con, self).__init__()
self.device = device
self.model1= MiniConvNext(num_classes=5, depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768], )
self.model2 = MiniConvNext(num_classes=1, depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768], )
self.freeeze_pretain = freeeze_pretain
self._init_weights()
self.fctl = _FCtL(512, 512)
self.norm = LayerNorm(512, eps=1e-6, data_format="channels_last")
self.head = nn.Linear(512, num_classes)
self.model1_path= model1_path
self.model2_path= model2_path
self.set_pretrained_weight()
def set_pretrained_weight(self):
if self.model1_path:
pretrained_dict = torch.load(self.model1_path, map_location=self.device)
model_dict = self.model1.state_dict()
# 1. filter out unnecessary keys
pretrained_dict_b = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict_b)
self.model1.load_state_dict(model_dict)
self.model1.eval()
if self.model2_path:
pretrained_dict = torch.load(self.model2_path, map_location=self.device)
model_dict = self.model2.state_dict()
# 1. filter out unnecessary keys
pretrained_dict_b = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict_b)
self.model2.load_state_dict(model_dict)
if self.freeeze_pretain: # ?????fctl????????????
self.model2.eval()
if self.freeeze_pretain: # ??FCTL
for name, para in self.model2.named_parameters():
para.requires_grad = False
for name, para in self.model1.named_parameters():
para.requires_grad = False
else: # ??FCTL?global??
for name, para in self.model1.named_parameters():
para.requires_grad = False
def get_pretrained_weight(self):
for name, parm in self.model2.named_parameters():
print(f'{name}:{parm.requires_grad}')
def forward(self, x, y):
if self.freeeze_pretain: # ?????????????????????
with torch.no_grad():
feature1 = self.model1(x)
feature2 = self.model2(y)
else:
with torch.no_grad():
feature1 = self.model1(y)
feature2 = self.model2(x)
features = self.fctl(feature1 ,feature2 ,) # ??global?????roi????
features_1 = self.norm(features.mean([-2, -1]))
out = self.head(features_1)
return out
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.2)
nn.init.constant_(m.bias, 0)
劝大家不要这样写!不要把权重加载的事情放在初始化里面,追悔莫及!
思考:
其实我自己刚开始觉得这种重复加载权重应该是没有问题的,因为model1和model2只是model3的一部分,我先加载model1和model2的权重,最后加载model3的权重也是会覆盖刚刚加载的model1和model2的权重的,但是结果好像并不像我想想的那么简单。因为我用训练好的权重去预测,先加载model1,再加载model2,之后加载model3,之后得到的结果惊掉下巴!虽然再训练过程中在验证集上准确率不低,但是…所以用验证集验证是不是我的权重保存有问题。check后发现没有问题,之后检查数据集也没有问题,代码也没问题,label的错误之前犯过了,也不妨再检查一遍没有问题。所以我又重新定义了不加载权重的predict_model ,直接加载model3的权重,这次在验证集上的结果才是正常的。至于原因,个人还在探索,搞明白再和大家分享。
你是否还在寻找稳定的海外服务器提供商?创新互联www.cdcxhl.cn海外机房具备T级流量清洗系统配攻击溯源,准确流量调度确保服务器高可用性,企业级服务器适合批量采购,新人活动首月15元起,快前往官网查看详情吧
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流