這是WGAN-GP的損失函數(shù)gen_sample = model.generator(input_gen)disc_real = model.discriminator(real_image, reuse=False)disc_fake = model.discriminator(gen_sample, reuse=True)disc_concat = tf.concat([disc_real, disc_fake], axis=0)# Gradient penaltyalpha = tf.random_uniform( shape=[BATCH_SIZE, 1, 1, 1], minval=0., maxval=1.)differences = gen_sample - real_imageinterpolates = real_image + (alpha * differences)gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0] # why [0]slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))gradient_penalty = tf.reduce_mean((slopes-1.)**2)d_loss_real = tf.reduce_mean(disc_real)d_loss_fake = tf.reduce_mean(disc_fake)disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penaltygen_loss = - d_loss_fake發(fā)電機損耗震蕩,值這么大。我的問題是:發(fā)電機損耗是正常的還是異常的?
1 回答

喵喔喔
TA貢獻1735條經(jīng)驗 獲得超5個贊
需要注意的一件事是您的梯度懲罰計算是錯誤的。以下行:
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
實際上應該是:
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))
您在第一個軸上減少,但漸變基于 alpha 值顯示的圖像,因此您必須在軸上減少[1,2,3]
。
代碼中的另一個錯誤是生成器損失是:
gen_loss = d_loss_real - d_loss_fake
對于梯度計算,這沒有區(qū)別,因為生成器的參數(shù)僅包含在 d_loss_fake 中。然而,對于發(fā)電機損失的價值,這在世界上造成了很大的不同,這也是為什么會如此震蕩的原因。
歸根結(jié)底,您應該查看您關心的實際性能指標,以確定 GAN 的質(zhì)量,例如初始分數(shù)或 Fréchet 初始距離 (FID),因為鑒別器和生成器的損失僅具有輕微的描述性。
添加回答
舉報
0/150
提交
取消