在Pytorch内置的自定义Batchnorm中更新running_mean和running_var问题吗?

我一直在尝试实现自定义批处理规范化功能,以便可以将其扩展到Multi GPU版本,尤其是Pytorch中的DataParallel模块。使用1个GPU时,自定义batchnorm可以正常工作,但是当扩展为2或更大时,运行均值和方差在正向函数中起作用,但是当它从网络返回时,均值和方差被重新初始化为0和1。

torch.nn.DataParallel在警告部分中提到:“在每个转发中,模块在每个设备上复制,因此对转发中正在运行的模块的任何更新都将丢失。例如,如果模块具有一个counter属性,在每个转发中递增,它将始终保持在初始值,因为更新是在转发后销毁的副本上完成的。”但是我不确定如何保留默认设备的均值和方差。

我提供了在多GPU训练中获得的结果的代码。该代码利用here提供的Batchnorm。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torch.nn.parameter import Parameter

class ptrblck_BatchNorm2d(nn.BatchNorm2d):
    def __init__(self,num_features,eps=1e-5,momentum=0.1,affine=True,track_running_stats=True):
        super(ptrblck_BatchNorm2d,self).__init__(
            num_features,eps,momentum,affine,track_running_stats)

    def forward(self,input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0,2,3])
            # use biased var in train
            var = input.var([0,3],unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None,:,None,None]) / (torch.sqrt(var[None,None] + self.eps))
        if self.affine:
            input = input * self.weight[None,None] + self.bias[None,None]

        return input


class net(nn.Module):
    def __init__(self):
        super(net,self).__init__()
        self.conv1 = nn.Conv2d(3,64,kernel_size=3,padding=1)
        self.bn1 = ptrblck_BatchNorm2d(64)
        print("==> printing bn1 mean when init")
        print(self.bn1.running_mean)
        print("==> printing bn1 when init")
        print(self.bn1.running_mean)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
        self.classifier = nn.Linear(64,10)

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.avgpool(x)

        x = x.view(x.size(0),-1)
        x = self.classifier(x)
        print("======================================================")
        print("==> printing bn1 running mean from NET during forward")
        print(net.module.bn1.running_mean)
        print("==> printing bn1 running mean from SELF. during forward")
        print(self.bn1.running_mean)
        print("==> printing bn1 running var from NET during forward")
        print(net.module.bn1.running_var)
        print("==> printing bn1 running mean from SELF. during forward")
        print(self.bn1.running_var)
        return x

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32,padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])


transform_test = transforms.Compose([
    transforms.ToTensor(),0.2010))])


trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True,num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,transform=transform_test)
testloader = torch.utils.data.DataLoader(testset,shuffle=False,num_workers=2)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

# Model
print('==> Building model..')
net = net()
net = torch.nn.DataParallel(net).cuda()
print('Number of GPU {}'.format(torch.cuda.device_count()))

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.1,momentum=0.9,weight_decay=5e-4)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx,(inputs,targets) in enumerate(trainloader):
        inputs,targets = inputs.cuda(),targets.cuda()
        outputs = net(inputs)
        loss = criterion(outputs,targets)
        print("====================================================")
        print("==> printing bn1 running mean FROM net after forward")
        print(net.module.bn1.running_mean)
        print("==> printing bn1 running var FROM net after forward")
        print(net.module.bn1.running_var)

        break
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

        # train_loss += loss.item()
        # _,predicted = outputs.max(1)
        # total += targets.size(0)
        # correct += predicted.eq(targets).sum().item()

        # break


for epoch in range(0,1):
    train(epoch)

结果:

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..
==> printing bn1 mean when init
tensor([0.,0.,0.])
==> printing bn1 when init
tensor([0.,0.])
Number of GPU 2

