这里是来自 Keras 的 CycleGAN 的例子 CycleGAN Example Using Keras.
这是我修改后的使用多个 GPU 的实现。为了实施自定义训练,我使用了参考 Custom training with tf.distribute.Strategy
我想要一个来自 Keras 的 CycleGAN 示例,以使用 GPU 快速运行。此外,我需要处理和训练大量数据。以及 CycleGAN 使用多个损失函数 train_step
将返回 4 种类型的损失,目前,为了更容易理解,我只返回一种。尽管如此,在 GPU 上的训练还是很慢。我找不到这背后的原因。
我是否错误地使用了 tf.distribute.Strategy
?
"""
Title: CycleGAN
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
Date created: 2020/08/12
Last modified: 2020/08/12
Description: Implementation of CycleGAN.
"""
"""
## CycleGAN
CycleGAN is a model that aims to solve the image-to-image translation
problem. The goal of the image-to-image translation problem is to learn the
mapping between an input image and an output image using a training set of
aligned image pairs. However,obtaining paired examples isn't always feasible.
CycleGAN tries to learn this mapping without requiring paired input-output images,using cycle-consistent adversarial networks.
- [Paper](https://arxiv.org/pdf/1703.10593.pdf)
- [Original implementation](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)
"""
"""
## Setup
"""
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
autotune = tf.data.experimental.AUTOTUNE
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
"""
## Prepare the dataset
In this example,we will be using the
[horse to zebra](https://www.tensorflow.org/datasets/catalog/cycle_gan#cycle_ganhorse2zebra)
dataset.
"""
# Load the horse-zebra dataset using tensorflow-datasets.
dataset,_ = tfds.load("cycle_gan/horse2zebra",with_info=True,as_supervised=True)
train_horses,train_zebras = dataset["trainA"],dataset["trainB"]
test_horses,test_zebras = dataset["testA"],dataset["testB"]
# Define the standard image size.
orig_img_size = (286,286)
# Size of the random crops to be used during training.
input_img_size = (256,256,3)
# Weights initializer for the layers.
kernel_init = keras.initializers.RandomNormal(mean=0.0,stddev=0.02)
# Gamma initializer for instance normalization.
gamma_init = keras.initializers.RandomNormal(mean=0.0,stddev=0.02)
buffer_size = 256
batch_size = 1
def normalize_img(img):
img = tf.cast(img,dtype=tf.float32)
# Map values in the range [-1,1]
return (img / 127.5) - 1.0
def preprocess_train_image(img,label):
# Random flip
img = tf.image.random_flip_left_right(img)
# Resize to the original size first
img = tf.image.resize(img,[*orig_img_size])
# Random crop to 256X256
img = tf.image.random_crop(img,size=[*input_img_size])
# Normalize the pixel values in the range [-1,1]
img = normalize_img(img)
return img
def preprocess_test_image(img,label):
# Only resizing and normalization for the test images.
img = tf.image.resize(img,[input_img_size[0],input_img_size[1]])
img = normalize_img(img)
return img
"""
## Create `Dataset` objects
"""
BATCH_SIZE_PER_REPLICA = batch_size
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
# Apply the preprocessing operations to the training data
train_horses = (
train_horses.map(preprocess_train_image,num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(GLOBAL_BATCH_SIZE)
)
train_zebras = (
train_zebras.map(preprocess_train_image,num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(GLOBAL_BATCH_SIZE)
)
# Apply the preprocessing operations to the test data
test_horses = (
test_horses.map(preprocess_test_image,num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(GLOBAL_BATCH_SIZE)
)
test_zebras = (
test_zebras.map(preprocess_test_image,num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(GLOBAL_BATCH_SIZE)
)
# Visualize some samples
_,ax = plt.subplots(4,2,figsize=(10,15))
for i,samples in enumerate(zip(train_horses.take(4),train_zebras.take(4))):
horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
ax[i,0].imshow(horse)
ax[i,1].imshow(zebra)
plt.show()
plt.savefig('Visualize_Some_Samples')
plt.close()
# Building blocks used in the CycleGAN generators and discriminators
class ReflectionPadding2D(layers.Layer):
"""Implements Reflection Padding as a layer.
Args:
padding(tuple): Amount of padding for the
spatial dimensions.
Returns:
A padded tensor with the same type as the input tensor.
"""
def __init__(self,padding=(1,1),**kwargs):
self.padding = tuple(padding)
super(ReflectionPadding2D,self).__init__(**kwargs)
def call(self,input_tensor,mask=None):
padding_width,padding_height = self.padding
padding_tensor = [
[0,0],[padding_height,padding_height],[padding_width,padding_width],[0,]
return tf.pad(input_tensor,padding_tensor,mode="REFLECT")
def residual_block(
x,activation,kernel_initializer=kernel_init,kernel_size=(3,3),strides=(1,padding="valid",gamma_initializer=gamma_init,use_bias=False,):
dim = x.shape[-1]
input_tensor = x
x = ReflectionPadding2D()(input_tensor)
x = layers.Conv2D(
dim,kernel_size,strides=strides,kernel_initializer=kernel_initializer,padding=padding,use_bias=use_bias,)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = activation(x)
x = ReflectionPadding2D()(x)
x = layers.Conv2D(
dim,)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = layers.add([input_tensor,x])
return x
def downsample(
x,filters,strides=(2,2),padding="same",):
x = layers.Conv2D(
filters,)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
if activation:
x = activation(x)
return x
def upsample(
x,):
x = layers.Conv2DTranspose(
filters,)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
if activation:
x = activation(x)
return x
def get_resnet_generator(
filters=64,num_downsampling_blocks=2,num_residual_blocks=9,num_upsample_blocks=2,name=None,):
img_input = layers.Input(shape=input_img_size,name=name + "_img_input")
x = ReflectionPadding2D(padding=(3,3))(img_input)
x = layers.Conv2D(filters,(7,7),use_bias=False)(
x
)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = layers.activation("relu")(x)
# Downsampling
for _ in range(num_downsampling_blocks):
filters *= 2
x = downsample(x,filters=filters,activation=layers.activation("relu"))
# Residual blocks
for _ in range(num_residual_blocks):
x = residual_block(x,activation=layers.activation("relu"))
# Upsampling
for _ in range(num_upsample_blocks):
filters //= 2
x = upsample(x,activation=layers.activation("relu"))
# Final block
x = ReflectionPadding2D(padding=(3,3))(x)
x = layers.Conv2D(3,padding="valid")(x)
x = layers.activation("tanh")(x)
model = keras.models.Model(img_input,x,name=name)
return model
"""
## Build the discriminators
The discriminators implement the following architecture:
`C64->C128->C256->C512`
"""
def get_discriminator(
filters=64,num_downsampling=3,name=None
):
img_input = layers.Input(shape=input_img_size,name=name + "_img_input")
x = layers.Conv2D(
filters,(4,4),)(img_input)
x = layers.LeakyReLU(0.2)(x)
num_filters = filters
for num_downsample_block in range(3):
num_filters *= 2
if num_downsample_block < 2:
x = downsample(
x,filters=num_filters,activation=layers.LeakyReLU(0.2),kernel_size=(4,)
else:
x = downsample(
x,)
x = layers.Conv2D(
1,kernel_initializer=kernel_initializer
)(x)
model = keras.models.Model(inputs=img_input,outputs=x,name=name)
return model
"""
## Build the CycleGAN model
"""
class CycleGan(keras.Model):
def __init__(
self,generator_G,generator_F,discriminator_X,discriminator_Y,lambda_cycle=10.0,lambda_identity=0.5,):
super(CycleGan,self).__init__()
self.gen_G = generator_G
self.gen_F = generator_F
self.disc_X = discriminator_X
self.disc_Y = discriminator_Y
self.lambda_cycle = lambda_cycle
self.lambda_identity = lambda_identity
def compile(
self,gen_G_optimizer,gen_F_optimizer,disc_X_optimizer,disc_Y_optimizer,gen_loss_fn,disc_loss_fn,cycle_loss_fn,identity_loss_fn
):
super(CycleGan,self).compile()
self.gen_G_optimizer = gen_G_optimizer
self.gen_F_optimizer = gen_F_optimizer
self.disc_X_optimizer = disc_X_optimizer
self.disc_Y_optimizer = disc_Y_optimizer
self.generator_loss_fn = gen_loss_fn
self.discriminator_loss_fn = disc_loss_fn
#self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
#self.identity_loss_fn = keras.losses.MeanAbsoluteError()
self.cycle_loss_fn = cycle_loss_fn
self.identity_loss_fn = identity_loss_fn
def train_step(self,batch_data):
# x is Horse and y is zebra
real_x,real_y = batch_data
with tf.GradientTape(persistent=True) as tape:
# Horse to fake zebra
fake_y = self.gen_G(real_x,training=True)
# zebra to fake horse -> y2x
fake_x = self.gen_F(real_y,training=True)
# Cycle (Horse to fake zebra to fake horse): x -> y -> x
cycled_x = self.gen_F(fake_y,training=True)
# Cycle (zebra to fake horse to fake zebra) y -> x -> y
cycled_y = self.gen_G(fake_x,training=True)
# Identity mapping
same_x = self.gen_F(real_x,training=True)
same_y = self.gen_G(real_y,training=True)
# Discriminator output
disc_real_x = self.disc_X(real_x,training=True)
disc_fake_x = self.disc_X(fake_x,training=True)
disc_real_y = self.disc_Y(real_y,training=True)
disc_fake_y = self.disc_Y(fake_y,training=True)
# Generator adverserial loss
gen_G_loss = self.generator_loss_fn(disc_fake_y)
gen_F_loss = self.generator_loss_fn(disc_fake_x)
# Generator cycle loss
cycle_loss_G = self.cycle_loss_fn(real_y,cycled_y) * self.lambda_cycle
cycle_loss_F = self.cycle_loss_fn(real_x,cycled_x) * self.lambda_cycle
# Generator identity loss
id_loss_G = (
self.identity_loss_fn(real_y,same_y)
* self.lambda_cycle
* self.lambda_identity
)
id_loss_F = (
self.identity_loss_fn(real_x,same_x)
* self.lambda_cycle
* self.lambda_identity
)
# Total generator loss
total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F
# Discriminator loss
disc_X_loss = self.discriminator_loss_fn(disc_real_x,disc_fake_x)
disc_Y_loss = self.discriminator_loss_fn(disc_real_y,disc_fake_y)
# Get the gradients for the generators
grads_G = tape.gradient(total_loss_G,self.gen_G.trainable_variables)
grads_F = tape.gradient(total_loss_F,self.gen_F.trainable_variables)
# Get the gradients for the discriminators
disc_X_grads = tape.gradient(disc_X_loss,self.disc_X.trainable_variables)
disc_Y_grads = tape.gradient(disc_Y_loss,self.disc_Y.trainable_variables)
# Update the weights of the generators
self.gen_G_optimizer.apply_gradients(
zip(grads_G,self.gen_G.trainable_variables)
)
self.gen_F_optimizer.apply_gradients(
zip(grads_F,self.gen_F.trainable_variables)
)
# Update the weights of the discriminators
self.disc_X_optimizer.apply_gradients(
zip(disc_X_grads,self.disc_X.trainable_variables)
)
self.disc_Y_optimizer.apply_gradients(
zip(disc_Y_grads,self.disc_Y.trainable_variables)
)
return total_loss_G
# return [total_loss_G,total_loss_F,disc_X_loss,disc_Y_loss]
# Open a strategy scope.
with strategy.scope():
mae_loss_fn = keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)
# Loss function for evaluating cycle consistency loss
def cycle_loss_fn(real,cycled):
cycle_loss = mae_loss_fn(real,cycled)
cycle_loss = tf.nn.compute_average_loss(cycle_loss,global_batch_size=GLOBAL_BATCH_SIZE)
return cycle_loss
# Loss function for evaluating identity mapping loss
def identity_loss_fn(real,same):
identity_loss = mae_loss_fn(real,same)
identity_loss = tf.nn.compute_average_loss(identity_loss,global_batch_size=GLOBAL_BATCH_SIZE)
return identity_loss
# Loss function for evaluating adversarial loss
adv_loss_fn = keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
# Define the loss function for the generators
def generator_loss_fn(fake):
fake_loss = adv_loss_fn(tf.ones_like(fake),fake)
fake_loss = tf.nn.compute_average_loss(fake_loss,global_batch_size=GLOBAL_BATCH_SIZE)
return fake_loss
# Define the loss function for the discriminators
def discriminator_loss_fn(real,fake):
real_loss = adv_loss_fn(tf.ones_like(real),real)
fake_loss = adv_loss_fn(tf.zeros_like(fake),fake)
real_loss = tf.nn.compute_average_loss(real_loss,global_batch_size=GLOBAL_BATCH_SIZE)
fake_loss = tf.nn.compute_average_loss(fake_loss,global_batch_size=GLOBAL_BATCH_SIZE)
return (real_loss + fake_loss) * 0.5
# Get the generators
gen_G = get_resnet_generator(name="generator_G")
gen_F = get_resnet_generator(name="generator_F")
# Get the discriminators
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")
# Create cycle gan model
cycle_gan_model = CycleGan(
generator_G=gen_G,generator_F=gen_F,discriminator_X=disc_X,discriminator_Y=disc_Y
)
optimizer = keras.optimizers.Adam(learning_rate=2e-4,beta_1=0.5)
# Compile the model
cycle_gan_model.compile(
gen_G_optimizer=optimizer,gen_F_optimizer=optimizer,disc_X_optimizer=optimizer,disc_Y_optimizer=optimizer,gen_loss_fn=generator_loss_fn,disc_loss_fn=discriminator_loss_fn,cycle_loss_fn=cycle_loss_fn,identity_loss_fn=identity_loss_fn
)
train_dist_dataset = strategy.experimental_distribute_dataset(
tf.data.Dataset.zip((train_horses,train_zebras)))
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(cycle_gan_model.train_step,args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM,per_replica_losses,axis=None)
"""
## Train the end-to-end model
"""
for epoch in range(1):
# TRAIN LOOP
all_loss = 0.0
num_batches = 0.0
for one_batch in train_dist_dataset:
all_loss += distributed_train_step(one_batch)
num_batches += 1
train_loss = all_loss/num_batches
print(train_loss)