第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號安全,請及時綁定郵箱和手機(jī)立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

使用tensorflow梯度磁帶的內(nèi)存不足但僅在我附加列表時發(fā)生

使用tensorflow梯度磁帶的內(nèi)存不足但僅在我附加列表時發(fā)生

搖曳的薔薇 2022-11-24 15:25:19
我一直在使用 CNN 處理數(shù)據(jù)集 (1000,3253)。我正在通過梯度磁帶運行梯度計算,但它一直在耗盡內(nèi)存。然而,如果我刪除將梯度計算附加到列表的行,腳本將運行所有時期。我不完全確定為什么會發(fā)生這種情況,但我對 tensorflow 和漸變帶的使用也是陌生的。任何建議或意見將不勝感激        #create a batch loop    for x, y_true in train_dataset:                    #create a tape to record actions        with  tf.GradientTape(watch_accessed_variables=False) as tape:            x_var = tf.Variable(x)            tape.watch([model.trainable_variables,x_var])                y_pred = model(x_var,training=True)                tape.stop_recording()            loss = los_func(y_true, y_pred)        epoch_loss_avg.update_state(loss)        epoch_accuracy.update_state(y_true, y_pred)                        #pdb.set_trace()         gradients,something = tape.gradient(loss, (model.trainable_variables,x_var))        #sa_input.append(tape.gradient(loss, x_var))        del tape                    #apply gradients        sa_input.append(something)        opti_func.apply_gradients(zip(gradients, model.trainable_variables))     train_loss_results.append(epoch_loss_avg.result())    train_accuracy_results.append(epoch_accuracy.result())
查看完整描述

1 回答

?
呼喚遠(yuǎn)方

TA貢獻(xiàn)1856條經(jīng)驗 獲得超11個贊

由于您是 TF2 的新手,建議您閱讀本指南。本指南涵蓋 TensorFlow 2.0 中兩種廣泛情況下的訓(xùn)練、評估和預(yù)測(推理)模型:

  1. 使用內(nèi)置 API 進(jìn)行訓(xùn)練和驗證時(例如 model.fit()、model.evaluate()、model.predict())。這在“使用內(nèi)置訓(xùn)練和評估循環(huán)”部分中有所介紹。

  2. 使用 eager execution 和 GradientTape 對象從頭開始編寫自定義循環(huán)時。這在“從頭開始編寫您自己的訓(xùn)練和評估循環(huán)”一節(jié)中有所介紹。

下面是一個程序,我在其中計算每個紀(jì)元后的梯度并附加到列表中。在程序結(jié)束時,為了簡單起見,我將轉(zhuǎn)換listarray

代碼 -如果我使用多層和更大過濾器尺寸的深度網(wǎng)絡(luò),這個程序會拋出 OOM Error 錯誤

# Importing dependency

%tensorflow_version 2.x

from tensorflow import keras

from tensorflow.keras import backend as K

from tensorflow.keras import datasets

from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2D

from tensorflow.keras.layers import BatchNormalization

import numpy as np

import tensorflow as tf


# Import Data

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()


# Build Model

model = Sequential()

model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(32,32, 3)))

model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(64, (3, 3), activation='relu'))

model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(64, (3, 3), activation='relu'))

model.add(Flatten())

model.add(Dense(64, activation='relu'))

model.add(Dense(10))


# Model Summary

model.summary()


# Model Compile 

model.compile(optimizer='adam',

              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

              metrics=['accuracy'])


# Define the Gradient Fucntion

epoch_gradient = []

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)


# Define the Gradient Function

@tf.function

def get_gradient_func(model):

    with tf.GradientTape() as tape:

       logits = model(train_images, training=True)

       loss = loss_fn(train_labels, logits)    

    grad = tape.gradient(loss, model.trainable_weights)

    model.optimizer.apply_gradients(zip(grad, model.trainable_variables))

    return grad


# Define the Required Callback Function

class GradientCalcCallback(tf.keras.callbacks.Callback):

  def on_epoch_end(self, epoch, logs={}):

    grad = get_gradient_func(model)

    epoch_gradient.append(grad)


epoch = 4


print(train_images.shape, train_labels.shape)


model.fit(train_images, train_labels, epochs=epoch, validation_data=(test_images, test_labels), callbacks=[GradientCalcCallback()])


# (7) Convert to a 2 dimensiaonal array of (epoch, gradients) type

gradient = np.asarray(epoch_gradient)

print("Total number of epochs run:", epoch)

輸出 -


Model: "sequential_5"

_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

conv2d_12 (Conv2D)           (None, 30, 30, 32)        896       

_________________________________________________________________

max_pooling2d_8 (MaxPooling2 (None, 15, 15, 32)        0         

_________________________________________________________________

conv2d_13 (Conv2D)           (None, 13, 13, 64)        18496     

_________________________________________________________________

max_pooling2d_9 (MaxPooling2 (None, 6, 6, 64)          0         

_________________________________________________________________

conv2d_14 (Conv2D)           (None, 4, 4, 64)          36928     

_________________________________________________________________

flatten_4 (Flatten)          (None, 1024)              0         

_________________________________________________________________

dense_11 (Dense)             (None, 64)                65600     

_________________________________________________________________

dense_12 (Dense)             (None, 10)                650       

=================================================================

Total params: 122,570

Trainable params: 122,570

Non-trainable params: 0

_________________________________________________________________

(50000, 32, 32, 3) (50000, 1)

Epoch 1/4

1563/1563 [==============================] - 109s 70ms/step - loss: 1.7026 - accuracy: 0.4081 - val_loss: 1.4490 - val_accuracy: 0.4861

Epoch 2/4

1563/1563 [==============================] - 145s 93ms/step - loss: 1.2657 - accuracy: 0.5506 - val_loss: 1.2076 - val_accuracy: 0.5752

Epoch 3/4

1563/1563 [==============================] - 151s 96ms/step - loss: 1.1103 - accuracy: 0.6097 - val_loss: 1.1122 - val_accuracy: 0.6127

Epoch 4/4

1563/1563 [==============================] - 152s 97ms/step - loss: 1.0075 - accuracy: 0.6475 - val_loss: 1.0508 - val_accuracy: 0.6371

Total number of epochs run: 4

希望這能回答您的問題。快樂學(xué)習(xí)。


查看完整回答
反對 回復(fù) 2022-11-24
  • 1 回答
  • 0 關(guān)注
  • 138 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補(bǔ)貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動學(xué)習(xí)伙伴

公眾號

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號