阅读背景:

mmdetection 修改预训练模型权重类别数

来源:互联网 

将coco预训练模型类别权重修改类训练数据集权重,代码如下:
如何修改可以看key里面值进行修改:

for key, value in model_coco["state_dict"].items(): print(key) # -*- coding: utf-8 -*- # @Time : 20-1-7 上午10:28 # @Author : wusaifei # @FileName: Modify_category.py # @Software: PyCharm def main(): #gen coco pretrained weight import torch num_classes = 11 model_coco = torch.load("../checkpoints/cascade_rcnn_r50_fpn_1x_20190501-3b6211ab.pth") # weight for key, value in model_coco["state_dict"].items(): print(key) ###################################################################################### # faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth # model_coco["state_dict"]["bbox_head.fc_cls.weight"] = \ # model_coco["state_dict"]["bbox_head.fc_cls.weight"][:num_classes, :] # # # model_coco["state_dict"]["bbox_head.fc_cls.bias"] = \ # model_coco["state_dict"]["bbox_head.fc_cls.bias"][:num_classes] ###################################################################################### # cascade_rcnn_r50_fpn_1x_20190501-3b6211ab.pth model_coco["state_dict"]["bbox_head.0.fc_cls.weight"] = \ model_coco["state_dict"]["bbox_head.0.fc_cls.weight"][:num_classes, :] model_coco["state_dict"]["bbox_head.1.fc_cls.weight"] = \ model_coco["state_dict"]["bbox_head.1.fc_cls.weight"][:num_classes, :] model_coco["state_dict"]["bbox_head.2.fc_cls.weight"] = \ model_coco["state_dict"]["bbox_head.2.fc_cls.weight"][:num_classes, :] model_coco["state_dict"]["bbox_head.0.fc_cls.bias"] = \ model_coco["state_dict"]["bbox_head.0.fc_cls.bias"][:num_classes] model_coco["state_dict"]["bbox_head.1.fc_cls.bias"] = \ model_coco["state_dict"]["bbox_head.1.fc_cls.bias"][:num_classes] model_coco["state_dict"]["bbox_head.2.fc_cls.bias"] = \ model_coco["state_dict"]["bbox_head.2.fc_cls.bias"][:num_classes] # save new model torch.save(model_coco, "cascade_rcnn_r50_fpn_1x_coco_pretrained_weights_classes_%d.pth" % num_classes) if __name__ == "__main__": main() 各个模型下载路径 点击这里下载各个网络模型



你的当前访问异常,请进行认证后继续阅读剩余内容。

分享到: