尝试在 GAN 上并行使用分布式数据,但收到有关就地操作的运行时错误

我正在尝试使用分布式数据并行训练一台具有 3GPU 的机器的 GAN。 在将我的模型包装在 DDP 之前一切正常,但是当我包装它时,它给了我以下运行时错误

RuntimeError:梯度计算所需的变量之一已被原位操作修改:[torch.cuda.FloatTensor [128]] 为第 5 版;而是预期的第 4 版。

我将每个相关的张量克隆到梯度以解决就地操作(如果有的话),但我找不到它。

出现问题的部分代码如下:

Tensor = torch.cuda.FloatTensor


# ----------
#  Training
# ----------

def train_gan(rank,world_size,opt):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank,world_size)

    if rank == 0:
        get_dataloader(rank,opt)
    dist.barrier()
    print(f"Rank {rank}/{world_size} training process passed data download barrier.\n")

    dataloader = get_dataloader(rank,opt)

    # Loss function
    adversarial_loss = torch.nn.BCELoss()
    # Initialize generator and discriminator
    generator = Generator()
    discriminator = Discriminator()
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    generator.to(rank)
    discriminator.to(rank)

    generator_d = DDP(generator,device_ids=[rank])
    discriminator_d = DDP(discriminator,device_ids=[rank])


    # Optimizers
    # Since we are computing the average of several batches at once (an effective batch size of
    # world_size * batch_size) we scale the learning rate to match.
    optimizer_G = torch.optim.Adam(generator_d.parameters(),lr=opt.lr * opt.world_size,betas=(opt.b1,opt.b2))
    optimizer_D = torch.optim.Adam(discriminator_d.parameters(),opt.b2))

    losses = []

    for epoch in range(opt.n_epochs):
        for i,(imgs,_) in enumerate(dataloader):

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0],1).fill_(1.0),requires_grad=False).to(rank)
            fake = Variable(Tensor(imgs.shape[0],1).fill_(0.0),requires_grad=False).to(rank)

            # Configure input
            real_imgs = Variable(imgs.type(Tensor)).to(rank)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim)))).to(rank)

            # Generate a batch of images
            gen_imgs = generator_d(z)

            # Loss measures generator's ability to fool the discriminator
            g_loss = adversarial_loss(discriminator_d(gen_imgs),valid)

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator_d(real_imgs),valid)
            fake_loss = adversarial_loss(discriminator_d(gen_imgs.detach()),fake)
            d_loss = ((real_loss + fake_loss) / 2).to(rank)

            

            d_loss.backward()
            optimizer_D.step()
a704271485 回答:尝试在 GAN 上并行使用分布式数据,但收到有关就地操作的运行时错误

我在尝试使用 DistributedDataParallel 训练 GAN 时遇到了类似的错误。 我注意到问题来自我的鉴别器中的 BatchNorm 层。

确实,DistributedDataParallel 在每次前向传递 (see the doc) 时同步批规范参数,从而就地修改变量,如果连续多次前向传递,这会导致问题。

将我的 BatchNorm 层转换为 SyncBatchNorm 对我有用:

discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
discriminator = DPP(discriminator)

在使用 DistributedDataParallel 时,您可能无论如何都想这样做。

或者,如果您不想使用 SyncBatchNorm,您可以将 broadcast_buffers 参数设置为 False,但我认为您真的不想这样做,因为这意味着您的批处理规范统计信息将不会在进程之间同步。

discriminator = DPP(discriminator,device_ids=[rank],broadcast_buffers=False)
本文链接:https://www.f2er.com/938021.html

大家都在问