- #1
BRN
- 108
- 10
Hello everyone,
this is part of the code for a cycleGAN model that I have implemented, and it is the part related to training
I need to use Tensorflow checkpoints to train the model in multiple runs, but I have no idea how to incorporate them. I haven't used functions like fit(), model(), compile()...
Would anyone be able to help me?
this is part of the code for a cycleGAN model that I have implemented, and it is the part related to training
Training cycleGAN:
#=======================================================================================================================
# cycleGAN architecture
#=======================================================================================================================
def cyclegan(input_A, input_B):
# fake images generation
BfromA = generateB(input_A, training = True)
AfromB = generateA(input_B, training = True)
# images recostruction
regenAfromB = generateA(BfromA, training = True)
regenBfromA = generateB(AfromB, training = True)
# auto-generating
gen_orig_A = generateA(input_A, training = True)
gen_orig_B = generateB(input_B, training = True)
# auto-validating
valid_A = discriminateA(input_A, training = True)
valid_B = discriminateB(input_B, training = True)
# fake images validating
valid_AfromB = discriminateA(AfromB, training = True)
valid_BfromA = discriminateB(BfromA, training = True)
return regenAfromB, regenBfromA, gen_orig_A, gen_orig_B, valid_A, valid_B, valid_AfromB, valid_BfromA
#=======================================================================================================================
# Loss Functions - Optimizers
#=======================================================================================================================
def generator_loss(generated):
return tf.keras.losses.BinaryCrossentropy(from_logits = True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)
def discriminator_loss(real, generated):
real_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)
generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits = True,
reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)
total_disc_loss = real_loss + generated_loss
return total_disc_loss
def cycle_loss(real, generated, LAMBDA):
c_loss = tf.reduce_mean(tf.abs(real - generated))
return LAMBDA * c_loss
def identity_loss(real, same, LAMBDA):
i_loss = tf.reduce_mean(tf.abs(real - same))
return LAMBDA * i_loss
#optimizers
gen_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
disc_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1 = 0.5)
#=======================================================================================================================
# Training session
#=======================================================================================================================
generateA = generator()
discriminateA = discriminator()
generateB = generator()
discriminateB = discriminator()
inputA = tf.keras.layers.Input(shape = [HEIGHT, WIDTH, CHANNEL])
inputB = tf.keras.layers.Input(shape = [HEIGHT, WIDTH, CHANNEL])
@tf.function
def train_step(inputA, inputB):
with tf.GradientTape(persistent = True) as tape:
regenA, regenB, gen_origA, gen_origB, disc_A, disc_B, disc_AfB, disc_BfA = cyclegan(inputA, inputB)
A_gen_loss = generator_loss(disc_AfB)
B_gen_loss = generator_loss(disc_BfA)
total_cycle_loss = cycle_loss(inputA, regenA, LAMBDA) + cycle_loss(inputB, regenB, LAMBDA)
A_identity_loss = identity_loss(inputA, gen_origA, LAMBDA)
B_identity_loss = identity_loss(inputB, gen_origB, LAMBDA)
total_A_gen_loss = A_gen_loss + total_cycle_loss + A_identity_loss
total_B_gen_loss = B_gen_loss + total_cycle_loss + B_identity_loss
A_disc_loss = discriminator_loss(disc_A, disc_AfB)
B_disc_loss = discriminator_loss(disc_B, disc_BfA)
# Gradients and optimizers
A_generator_gradients = tape.gradient(total_A_gen_loss, generateA.trainable_variables)
gen_optimizer.apply_gradients(zip(A_generator_gradients, generateA.trainable_variables))
B_generator_gradients = tape.gradient(total_B_gen_loss, generateB.trainable_variables)
gen_optimizer.apply_gradients(zip(B_generator_gradients, generateB.trainable_variables))
A_discriminator_gradients = tape.gradient( A_disc_loss, discriminateA.trainable_variables)
disc_optimizer.apply_gradients(zip(A_discriminator_gradients, discriminateA.trainable_variables))
B_discriminator_gradients = tape.gradient(B_disc_loss, discriminateB.trainable_variables)
disc_optimizer.apply_gradients(zip(B_discriminator_gradients, discriminateB.trainable_variables))
# Training
def train(train_ds, epochs):
for epoch in range(epochs):
start = time.time()
print("Starting epoch", epoch + 1)
for image_x, image_y in train_ds:
train_step(image_x.numpy(), image_y.numpy())
print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
save_step(input_path_A, sample_img, epoch, 'P', generateB, step_path)
I need to use Tensorflow checkpoints to train the model in multiple runs, but I have no idea how to incorporate them. I haven't used functions like fit(), model(), compile()...
Would anyone be able to help me?