如何使用 tf.strategy 修改 Keras CycleGAN 示例代码以在 GPU 上并行运行

这里是来自 Keras 的 CycleGAN 的例子 CycleGAN Example Using Keras.

这是我修改后的使用多个 GPU 的实现。为了实施自定义训练,我使用了参考 Custom training with tf.distribute.Strategy

我想要一个来自 Keras 的 CycleGAN 示例,以使用 GPU 快速运行。此外,我需要处理和训练大量数据。以及 CycleGAN 使用多个损失函数 train_step 将返回 4 种类型的损失,目前,为了更容易理解,我只返回一种。尽管如此,在 GPU 上的训练还是很慢。我找不到这背后的原因。

我是否错误地使用了 tf.distribute.Strategy

"""
Title: CycleGAN
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
Date created: 2020/08/12
Last modified: 2020/08/12
Description: Implementation of CycleGAN.
"""

"""
## CycleGAN
CycleGAN is a model that aims to solve the image-to-image translation
problem. The goal of the image-to-image translation problem is to learn the
mapping between an input image and an output image using a training set of
aligned image pairs. However,obtaining paired examples isn't always feasible.
CycleGAN tries to learn this mapping without requiring paired input-output images,using cycle-consistent adversarial networks.
- [Paper](https://arxiv.org/pdf/1703.10593.pdf)
- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
"""

"""
## Setup
"""

import os
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import tensorflow_addons as tfa
import tensorflow_datasets as tfds

tfds.disable_progress_bar()
autotune = tf.data.experimental.AUTOTUNE

# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

"""
## Prepare the dataset
In this example,we will be using the
[horse to zebra](https://www.tensorflow.org/datasets/catalog/cycle_gan#cycle_ganhorse2zebra)
dataset.
"""

# Load the horse-zebra dataset using tensorflow-datasets.
dataset,_ = tfds.load("cycle_gan/horse2zebra",with_info=True,as_supervised=True)
train_horses,train_zebras = dataset["trainA"],dataset["trainB"]
test_horses,test_zebras = dataset["testA"],dataset["testB"]

# Define the standard image size.
orig_img_size = (286,286)
# Size of the random crops to be used during training.
input_img_size = (256,256,3)
# Weights initializer for the layers.
kernel_init = keras.initializers.RandomNormal(mean=0.0,stddev=0.02)
# Gamma initializer for instance normalization.
gamma_init = keras.initializers.RandomNormal(mean=0.0,stddev=0.02)

buffer_size = 256
batch_size = 1


def normalize_img(img):
    img = tf.cast(img,dtype=tf.float32)
    # Map values in the range [-1,1]
    return (img / 127.5) - 1.0


def preprocess_train_image(img,label):
    # Random flip
    img = tf.image.random_flip_left_right(img)
    # Resize to the original size first
    img = tf.image.resize(img,[*orig_img_size])
    # Random crop to 256X256
    img = tf.image.random_crop(img,size=[*input_img_size])
    # Normalize the pixel values in the range [-1,1]
    img = normalize_img(img)
    return img


def preprocess_test_image(img,label):
    # Only resizing and normalization for the test images.
    img = tf.image.resize(img,[input_img_size[0],input_img_size[1]])
    img = normalize_img(img)
    return img


"""
## Create `Dataset` objects
"""
BATCH_SIZE_PER_REPLICA = batch_size
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync  


# Apply the preprocessing operations to the training data
train_horses = (
    train_horses.map(preprocess_train_image,num_parallel_calls=autotune)
    .cache()
    .shuffle(buffer_size)
    .batch(GLOBAL_BATCH_SIZE)
)
train_zebras = (
    train_zebras.map(preprocess_train_image,num_parallel_calls=autotune)
    .cache()
    .shuffle(buffer_size)
    .batch(GLOBAL_BATCH_SIZE)
)

