使用 torch.nn.DataParallel() 时如何访问类对象?

我想使用带有多个 GPU 的 PyTorch 训练我的模型。我包括以下行:

model = torch.nn.DataParallel(model,device_ids=opt.gpu_ids)

然后,我尝试访问在我的模型定义中定义的优化器:

G_opt = model.module.optimizer_G

但是,我遇到了一个错误:

AttributeError: 'DataParallel' 对象没有属性 optimizer_G

我认为这与我的模型定义中优化器的定义有关。当我在没有 torch.nn.DataParallel 的情况下使用单个 GPU 时,它可以工作。但它不适用于多 GPU,即使我使用 module 调用并且我找不到解决方案。

这是模型定义:

class MyModel(torch.nn.Module):
    ...
   self.optimizer_G = torch.optim.Adam(params,lr=opt.lr,betas=(opt.beta1,0.999))   

如果您想查看完整代码,我在 GitHub 中使用了 Pix2PixHD 实现。

谢谢, 最好的。

编辑:我使用 model.module.module.optimizer_G 解决了这个问题。

yexi831020zqymy 回答:使用 torch.nn.DataParallel() 时如何访问类对象?

暂时没有好的解决方案,如果你有好的解决方案,请发邮件至:iooj@foxmail.com
本文链接:https://www.f2er.com/768248.html

大家都在问