Training the pix2pix network

Like any other GAN, training the pix2pix network is a two-step process. In the first step, we train the discriminator network. In the second step, we train the adversarial network, which eventually trains the generator network. Let's start training the network.

Perform the following steps to train an SRGAN network:

  1. Start by defining the hyperparameters that are required for training:
epochs = 500
num_images_per_epoch = 400
batch_size = 1
img_width = 256
img_height = 256
num_channels = 1
input_img_dim = (256, 256, 1)
patch_dim = (256, 256)

# Specify dataset directory path
dataset_dir = "pix2pix-keras/pix2pix/data/facades_bw"
  1. Next, define the common optimizer, shown as follows:
common_optimizer = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999,  
epsilon=1e-08)

 For all networks, we will use the Adam optimizer with the learning rate equal to 1e-4, beta_1 equal to 0.9, beta_2 equal to 0.999, and epsilon equal to 1e-08. 

  1. Next, build and compile the PatchGAN discriminator network, as follows:
patchgan_discriminator = build_patchgan_discriminator()
patchgan_discriminator.compile(loss='binary_crossentropy', optimizer=common_optimizer)

To compile the discriminator model, use binary_crossentropy as the loss function and common_optimizer as the training optimizer.

  1. Now build and compile the generator network, as follows:
unet_generator = build_unet_generator()
unet_generator.compile(loss='mae', optimizer=common_optimizer)

To compile the discriminator model, use mse as the loss function and common_optimizer as the training optimizer.

  1. Next, build and compile the adversarial model, as follows:
adversarial_model = build_adversarial_model(unet_generator, patchgan_discriminator)
adversarial_model.compile(loss=['mae', 'binary_crossentropy'], loss_weights=[1E2, 1], optimizer=common_optimizer)

To compile the adversarial model, use a list of losses ['mse', 'binary_crossentropy']  and common_optimizer as the training optimizer.

  1. Now load the training, validation, and test datasets, as follows:
training_facade_photos, training_facade_labels = load_dataset(data_dir=dataset_dir, data_type='training',img_width=img_width, img_height=img_height)

test_facade_photos, test_facade_labels = load_dataset(data_dir=dataset_dir, data_type='testing',img_width=img_width, img_height=img_height)

validation_facade_photos, validation_facade_labels = load_dataset(data_dir=dataset_dir, data_type='validation',img_width=img_width, img_height=img_height)

The load_dataset function was defined in the Data preparation section.  Each set contains a set of ndarrays of all images. The dimension of each set will be (#total_images, 256, 256, 1).

  1. Add tensorboard to visualize the training losses and to visualize the network graphs:
tensorboard = TensorBoard(log_dir="logs/".format(time.time()))
tensorboard.set_model(unet_generator)
tensorboard.set_model(patchgan_discriminator)

  1. Next, create a for loop, which should run for the number of times specified by the number of epochs, as follows:
for epoch in range(epochs):
print("Epoch:{}".format(epoch))
  1. Create two lists to store the losses for all mini-batches:
dis_losses = []
gen_losses = []

# Initialize a variable
batch_counter = 1
  1. Next, create another loop inside the epochs loop, and make it run for the number of times that is specified by num_batches, as follows:
num_batches = int(training_facade_photos.shape[0] / batch_size)
for
index in range(int(training_facade_photos.shape[0] / batch_size)):
print("Batch:{}".format(index))

Our entire code for the training of the discriminator networks and the adversarial network will be inside this loop.

  1. Next, sample a mini-batch of training and validation data, shown as follows:
        train_facades_batch = training_facade_labels[index * batch_size:(index + 1) * batch_size]
train_images_batch = training_facade_photos[index * batch_size:(index + 1) * batch_size]

val_facades_batch = validation_facade_labels[index * batch_size:(index + 1) * batch_size]
val_images_batch = validation_facade_photos[index * batch_size:(index + 1) * batch_size]
  1. Next, generate a batch of fake images and extract patches from them. Use the generate_and_extract_patches function as follows:
patches, labels = generate_and_extract_patches(train_images_batch, train_facades_batch, unet_generator,batch_counter, patch_dim)

The generate_and_extract_patches function is defined as follows:

def generate_and_extract_patches(images, facades, generator_model, batch_counter, patch_dim):
# Alternatively, train the discriminator network on real and generated images
if batch_counter % 2 == 0:
# Generate fake images
output_images = generator_model.predict(facades)

# Create a batch of ground truth labels
labels = np.zeros((output_images.shape[0], 2), dtype=np.uint8)
labels[:, 0] = 1

else:
# Take real images
output_images = images

# Create a batch of ground truth labels
labels = np.zeros((output_images.shape[0], 2), dtype=np.uint8)
labels[:, 1] = 1

patches = []
for y in range(0, output_images.shape[0], patch_dim[0]):
for x in range(0, output_images.shape[1], patch_dim[1]):
image_patches = output_images[:, y: y + patch_dim[0], x: x + patch_dim[1], :]
patches.append(np.asarray(image_patches, dtype=np.float32))

return patches, labels

The preceding function uses the generator network to generate fake images and then extracts patches from the generated images. Now we should have a list of patches and their ground truth values.

  1. Now, train the discriminator network on the generated patches:
d_loss = patchgan_discriminator.train_on_batch(patches, labels)

This will train the discriminator network on the extracted patches and the ground truth labels. 

  1. Next, train the adversarial model. The adversarial model will train the generator network but freezes the training of the discriminator network. Use the following code:
labels = np.zeros((train_images_batch.shape[0], 2), dtype=np.uint8)
labels[:, 1] = 1

# Train the adversarial model
g_loss = adversarial_model.train_on_batch(train_facades_batch, [train_images_batch, labels])
  1. Increase the batch counter after the completion of each mini-batch:
batch_counter += 1
  1. After the completion of a single iteration (loop) over each mini-batch, store the losses in lists called dis_losses and gen_losses:
dis_losses.append(d_loss)
gen_losses.append(g_loss)
  1. Also, store the average losses to TensorBoard for visualization. Store both losses: the average loss for the generator network and the average loss for the discriminator network:
write_log(tensorboard, 'discriminator_loss', np.mean(dis_losses), 
epoch)
write_log(tensorboard, 'generator_loss', np.mean(gen_losses), epoch)
  1. After every 10 epochs, use the generator networks to generate a set of images:
    # After every 10th epoch, generate and save images for visualization
if epoch % 10 == 0:
# Sample a batch of validation datasets
val_facades_batch = validation_facade_labels[0:5]
val_images_batch = validation_facade_photos[0:5]

# Generate images
validation_generated_images = unet_generator.predict(val_facades_batch)

# Save images
save_images(val_images_batch, val_facades_batch, validation_generated_images, epoch, 'validation', limit=5)

Put the preceding code block inside the epochs loop. After every 10 epochs, it will generate a batch of fake images and save them to the results directory. Here, save_images() is a utility function defined as follows:

def save_images(real_images, real_sketches, generated_images, num_epoch, dataset_name, limit):
real_sketches = real_sketches * 255.0
real_images = real_images * 255.0
generated_images = generated_images * 255.0

# Save some images only
real_sketches = real_sketches[:limit]
generated_images = generated_images[:limit]
real_images = real_images[:limit]

# Create a stack of images
X = np.hstack((real_sketches, generated_images, real_images))

# Save stack of images
imwrite('results/X_full_{}_{}.png'.format(dataset_name, num_epoch), X[0])

Now we have successfully trained the pix2pix network on the facades dataset. Train the network for 1,000 epochs to get a generator network with good quality.