第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號(hào)安全,請(qǐng)及時(shí)綁定郵箱和手機(jī)立即綁定
已解決430363個(gè)問題,去搜搜看,總會(huì)有你想問的

使用子類模型時(shí),model.summary() 無法打印輸出形狀

使用子類模型時(shí),model.summary() 無法打印輸出形狀

嗶嗶one 2021-12-09 14:43:09
這是創(chuàng)建keras模型的兩種方法,但是output shapes兩種方法的匯總結(jié)果不同。顯然,前者打印的信息更多,更容易檢查網(wǎng)絡(luò)的正確性。import tensorflow as tffrom tensorflow.keras import Input, layers, Modelclass subclass(Model):    def __init__(self):        super(subclass, self).__init__()        self.conv = layers.Conv2D(28, 3, strides=1)    def call(self, x):        return self.conv(x)def func_api():    x = Input(shape=(24, 24, 3))    y = layers.Conv2D(28, 3, strides=1)(x)    return Model(inputs=[x], outputs=[y])if __name__ == '__main__':    func = func_api()    func.summary()    sub = subclass()    sub.build(input_shape=(None, 24, 24, 3))    sub.summary()輸出:_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_1 (InputLayer)         (None, 24, 24, 3)         0         _________________________________________________________________conv2d (Conv2D)              (None, 22, 22, 28)        784       =================================================================Total params: 784Trainable params: 784Non-trainable params: 0__________________________________________________________________________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================conv2d_1 (Conv2D)            multiple                  784       =================================================================Total params: 784Trainable params: 784Non-trainable params: 0_________________________________________________________________那么,我應(yīng)該如何使用子類方法來獲取output shape摘要()?
查看完整描述

3 回答

?
小怪獸愛吃肉

TA貢獻(xiàn)1852條經(jīng)驗(yàn) 獲得超1個(gè)贊

我已經(jīng)用這個(gè)方法解決了這個(gè)問題,不知道有沒有更簡(jiǎn)單的方法。


class subclass(Model):

    def __init__(self):

        ...

    def call(self, x):

        ...


    def model(self):

        x = Input(shape=(24, 24, 3))

        return Model(inputs=[x], outputs=self.call(x))




if __name__ == '__main__':

    sub = subclass()

    sub.model().summary()


查看完整回答
反對(duì) 回復(fù) 2021-12-09
?
Qyouu

TA貢獻(xiàn)1786條經(jīng)驗(yàn) 獲得超11個(gè)贊

我想關(guān)鍵點(diǎn)是_init_graph_network類中的方法Network,它是Model_init_graph_network如果在調(diào)用方法時(shí)指定inputsoutputs參數(shù),將被調(diào)用__init__

所以會(huì)有兩種可能的方法:

  1. 手動(dòng)調(diào)用_init_graph_network方法構(gòu)建模型圖。

  2. 用輸入層和輸出重新初始化。

并且這兩種方法都需要輸入層和輸出(從 需要self.call)。

現(xiàn)在調(diào)用summary將給出確切的輸出形狀。但是它會(huì)顯示Input層,它不是子類模型的一部分。

from tensorflow import keras

from tensorflow.keras import layers as klayers


class MLP(keras.Model):

    def __init__(self, input_shape=(32), **kwargs):

        super(MLP, self).__init__(**kwargs)

        # Add input layer

        self.input_layer = klayers.Input(input_shape)


        self.dense_1 = klayers.Dense(64, activation='relu')

        self.dense_2 = klayers.Dense(10)


        # Get output layer with `call` method

        self.out = self.call(self.input_layer)


        # Reinitial

        super(MLP, self).__init__(

            inputs=self.input_layer,

            outputs=self.out,

            **kwargs)


    def build(self):

        # Initialize the graph

        self._is_graph_network = True

        self._init_graph_network(

            inputs=self.input_layer,

            outputs=self.out

        )


    def call(self, inputs):

        x = self.dense_1(inputs)

        return self.dense_2(x)


if __name__ == '__main__':

    mlp = MLP(16)

    mlp.summary()

輸出將是:


Model: "mlp_1"

_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

input_1 (InputLayer)         [(None, 16)]              0         

_________________________________________________________________

dense (Dense)                (None, 64)                1088      

_________________________________________________________________

dense_1 (Dense)              (None, 10)                650       

=================================================================

Total params: 1,738

Trainable params: 1,738

Non-trainable params: 0

_________________________________________________________________


查看完整回答
反對(duì) 回復(fù) 2021-12-09
?
侃侃無極

TA貢獻(xiàn)2051條經(jīng)驗(yàn) 獲得超10個(gè)贊

我解決問題的方式與Elazar提到的非常相似。覆蓋類中的函數(shù) summary() subclass。然后您可以在使用模型子類化時(shí)直接調(diào)用 summary() :


class subclass(Model):

    def __init__(self):

        ...

    def call(self, x):

        ...


    def summary(self):

        x = Input(shape=(24, 24, 3))

        model = Model(inputs=[x], outputs=self.call(x))

        return model.summary()


if __name__ == '__main__':

    sub = subclass()

    sub.summary()


查看完整回答
反對(duì) 回復(fù) 2021-12-09
  • 3 回答
  • 0 關(guān)注
  • 936 瀏覽
慕課專欄
更多

添加回答

舉報(bào)

0/150
提交
取消
微信客服

購(gòu)課補(bǔ)貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動(dòng)學(xué)習(xí)伙伴

公眾號(hào)

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號(hào)