第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號(hào)安全,請(qǐng)及時(shí)綁定郵箱和手機(jī)立即綁定

深度學(xué)習(xí)模型中的參數(shù)數(shù)量(備忘)

標(biāo)簽:
人工智能

原文地址:huay’ blog/模型中的参数数量(备忘)

记录模型参数数量的计算方法

最早使用 tensorflow 的时候没怎么注意这个问题;
后面高级 API 用的多了,有点忘记怎么计算模型的参数数量了;
特此记录以作备忘

参数来源

模型的参数数量 = 每一层的参数数量之和

每一层的参数数量需要由该层的规模(n_units上一层的输出(n_features共同决定

全连接层

input_shape: [batch_size, n_features]

output_shape: [batch_size, n_units]

参数数量 = n_features * n_units

参数数量 = n_features * n_units + n_units (使用偏置)

卷积层

Conv1D

input_shape: [batch_size, max_steps, n_features]

output_shape: [batch_size, new_steps, n_filters]

kernel_size = [kernel_w]

kernel_shape = [kernel_w, n_features, n_filters]

参数数量 = kernel_w * n_features * n_filters

参数数量 = kernel_w * n_features * n_filters + n_filters (使用偏置)

Conv2D

input_shape: [batch_size, in_height, in_width, in_channels]

output_shape: [batch_size, new_height, new_width, out_channels]

kernel_size = [kernel_h, kernel_w]

kernel_shape = [kernel_h, kernel_w, n_features, n_filters]

参数数量 = kernel_h * kernel_w * n_features * n_filters

参数数量 = kernel_h * kernel_w * n_features * n_filters + n_filters (使用偏置)

Conv3D

类似 Conv1D/Conv2D,略

RNN 层

具体有哪几部分参数不是非常理解
从参数数量回看,大概了解了 RNN 的参数情况:

先看基础 RNN 的计算公式:

h(t)=f(Ux(t)+Wh(t−1)+b)h(t)=f(Ux(t)+Wh(t1)+b)


可以看到参数有 3 个:UU,WW,bb

下面是公式中每个字母的 shape:

U: `[n_units, n_features]`
x: `[n_features, 1]`
W: `[n_units, n_units]`
h: `[n_units, 1]`
b: `[n_units, 1]`123456

这里hh没有考虑 batch_size,实际上应该是 [batch_size, n_units]

如果你喜欢手写 rnn,而不是直接使用 dynamic_rnn,那你肯定写过这句 initial_state = cell.zero_state(batch_size, dtype=tf.float32),这就是公式中的hh,但是这不算参数,而属于输出部分

(其实也可以看做是参数,只是不通过反向传播更新)

有了以上的基础,LSTM 和 GRU 的参数有哪些,具体看它们的模型图就能知道了

下面是使用 TF 的测试代码:

# 参数batch_size = 16max_steps = 5n_features = 32n_units = 64# 测试inputs = tf.placeholder(tf.float32, [batch_size, max_steps, n_features])

cell = cell_fn(n_units)

outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32)

tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)1234567891011121314

基础 RNN

cell_fn = tf.nn.rnn_cell.BasicRNNCell

测试结果:

[<tf.Variable 'rnn/basic_rnn_cell/kernel:0' shape=(96, 64) dtype=float32_ref>, <tf.Variable 'rnn/basic_rnn_cell/bias:0' shape=(64,) dtype=float32_ref>]12

参数数量 = (n_features + n_units) * n_units + n_units

根据 RNN 的计算公式f(Ux(t)+Wh(t−1)+b)f(Ux(t)+Wh(t1)+b),参数有 3 个:UU,WW,bb

其中每个字母的形状如下:

    U: `[n_units, n_features]`
    x: `[n_features, 1]`
    W: `[n_units, n_units]`
    h: `[n_units, 1]`
    b: `[n_units, 1]`123456

LSTM

cell_fn = tf.nn.rnn_cell.BasicLSTMCell
cell_fn = tf.nn.rnn_cell.LSTMCell

测试结果:

[<tf.Variable 'rnn/basic_lstm_cell/kernel:0' shape=(96, 256) dtype=float32_ref>, <tf.Variable 'rnn/basic_lstm_cell/bias:0' shape=(256,) dtype=float32_ref>]12

参数数量 = (n_features + n_units) * (n_units * 4) + (n_units * 4)

GRU

cell_fn = tf.nn.rnn_cell.GRUCell

测试结果:

[<tf.Variable 'rnn/gru_cell/gates/kernel:0' shape=(96, 128) dtype=float32_ref>, <tf.Variable 'rnn/gru_cell/gates/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'rnn/gru_cell/candidate/kernel:0' shape=(96, 64) dtype=float32_ref>, <tf.Variable 'rnn/gru_cell/candidate/bias:0' shape=(64,) dtype=float32_ref>]1234

参数数量 = (n_features + n_units) * (n_units * 3) + (n_units * 3)

GRU 比 LSTM 少了一个门(将遗忘门和输入门合成了一个单一的更新门),与结果一致

双向 rnn/lstm/gru

参数数量再乘 2 (当 cell_fw == cell_bw 时)

测试代码:

# 参数batch_size = 16max_steps = 5n_features = 32n_units = 64# 测试inputs = tf.placeholder(tf.float32, [batch_size, max_steps, n_features])

cell_fw = cell_fn(n_units)
cell_bw = cell_fn(n_units)

outputs, state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, dtype=tf.float32)

tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)


點(diǎn)擊查看更多內(nèi)容
TA 點(diǎn)贊

若覺(jué)得本文不錯(cuò),就分享一下吧!

評(píng)論

作者其他優(yōu)質(zhì)文章

正在加載中
  • 推薦
  • 評(píng)論
  • 收藏
  • 共同學(xué)習(xí),寫(xiě)下你的評(píng)論
感謝您的支持,我會(huì)繼續(xù)努力的~
掃碼打賞,你說(shuō)多少就多少
贊賞金額會(huì)直接到老師賬戶(hù)
支付方式
打開(kāi)微信掃一掃,即可進(jìn)行掃碼打賞哦
今天注冊(cè)有機(jī)會(huì)得

100積分直接送

付費(fèi)專(zhuān)欄免費(fèi)學(xué)

大額優(yōu)惠券免費(fèi)領(lǐng)

立即參與 放棄機(jī)會(huì)
微信客服

購(gòu)課補(bǔ)貼
聯(lián)系客服咨詢(xún)優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動(dòng)學(xué)習(xí)伙伴

公眾號(hào)

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號(hào)

舉報(bào)

0/150
提交
取消