我正在嘗試使用Keras庫注意實(shí)現(xiàn)序列2序列模型。該模型的框圖如下模型將輸入序列嵌入3D張量。然后,雙向lstm創(chuàng)建編碼層。接下來,將編碼后的序列發(fā)送到自定義關(guān)注層,該層返回具有每個隱藏節(jié)點(diǎn)的關(guān)注權(quán)重的2D張量。解碼器輸入作為一個熱矢量注入模型中?,F(xiàn)在在解碼器(另一個bistlm)中,解碼器輸入和注意力權(quán)重都作為輸入傳遞。解碼器的輸出被發(fā)送到具有softmax激活函數(shù)的時間分布密集層,以概率的方式獲得每個時間步長的輸出。該模型的代碼如下:encoder_input = Input(shape=(MAX_LENGTH_Input, ))embedded = Embedding(input_dim=vocab_size_input, output_dim= embedding_width, trainable=False)(encoder_input)encoder = Bidirectional(LSTM(units= hidden_size, input_shape=(MAX_LENGTH_Input,embedding_width), return_sequences=True, dropout=0.25, recurrent_dropout=0.25))(embedded)attention = Attention(MAX_LENGTH_Input)(encoder)decoder_input = Input(shape=(MAX_LENGTH_Output,vocab_size_output))merge = concatenate([attention, decoder_input])decoder = Bidirectional(LSTM(units=hidden_size, input_shape=(MAX_LENGTH_Output,vocab_size_output))(merge))output = TimeDistributed(Dense(MAX_LENGTH_Output, activation="softmax"))(decoder)問題是當(dāng)我連接注意層和解碼器輸入時。由于解碼器輸入是3D張量,而注意是2D張量,因此顯示以下錯誤:ValueError:Concatenate圖層需要輸入的形狀與concat軸一致,但匹配的軸除外。得到了輸入形狀:[(無,1024),(無,10,8281)]如何將2D注意張量轉(zhuǎn)換為3D張量?
在Keras上使用解碼器輸入seq2seq模型連接關(guān)注層
函數(shù)式編程
2021-05-03 15:12:47