相关文章推荐

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement . We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hello,

It happens when I resume previous training task. When I create my model and use 'model.load_state_dict(checkpoint['state_dict'])', it report the error that 'unexpected key "module.features.conv0.weight" in state_dict'.
However it works well when I load the model into GPU by 'net = torch.nn.DataParallel(model, device_ids=[0, 1]); net.cuda()' first, and then use 'net.load_state_dict(checkpoint['state_dict'])'.
So could anyone tell me what the problem is?

Thanks,
Lyuwei

@apaszke I was going to say that this is annoying for those without a GPU (they have to manually remove the module. prefix), but since #3318 it is possible to use multi-GPU module in a CPU-only machine.
That patch will be in the next release I suppose?

Same problem. Here is a quick script I used to solve the problem. Help for future visitors here! 😸

import argparse
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--source", type=str, required=True)
parser.add_argument("--dest", type=str, required=True)
args = parser.parse_args()
model_state = torch.load(args.source)
new_model_state = {}
for key in model_state.keys():
    new_model_state[key[7:]] = model_state[key]
torch.save(new_model_state, args.dest)

Run it as: python remove_module.py --source <source pickle file> --dest <destination pickle file>

liushuchun, Randl, akhileshgotmare, GuanRainy, zhenglilei, txrc, jackvial, smiler96, DanteSung, TianxingWu, and 2 more reacted with thumbs up emoji smiler96 reacted with laugh emoji bombs-kim and andy-landy reacted with confused emoji All reactions

To me, KrnTneja's solution seems overly complicated and easy to break.

What most of you would want to do is something like this

model = Model()
dp = nn.DataParallel(model)
# train dp
# ...
torch.save(model.state_dict(), path)
# or you can do torch.save(dp.module.state_dict(), path)

And when you resume your training, do this

model = Model()
model.load_state_dict(torch.load(path))
dp = nn.DataParallel(model)
 
推荐文章