多 GPU 测试:类型错误:“int”对象不可迭代

将深度迁移学习单GPU训练代码修改为多GPU训练时,代码测试部分出现如下错误:

Traceback (most recent call last):
  File "../main.py",line 155,in <module>
    main(args)
  File "../main.py",line 57,in main
    t_correct = test(args,model,tar_test_loader,cuda_stat)
  File "../main.py",line 136,in test
    s_output,_,_  = model(data,data,target)
  File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py",line 889,in _call_impl
    result = self.forward(*input,**kwargs)
  File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py",line 168,in forward
    return self.gather(outputs,self.output_device)
  File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py",line 180,in gather
    return gather(outputs,output_device,dim=self.dim)
  File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py",line 76,in gather
    res = gather_map(outputs)
  File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py",line 71,in gather_map
    return type(out)(map(gather_map,zip(*outputs)))
  File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py",zip(*outputs)))
TypeError: 'int' object is not iterable

我找不到错误的原因。以下是我的测试代码的一部分:

def test(args,target_test_loader,cuda_stat):
    model.eval()
    test_loss = 0
    correct = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    len_target_dataset = len(target_test_loader.dataset)
    with torch.no_grad():
        for data,target in target_test_loader:
            if cuda_stat:
                data,target = data.cuda(),target.cuda()
            s_output,target)
            test_loss += criterion(s_output,target)# sum up batch loss
            pred = torch.max(s_output,1)[1]  # get the index of the max log-probability
            print(pred)
            correct += torch.sum(pred == target)
        test_loss /= len_target_dataset
        print(args.test_dir,'  Test set:Loss: {} {:.6f}  accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss,correct,len_target_dataset,100. * correct / len_target_dataset))
    return correct
quanlong123456 回答:多 GPU 测试:类型错误:“int”对象不可迭代

暂时没有好的解决方案,如果你有好的解决方案,请发邮件至:iooj@foxmail.com
本文链接:https://www.f2er.com/41504.html

大家都在问