1 回答

TA貢獻(xiàn)1788條經(jīng)驗(yàn) 獲得超4個(gè)贊
由于您對(duì)使用可訓(xùn)練權(quán)重不感興趣(我將它們標(biāo)記為系數(shù)以將它們與可訓(xùn)練權(quán)重區(qū)分開),您可以連接輸出并將它們作為單個(gè)輸出傳遞給自定義損失函數(shù)。這意味著這些系數(shù)將在訓(xùn)練開始時(shí)可用。
您應(yīng)該提供如上所述的自定義損失函數(shù)。損失函數(shù)預(yù)計(jì)只接受 2 個(gè)參數(shù),因此您應(yīng)該使用這樣一個(gè)函數(shù)categorical_crossentropy,它也應(yīng)該熟悉您感興趣的參數(shù),例如coeffs和num_class。因此,我使用所需的參數(shù)實(shí)例化一個(gè)包裝函數(shù),然后將內(nèi)部實(shí)際損失函數(shù)作為主損失函數(shù)傳遞。
from tensorflow.keras.layers import Dense, Dropout, Input, Concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.keras import backend as K
def categorical_crossentropy_base(coeffs, num_class):
def categorical_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):
"""Computes the categorical crossentropy loss.
Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
we assume that `y_pred` encodes a probability distribution.
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
Returns:
Categorical crossentropy loss value.
https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/python/keras/losses.py#L938-L966
"""
y_pred1 = y_pred[:, :num_class] # the 1st prediction
y_pred2 = y_pred[:, num_class:2*num_class] # the 2nd prediction
y_pred3 = y_pred[:, 2*num_class:] # the 3rd prediction
# you should adapt the ground truth to contain all 3 ground truth of course
y_true1 = y_true[:, :num_class] # the 1st gt
y_true2 = y_true[:, num_class:2*num_class] # the 2nd gt
y_true3 = y_true[:, 2*num_class:] # the 3rd gt
loss1 = K.categorical_crossentropy(y_true1, y_pred1, from_logits=from_logits)
loss2 = K.categorical_crossentropy(y_true2, y_pred2, from_logits=from_logits)
loss3 = K.categorical_crossentropy(y_true3, y_pred3, from_logits=from_logits)
# combine the losses the way you like it
total_loss = coeffs[0]*loss1 + coeffs[1]*(loss1 - loss2) + coeffs[2]*(loss2 - loss3)
return total_loss
return categorical_crossentropy
in1 = Input((6373,))
enc1 = Dense(num_nodes)(in1)
enc1 = Dropout(0.3)(enc1)
enc1 = Dense(num_nodes, activation='relu')(enc1)
enc1 = Dropout(0.3)(enc1)
enc1 = Dense(num_nodes, activation='relu')(enc1)
out1 = Dense(units=num_class, activation='softmax')(enc1)
in2 = Input((512,))
enc2 = Dense(num_nodes, activation='relu')(in2)
enc2 = Dense(num_nodes, activation='relu')(enc2)
out2 = Dense(units=num_class, activation='softmax')(enc2)
in3 = Input((768,))
enc3 = Dense(num_nodes, activation='relu')(in3)
enc3 = Dense(num_nodes, activation='relu')(enc3)
out3 = Dense(units=num_class, activation='softmax')(enc3)
adam = Adam(lr=0.0001)
total_out = Concatenate(axis=1)([out1, out2, out3])
model = Model(inputs=[in1, in2, in3], outputs=[total_out])
coeffs = [1, 1, 1]
model.compile(loss=categorical_crossentropy_base(coeffs=coeffs, num_class=num_class), optimizer='adam', metrics=['accuracy'])
不過(guò),我不確定有關(guān)準(zhǔn)確性的指標(biāo)。但我認(rèn)為無(wú)需其他更改即可發(fā)揮作用。我也在使用K.categorical_crossentropy,但是您當(dāng)然也可以自由地使用其他實(shí)現(xiàn)來(lái)更改它。
添加回答
舉報(bào)