第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

在 Tensorflow 中 - 是否可以將特定的卷積濾波器鎖定在一個層中,或者將它們完全刪除?

在 Tensorflow 中 - 是否可以將特定的卷積濾波器鎖定在一個層中,或者將它們完全刪除?

HUH函數(shù) 2022-09-27 16:44:32
在 Tensorflow 中使用遷移學習時,我知道可以通過執(zhí)行以下操作來鎖定層,使其無法進行進一步的訓練:for layer in pre_trained_model.layers:     layer.trainable = False是否可以將特定濾鏡鎖定在圖層中?如 - 如果整個圖層包含 64 個過濾器,是否可以:只鎖定其中一些,似乎包含合理的過濾器并重新訓練那些沒有的過濾器?或從圖層中刪除看起來不合理的過濾器,并在沒有它們的情況下重新訓練?(例如,查看重新訓練的過濾器是否會發(fā)生很大變化)
查看完整描述

1 回答

?
千萬里不及你

TA貢獻1784條經驗 獲得超9個贊

一種可能的解決方案是實現(xiàn)自定義層,該層將卷積拆分為單獨的卷積,并將每個通道(具有一個輸出通道的卷積)設置為 或 設置為 。例如:number of filterstrainablenot trainable


import tensorflow as tf

import numpy as np


class Conv2DExtended(tf.keras.layers.Layer):

    def __init__(self, filters, kernel_size, **kwargs):

        self.filters = filters

        self.conv_layers = [tf.keras.layers.Conv2D(1, kernel_size, **kwargs) for _ in range(filters)]

        super().__init__()


    def build(self, input_shape):

        _ = [l.build(input_shape) for l in self.conv_layers]

        super().build(input_shape)


    def set_trainable(self, channels):

        """Sets trainable channels."""

        for i in channels:

            self.conv_layers[i].trainable = True


    def set_non_trainable(self, channels):

        """Sets not trainable channels."""

        for i in channels:

            self.conv_layers[i].trainable = False


    def call(self, inputs):

        results = [l(inputs) for l in self.conv_layers]

        return tf.concat(results, -1)

和用法示例:


inputs = tf.keras.layers.Input((28, 28, 1))

conv = Conv2DExtended(filters=4, kernel_size=(3, 3))

conv.set_non_trainable([1, 2]) # only channels 0 and 3 are trainable

res = conv(inputs)

res = tf.keras.layers.Flatten()(res)

res = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(res)


model = tf.keras.models.Model(inputs, res)

model.compile(optimizer=tf.keras.optimizers.SGD(),

              loss='binary_crossentropy',

              metrics=['accuracy'])

model.fit(np.random.normal(0, 1, (10, 28, 28, 1)),

          np.random.randint(0, 2, (10)),

          batch_size=2,

          epochs=5)


查看完整回答
反對 回復 2022-09-27
  • 1 回答
  • 0 關注
  • 110 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網(wǎng)微信公眾號