3 回答
TA貢獻(xiàn)1852條經(jīng)驗(yàn) 獲得超1個(gè)贊
我已經(jīng)用這個(gè)方法解決了這個(gè)問題,不知道有沒有更簡(jiǎn)單的方法。
class subclass(Model):
def __init__(self):
...
def call(self, x):
...
def model(self):
x = Input(shape=(24, 24, 3))
return Model(inputs=[x], outputs=self.call(x))
if __name__ == '__main__':
sub = subclass()
sub.model().summary()
TA貢獻(xiàn)1786條經(jīng)驗(yàn) 獲得超11個(gè)贊
我想關(guān)鍵點(diǎn)是_init_graph_network類中的方法Network,它是Model. _init_graph_network如果在調(diào)用方法時(shí)指定inputs和outputs參數(shù),將被調(diào)用__init__。
所以會(huì)有兩種可能的方法:
手動(dòng)調(diào)用
_init_graph_network方法構(gòu)建模型圖。用輸入層和輸出重新初始化。
并且這兩種方法都需要輸入層和輸出(從 需要self.call)。
現(xiàn)在調(diào)用summary將給出確切的輸出形狀。但是它會(huì)顯示Input層,它不是子類模型的一部分。
from tensorflow import keras
from tensorflow.keras import layers as klayers
class MLP(keras.Model):
def __init__(self, input_shape=(32), **kwargs):
super(MLP, self).__init__(**kwargs)
# Add input layer
self.input_layer = klayers.Input(input_shape)
self.dense_1 = klayers.Dense(64, activation='relu')
self.dense_2 = klayers.Dense(10)
# Get output layer with `call` method
self.out = self.call(self.input_layer)
# Reinitial
super(MLP, self).__init__(
inputs=self.input_layer,
outputs=self.out,
**kwargs)
def build(self):
# Initialize the graph
self._is_graph_network = True
self._init_graph_network(
inputs=self.input_layer,
outputs=self.out
)
def call(self, inputs):
x = self.dense_1(inputs)
return self.dense_2(x)
if __name__ == '__main__':
mlp = MLP(16)
mlp.summary()
輸出將是:
Model: "mlp_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 16)] 0
_________________________________________________________________
dense (Dense) (None, 64) 1088
_________________________________________________________________
dense_1 (Dense) (None, 10) 650
=================================================================
Total params: 1,738
Trainable params: 1,738
Non-trainable params: 0
_________________________________________________________________
TA貢獻(xiàn)2051條經(jīng)驗(yàn) 獲得超10個(gè)贊
我解決問題的方式與Elazar提到的非常相似。覆蓋類中的函數(shù) summary() subclass。然后您可以在使用模型子類化時(shí)直接調(diào)用 summary() :
class subclass(Model):
def __init__(self):
...
def call(self, x):
...
def summary(self):
x = Input(shape=(24, 24, 3))
model = Model(inputs=[x], outputs=self.call(x))
return model.summary()
if __name__ == '__main__':
sub = subclass()
sub.summary()
添加回答
舉報(bào)
