3 回答

TA貢獻(xiàn)1852條經(jīng)驗(yàn) 獲得超1個(gè)贊
您是否檢查過(guò)您的訓(xùn)練/測(cè)試數(shù)據(jù)和訓(xùn)練/測(cè)試標(biāo)簽是否都是 numpy 數(shù)組?可能是您將 numpy 數(shù)組與列表混合在一起。

TA貢獻(xiàn)1818條經(jīng)驗(yàn) 獲得超3個(gè)贊
您可以通過(guò)在調(diào)用之前將標(biāo)簽轉(zhuǎn)換為數(shù)組來(lái)避免此錯(cuò)誤model.fit():
train_x = np.asarray(train_x)
train_y = np.asarray(train_y)
validation_x = np.asarray(validation_x)
validation_y = np.asarray(validation_y)

TA貢獻(xiàn)1829條經(jīng)驗(yàn) 獲得超7個(gè)贊
如果您在處理從該類繼承的自定義生成器時(shí)遇到此問(wèn)題keras.utils.Sequence,您可能必須確保不要混合使用 aKeras或tensorflow - Keras-import。
當(dāng)您必須切換到以前的tensorflow版本以實(shí)現(xiàn)兼容性時(shí)(例如 with cuDNN),這種情況尤其可能發(fā)生。
例如,如果您將其與tensorflow-version > 2 一起使用...
from keras.utils import Sequence
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
...
def __len__(self):
...
def __getitem__(self, idx):
return ...
...但是您實(shí)際上嘗試將此生成器安裝在tensorflow-version < 2 中,您必須確保Sequence從該版本導(dǎo)入 -class,例如:
keras = tf.compat.v1.keras
Sequence = keras.utils.Sequence
class generatorClass(Sequence):
...
添加回答
舉報(bào)