# Apply the preprocessing operations to the test data
test_horses = (
    test_horses.map(preprocess_test_image,num_parallel_calls=autotune)
    .cache()
    .shuffle(buffer_size)
    .batch(GLOBAL_BATCH_SIZE)
)
test_zebras = (
    test_zebras.map(preprocess_test_image,num_parallel_calls=autotune)
    .cache()
    .shuffle(buffer_size)
    .batch(GLOBAL_BATCH_SIZE)
)

# Visualize some samples

_,ax = plt.subplots(4,2,figsize=(10,15))
for i,samples in enumerate(zip(train_horses.take(4),train_zebras.take(4))):
    horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
    zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
    ax[i,0].imshow(horse)
    ax[i,1].imshow(zebra)
plt.show()
plt.savefig('Visualize_Some_Samples')
plt.close()     


# Building blocks used in the CycleGAN generators and discriminators

class ReflectionPadding2D(layers.Layer):
    """Implements Reflection Padding as a layer.
    Args:
        padding(tuple): Amount of padding for the
        spatial dimensions.
    Returns:
        A padded tensor with the same type as the input tensor.
    """

    def __init__(self,padding=(1,1),**kwargs):
        self.padding = tuple(padding)
        super(ReflectionPadding2D,self).__init__(**kwargs)

    def call(self,input_tensor,mask=None):
        padding_width,padding_height = self.padding
        padding_tensor = [
            [0,0],[padding_height,padding_height],[padding_width,padding_width],[0,]
        return tf.pad(input_tensor,padding_tensor,mode="REFLECT")


def residual_block(
    x,activation,kernel_initializer=kernel_init,kernel_size=(3,3),strides=(1,padding="valid",gamma_initializer=gamma_init,use_bias=False,):
    dim = x.shape[-1]
    input_tensor = x

    x = ReflectionPadding2D()(input_tensor)
    x = layers.Conv2D(
        dim,kernel_size,strides=strides,kernel_initializer=kernel_initializer,padding=padding,use_bias=use_bias,)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = activation(x)

    x = ReflectionPadding2D()(x)
    x = layers.Conv2D(
        dim,)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.add([input_tensor,x])
    return x


def downsample(
    x,filters,strides=(2,2),padding="same",):
    x = layers.Conv2D(
        filters,)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x


def upsample(
    x,):
    x = layers.Conv2DTranspose(
        filters,)(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x




def get_resnet_generator(
    filters=64,num_downsampling_blocks=2,num_residual_blocks=9,num_upsample_blocks=2,name=None,):
    img_input = layers.Input(shape=input_img_size,name=name + "_img_input")
    x = ReflectionPadding2D(padding=(3,3))(img_input)
    x = layers.Conv2D(filters,(7,7),use_bias=False)(
        x
    )
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.activation("relu")(x)

    # Downsampling
    for _ in range(num_downsampling_blocks):
        filters *= 2
        x = downsample(x,filters=filters,activation=layers.activation("relu"))

    # Residual blocks
    for _ in range(num_residual_blocks):
        x = residual_block(x,activation=layers.activation("relu"))

    # Upsampling
    for _ in range(num_upsample_blocks):
        filters //= 2
        x = upsample(x,activation=layers.activation("relu"))

    # Final block
    x = ReflectionPadding2D(padding=(3,3))(x)
    x = layers.Conv2D(3,padding="valid")(x)
    x = layers.activation("tanh")(x)

    model = keras.models.Model(img_input,x,name=name)
    return model


"""
## Build the discriminators
The discriminators implement the following architecture:
`C64->C128->C256->C512`
"""

def get_discriminator(
    filters=64,num_downsampling=3,name=None
):
    img_input = layers.Input(shape=input_img_size,name=name + "_img_input")
    x = layers.Conv2D(
        filters,(4,4),)(img_input)
    x = layers.LeakyReLU(0.2)(x)

    num_filters = filters
    for num_downsample_block in range(3):
        num_filters *= 2
        if num_downsample_block < 2:
            x = downsample(
                x,filters=num_filters,activation=layers.LeakyReLU(0.2),kernel_size=(4,)
        else:
            x = downsample(
                x,)

    x = layers.Conv2D(
        1,kernel_initializer=kernel_initializer
    )(x)

    model = keras.models.Model(inputs=img_input,outputs=x,name=name)
    return model



"""
## Build the CycleGAN model
"""

class CycleGan(keras.Model):
    def __init__(
        self,generator_G,generator_F,discriminator_X,discriminator_Y,lambda_cycle=10.0,lambda_identity=0.5,):
        super(CycleGan,self).__init__()
        self.gen_G = generator_G
        self.gen_F = generator_F
        self.disc_X = discriminator_X
        self.disc_Y = discriminator_Y
        self.lambda_cycle = lambda_cycle
        self.lambda_identity = lambda_identity

    def compile(
        self,gen_G_optimizer,gen_F_optimizer,disc_X_optimizer,disc_Y_optimizer,gen_loss_fn,disc_loss_fn,cycle_loss_fn,identity_loss_fn
    ):
        super(CycleGan,self).compile()
        self.gen_G_optimizer = gen_G_optimizer
        self.gen_F_optimizer = gen_F_optimizer
        self.disc_X_optimizer = disc_X_optimizer
        self.disc_Y_optimizer = disc_Y_optimizer
        self.generator_loss_fn = gen_loss_fn
        self.discriminator_loss_fn = disc_loss_fn
        #self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
        #self.identity_loss_fn = keras.losses.MeanAbsoluteError()
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self,batch_data):
        # x is Horse and y is zebra
        real_x,real_y = batch_data

        with tf.GradientTape(persistent=True) as tape:
            # Horse to fake zebra
            fake_y = self.gen_G(real_x,training=True)
            # zebra to fake horse -> y2x
            fake_x = self.gen_F(real_y,training=True)

            # Cycle (Horse to fake zebra to fake horse): x -> y -> x
            cycled_x = self.gen_F(fake_y,training=True)
            # Cycle (zebra to fake horse to fake zebra) y -> x -> y
            cycled_y = self.gen_G(fake_x,training=True)

            # Identity mapping
            same_x = self.gen_F(real_x,training=True)
            same_y = self.gen_G(real_y,training=True)

            # Discriminator output
            disc_real_x = self.disc_X(real_x,training=True)
            disc_fake_x = self.disc_X(fake_x,training=True)

            disc_real_y = self.disc_Y(real_y,training=True)
            disc_fake_y = self.disc_Y(fake_y,training=True)

            # Generator adverserial loss
            gen_G_loss = self.generator_loss_fn(disc_fake_y)
            gen_F_loss = self.generator_loss_fn(disc_fake_x)

            # Generator cycle loss
            cycle_loss_G = self.cycle_loss_fn(real_y,cycled_y) * self.lambda_cycle
            cycle_loss_F = self.cycle_loss_fn(real_x,cycled_x) * self.lambda_cycle

            # Generator identity loss
            id_loss_G = (
                self.identity_loss_fn(real_y,same_y)
                * self.lambda_cycle
                * self.lambda_identity
            )
            id_loss_F = (
                self.identity_loss_fn(real_x,same_x)
                * self.lambda_cycle
                * self.lambda_identity
            )

            # Total generator loss
            total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
            total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F

            # Discriminator loss
            disc_X_loss = self.discriminator_loss_fn(disc_real_x,disc_fake_x)
            disc_Y_loss = self.discriminator_loss_fn(disc_real_y,disc_fake_y)

        # Get the gradients for the generators
        grads_G = tape.gradient(total_loss_G,self.gen_G.trainable_variables)
        grads_F = tape.gradient(total_loss_F,self.gen_F.trainable_variables)

        # Get the gradients for the discriminators
        disc_X_grads = tape.gradient(disc_X_loss,self.disc_X.trainable_variables)
        disc_Y_grads = tape.gradient(disc_Y_loss,self.disc_Y.trainable_variables)

        # Update the weights of the generators
        self.gen_G_optimizer.apply_gradients(
            zip(grads_G,self.gen_G.trainable_variables)
        )
        self.gen_F_optimizer.apply_gradients(
            zip(grads_F,self.gen_F.trainable_variables)
        )

        # Update the weights of the discriminators
        self.disc_X_optimizer.apply_gradients(
            zip(disc_X_grads,self.disc_X.trainable_variables)
        )
        self.disc_Y_optimizer.apply_gradients(
            zip(disc_Y_grads,self.disc_Y.trainable_variables)
        )

        return total_loss_G
        # return [total_loss_G,total_loss_F,disc_X_loss,disc_Y_loss]
        


# Open a strategy scope.
with strategy.scope():
   mae_loss_fn = keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)
    
    # Loss function for evaluating cycle consistency loss
    def cycle_loss_fn(real,cycled):
        cycle_loss = mae_loss_fn(real,cycled)
        cycle_loss = tf.nn.compute_average_loss(cycle_loss,global_batch_size=GLOBAL_BATCH_SIZE)
        return cycle_loss
         

    # Loss function for evaluating identity mapping loss
    def identity_loss_fn(real,same):
        identity_loss = mae_loss_fn(real,same)
        identity_loss = tf.nn.compute_average_loss(identity_loss,global_batch_size=GLOBAL_BATCH_SIZE)
        return identity_loss

    # Loss function for evaluating adversarial loss
    adv_loss_fn = keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)

    # Define the loss function for the generators
    def generator_loss_fn(fake):
        fake_loss = adv_loss_fn(tf.ones_like(fake),fake)
        fake_loss = tf.nn.compute_average_loss(fake_loss,global_batch_size=GLOBAL_BATCH_SIZE)
        return fake_loss


    # Define the loss function for the discriminators
    def discriminator_loss_fn(real,fake):
        real_loss = adv_loss_fn(tf.ones_like(real),real)
        fake_loss = adv_loss_fn(tf.zeros_like(fake),fake)
        real_loss = tf.nn.compute_average_loss(real_loss,global_batch_size=GLOBAL_BATCH_SIZE)
        fake_loss = tf.nn.compute_average_loss(fake_loss,global_batch_size=GLOBAL_BATCH_SIZE)  
        return (real_loss + fake_loss) * 0.5

    # Get the generators
    gen_G = get_resnet_generator(name="generator_G")
    gen_F = get_resnet_generator(name="generator_F")

    # Get the discriminators
    disc_X = get_discriminator(name="discriminator_X")
    disc_Y = get_discriminator(name="discriminator_Y")



    # Create cycle gan model
    cycle_gan_model = CycleGan(
        generator_G=gen_G,generator_F=gen_F,discriminator_X=disc_X,discriminator_Y=disc_Y
    )
    optimizer = keras.optimizers.Adam(learning_rate=2e-4,beta_1=0.5)    
    # Compile the model
    cycle_gan_model.compile(
        gen_G_optimizer=optimizer,gen_F_optimizer=optimizer,disc_X_optimizer=optimizer,disc_Y_optimizer=optimizer,gen_loss_fn=generator_loss_fn,disc_loss_fn=discriminator_loss_fn,cycle_loss_fn=cycle_loss_fn,identity_loss_fn=identity_loss_fn
    )


train_dist_dataset = strategy.experimental_distribute_dataset(
    tf.data.Dataset.zip((train_horses,train_zebras)))

# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(cycle_gan_model.train_step,args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM,per_replica_losses,axis=None)

"""
## Train the end-to-end model
"""
for epoch in range(1):
    # TRAIN LOOP
    all_loss = 0.0
    num_batches = 0.0
    for one_batch in train_dist_dataset:
        all_loss +=  distributed_train_step(one_batch)
        num_batches += 1
    train_loss = all_loss/num_batches
    print(train_loss)
afei198602 回答:如何使用 tf.strategy 修改 Keras CycleGAN 示例代码以在 GPU 上并行运行

暂时没有好的解决方案,如果你有好的解决方案,请发邮件至:iooj@foxmail.com
本文链接:https://www.f2er.com/750972.html

大家都在问