1 回答

TA貢獻2080條經(jīng)驗 獲得超4個贊
據(jù)我了解,您的假設(shè)是正確的。
如果您檢查 github [keras/losses_utils.py][1] 第 260-269 行,您將看到它確實按預期執(zhí)行。 SUM將總結(jié)批量維度中的損失,并SUM_OVER_BATCH_SIZE除以總SUM損失數(shù)(批量大?。?/p>
def reduce_weighted_loss(weighted_losses,
reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
if reduction == ReductionV2.NONE:
loss = weighted_losses
else:
loss = tf.reduce_sum(weighted_losses)
if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
loss = _safe_mean(loss, _num_elements(weighted_losses))
return loss
您只需添加一對損失為零的輸出即可對前面的示例進行簡單檢查。
y_true = [[0., 2.], [0., 0.],[1.,1.]]
y_pred = [[3., 1.], [2., 5.],[1.,1.]]
mae = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)
mae(y_true, y_pred).numpy()
> 5.5
mae = tf.keras.losses.MeanAbsoluteError()
mae(y_true, y_pred).numpy()
> 1.8333
所以,你的假設(shè)是正確的。[1]:https://github.com/keras-team/keras/blob/v2.7.0/keras/utils/losses_utils.py#L25-L84
添加回答
舉報