提高多類圖像分類器的準(zhǔn)確性
我正在使用 Food-101 數(shù)據(jù)集構(gòu)建分類器。該數(shù)據(jù)集具有預(yù)定義的訓(xùn)練集和測(cè)試集,均已標(biāo)記。它共有 101,000 張圖像。我正在嘗試為 top-1 構(gòu)建一個(gè) >=90% 準(zhǔn)確度的分類器模型。我目前坐在 75%。提供的訓(xùn)練集不干凈。但是現(xiàn)在,我想知道我可以改進(jìn)我的模型的一些方法以及我做錯(cuò)了什么。我已將訓(xùn)練圖像和測(cè)試圖像劃分到各自的文件夾中。在這里,我使用 0.2 的訓(xùn)練數(shù)據(jù)集通過(guò)運(yùn)行 5 個(gè)時(shí)期來(lái)驗(yàn)證學(xué)習(xí)者。np.random.seed(42)data = ImageList.from_folder(path).split_by_rand_pct(valid_pct=0.2).label_from_re(pat=file_parse).transform(size=224).databunch()top_1 = partial(top_k_accuracy, k=1)learn = cnn_learner(data, models.resnet50, metrics=[accuracy, top_1], callback_fns=ShowGraph)learn.fit_one_cycle(5)epoch train_loss valid_loss accuracy top_k_accuracy time0 2.153797 1.710803 0.563498 0.563498 19:261 1.677590 1.388702 0.637096 0.637096 18:292 1.385577 1.227448 0.678746 0.678746 18:363 1.154080 1.141590 0.700924 0.700924 18:344 1.003366 1.124750 0.707063 0.707063 18:25在這里,我試圖找到學(xué)習(xí)率。在講座中的表現(xiàn)非常標(biāo)準(zhǔn):learn.lr_find()learn.recorder.plot(suggestion=True)LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.Min numerical gradient: 1.32E-06Min loss divided by 10: 6.31E-08使用 1e-06 的學(xué)習(xí)率再運(yùn)行 5 個(gè) epoch。將其保存為 stage-2learn.fit_one_cycle(5, max_lr=slice(1.e-06))learn.save('stage-2')epoch train_loss valid_loss accuracy top_k_accuracy time0 0.940980 1.124032 0.705809 0.705809 18:181 0.989123 1.122873 0.706337 0.706337 18:242 0.963596 1.121615 0.706733 0.706733 18:383 0.975916 1.121084 0.707195 0.707195 18:274 0.978523 1.123260 0.706403 0.706403 17:04之前我總共運(yùn)行了 3 個(gè)階段,但模型沒(méi)有改進(jìn)超過(guò) 0.706403,所以我不想重復(fù)。下面是我的混淆矩陣。我為糟糕的決議道歉。這是 Colab 的功勞。因?yàn)槲乙呀?jīng)創(chuàng)建了一個(gè)額外的驗(yàn)證集,所以我決定使用測(cè)試集來(lái)驗(yàn)證已保存的 stage-2 模型,看看它的表現(xiàn)如何:path = '/content/food-101/images'data_test = ImageList.from_folder(path).split_by_folder(train='train', valid='test').label_from_re(file_parse).transform(size=224).databunch()learn.load('stage-2')learn.validate(data_test.valid_dl)這是結(jié)果:[0.87199837, tensor(0.7584), tensor(0.7584)]
查看完整描述