Epoch: 0
======================================================
==> printing bn1 running mean from NET during forward
tensor([0.,0.],device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([ 0.0053,0.0010,-0.0077,-0.0290,0.0241,0.0258,-0.0048,0.0151,-0.0133,0.0080,0.0197,-0.0042,-0.0188,0.0233,0.0310,-0.0230,0.0222,0.0119,-0.0220,-0.0169,-0.0342,-0.0025,0.0338,-0.0070,0.0202,0.0050,0.0108,0.0008,0.0363,0.0347,-0.0106,0.0082,0.0128,0.0074,0.0111,-0.0030,-0.0089,0.0070,-0.0262,-0.0029,0.0053,-0.0136,-0.0183,0.0045,-0.0014,-0.0221,0.0132,0.0064,0.0388,-0.0008,0.0400,-0.0187,0.0397,-0.0131,-0.0176,0.0035,0.0055,-0.0270,0.0066,-0.0149,0.0135],device='cuda:0')
==> printing bn1 running var from NET during forward
tensor([1.,1.,1.],device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665,0.9073,0.9220,1.0947,1.0687,0.9624,0.9252,0.9131,0.9066,0.9536,0.9258,0.9203,1.0359,0.9690,1.1066,1.0636,0.9135,0.9644,0.9373,0.9846,0.9696,0.9454,1.0459,0.9245,0.9778,0.9709,0.9352,0.9995,0.9657,0.9510,1.0943,1.0171,0.9298,1.0747,0.9341,0.9635,0.9978,0.9303,0.9261,0.9137,0.9569,1.0066,1.0463,0.9955,0.9621,0.9172,0.9836,0.9817,0.9086,0.9576,1.0905,0.9861,0.9661,1.1773,0.9345,1.0904,0.9133,1.0660,0.9164,0.9058,0.9446,0.9225,1.0914,0.9292],device='cuda:0')
======================================================
==> printing bn1 running mean from NET during forward
tensor([0.,device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([-0.0020,0.0002,-0.0103,-0.0426,0.0386,0.0311,-0.0059,-0.0140,0.0145,0.0218,-0.0281,0.0284,0.0449,-0.0329,-0.0107,0.0278,0.0135,-0.0123,-0.0260,-0.0214,-0.0423,-0.0035,0.0410,-0.0097,0.0276,0.0102,-0.0001,0.0483,0.0451,-0.0078,0.0190,-0.0004,0.0196,-0.0028,-0.0332,-0.0110,-0.0210,-0.0226,-0.0088,-0.0314,0.0125,-0.0003,0.0505,-0.0312,0.0086,0.0544,-0.0245,0.0528,-0.0086,0.0063,0.0042,-0.0339,0.0061,-0.0277,0.0092],device='cuda:1')
==> printing bn1 running var from NET during forward
tensor([1.,0.9072,0.9211,1.0999,1.0714,0.9610,0.9209,0.9125,0.9063,0.9553,0.9260,0.9189,1.0386,0.9706,1.1139,1.0610,0.9121,0.9660,0.9366,0.9886,0.9683,1.0511,0.9227,0.9792,0.9704,0.9330,0.9989,0.9476,1.1008,1.0191,0.9294,1.0814,0.9320,0.9642,1.0006,0.9287,0.9254,0.9128,0.9559,1.0100,1.0521,0.9972,0.9168,0.9849,0.9803,0.9083,0.9556,1.0946,0.9865,0.9651,1.1880,1.0959,0.9116,1.0706,0.9149,0.9057,0.9450,0.9215,1.0972,0.9261],device='cuda:1')
====================================================
==> printing bn1 running mean FROM net after forward
tensor([0.,device='cuda:0')
==> printing bn1 running var FROM net after forward
tensor([1.,device='cuda:0')

如何确保使用默认设备的运行估算?目前,我不致力于同步Batchnorm。

a123qqq 回答:在Pytorch内置的自定义Batchnorm中更新running_mean和running_var问题吗?

替换

self.running_mean = (...)

使用

self.running_mean.copy_(...)

工作完成了。

Reference

本文链接:https://www.f2er.com/2405268.html

大家都在问