第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號安全,請及時綁定郵箱和手機立即綁定
已解決430363個問題,去搜搜看,總會有你想問的

訓練深度學習模型時出錯

訓練深度學習模型時出錯

HUX布斯 2023-03-22 10:54:01
所以我設計了一個 CNN 并使用以下參數(shù)進行編譯,training_file_loc = "8-SignLanguageMNIST/sign_mnist_train.csv"testing_file_loc = "8-SignLanguageMNIST/sign_mnist_test.csv"def getData(filename):    images = []    labels = []    with open(filename) as csv_file:        file = csv.reader(csv_file, delimiter = ",")        next(file, None)                for row in file:            label = row[0]            data = row[1:]            img = np.array(data).reshape(28,28)                        images.append(img)            labels.append(label)                images = np.array(images).astype("float64")        labels = np.array(labels).astype("float64")            return images, labelstraining_images, training_labels = getData(training_file_loc)testing_images, testing_labels = getData(testing_file_loc)print(training_images.shape, training_labels.shape)print(testing_images.shape, testing_labels.shape)training_images = np.expand_dims(training_images, axis = 3)testing_images = np.expand_dims(testing_images, axis = 3)training_datagen = ImageDataGenerator(    rescale = 1/255,    rotation_range = 45,    width_shift_range = 0.2,    height_shift_range = 0.2,    shear_range = 0.2,    zoom_range = 0.2,    horizontal_flip = True,    fill_mode = "nearest")training_generator = training_datagen.flow(    training_images,    training_labels,    batch_size = 64,)validation_datagen = ImageDataGenerator(    rescale = 1/255,    rotation_range = 45,    width_shift_range = 0.2,    height_shift_range = 0.2,    shear_range = 0.2,    zoom_range = 0.2,    horizontal_flip = True,    fill_mode = "nearest")validation_generator = training_datagen.flow(    testing_images,    testing_labels,    batch_size = 64,])但是,當我運行 model.fit() 時,出現(xiàn)以下錯誤,ValueError: Shapes (None, 1) and (None, 24) are incompatible將損失函數(shù)更改為 后sparse_categorical_crossentropy,程序運行良好。我不明白為什么會這樣。誰能解釋這一點以及這些損失函數(shù)之間的區(qū)別?
查看完整描述

2 回答

?
largeQ

TA貢獻2039條經(jīng)驗 獲得超8個贊

問題是,categorical_crossentropy期望單熱編碼標簽,這意味著,對于每個樣本,它期望一個長度張量,num_classes其中l(wèi)abel第 th 個元素設置為 1,其他所有元素都為 0。


另一方面,sparse_categorical_crossentropy直接使用整數(shù)標簽(因為這里的用例是大量的類,所以單熱編碼標簽會浪費大量零的內(nèi)存)。我相信,但我無法證實這一點,它categorical_crossentropy比它的稀疏對應物運行得更快。


對于您的情況,對于 26 個類,我建議使用非稀疏版本并將您的標簽轉換為單熱編碼,如下所示:


def getData(filename):

    images = []

    labels = []

    with open(filename) as csv_file:

        file = csv.reader(csv_file, delimiter = ",")

        next(file, None)

        

        for row in file:

            label = row[0]

            data = row[1:]

            img = np.array(data).reshape(28,28)

            

            images.append(img)

            labels.append(label)

        

        images = np.array(images).astype("float64")

        labels = np.array(labels).astype("float64")

        

    return images, tf.keras.utils.to_categorical(labels, num_classes=26) # you can omit num_classes to have it computed from the data

旁注:除非你有理由使用float64圖像,否則我會切換到float32(它將數(shù)據(jù)集所需的內(nèi)存減半,并且模型可能會將它們轉換為float32第一個操作)


查看完整回答
反對 回復 2023-03-22
?
BIG陽

TA貢獻1859條經(jīng)驗 獲得超6個贊

很簡單,對于輸出類為整數(shù)的分類問題,使用 sparse_categorical_crosentropy,對于標簽在一個熱編碼標簽中轉換的問題,我們使用 categorical_crosentropy。



查看完整回答
反對 回復 2023-03-22
  • 2 回答
  • 0 關注
  • 179 瀏覽
慕課專欄
更多

添加回答

舉報

0/150
提交
取消
微信客服

購課補貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動學習伙伴

公眾號

掃描二維碼
關注慕課網(wǎng)微信公眾號