第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號(hào)安全,請(qǐng)及時(shí)綁定郵箱和手機(jī)立即綁定
已解決430363個(gè)問題,去搜搜看,總會(huì)有你想問的

如何為PyTorch中的掩碼R-CNN預(yù)測(cè)中的圖像生成準(zhǔn)確的掩碼?

如何為PyTorch中的掩碼R-CNN預(yù)測(cè)中的圖像生成準(zhǔn)確的掩碼?

慕斯709654 2022-09-27 16:15:12
我已經(jīng)訓(xùn)練了一個(gè)掩碼RCNN網(wǎng)絡(luò),例如蘋果的分割。我能夠加載權(quán)重并為我的測(cè)試圖像生成預(yù)測(cè)。正在生成的掩碼似乎位于正確的位置,但掩模本身沒有真正的形式。它看起來(lái)就像一堆像素訓(xùn)練是根據(jù)本文中的數(shù)據(jù)集完成的,以下是用于訓(xùn)練和生成權(quán)重的代碼的github鏈接預(yù)測(cè)代碼如下。(我省略了我創(chuàng)建路徑變量并分配路徑的部分)import osimport globimport numpy as npimport pandas as pdimport cv2 as cvimport fileinputimport torchimport torch.utils.dataimport torchvisionfrom data.apple_dataset import AppleDatasetfrom torchvision.models.detection.faster_rcnn import FastRCNNPredictorfrom torchvision.models.detection.mask_rcnn import MaskRCNNPredictorimport utility.utils as utilsimport utility.transforms as Tfrom PIL import Imagefrom matplotlib import pyplot as plt%matplotlib inlinedef get_transform(train):    transforms = []    transforms.append(T.ToTensor())    if train:        transforms.append(T.RandomHorizontalFlip(0.5))    return T.Compose(transforms)def get_maskrcnn_model_instance(num_classes):    # load an instance segmentation model pre-trained pre-trained on COCO    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)    # get number of input features for the classifier    in_features = model.roi_heads.box_predictor.cls_score.in_features    # replace the pre-trained head with a new one    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)    # now get the number of input features for the mask classifier    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels    hidden_layer = 256    # and replace the mask predictor with a new one    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)    return modelnum_classes = 2device = torch.device('cpu')model = get_maskrcnn_model_instance(num_classes)checkpoint = torch.load('model_49.pth', map_location=device)model.load_state_dict(checkpoint['model'], strict=False)dataset_test = AppleDataset(test_image_files_path, get_transform(train=False))img, _ = dataset_test[1]model.eval()with torch.no_grad():    prediction = model([img.to(device)])
查看完整描述

1 回答

?
溫溫醬

TA貢獻(xiàn)1752條經(jīng)驗(yàn) 獲得超4個(gè)贊

來(lái)自掩碼 R-CNN 的預(yù)測(cè)具有以下結(jié)構(gòu):


在推理過(guò)程中,模型只需要輸入張量,并將后處理的預(yù)測(cè)作為 ,每個(gè)輸入圖像返回一個(gè)。的字段如下:List[Dict[Tensor]]Dict


boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between 0 and H and 0 and W  

labels (Int64Tensor[N]): the predicted labels for each image  

scores (Tensor[N]): the scores or each prediction  

masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range.

您可以使用 OpenCV 和函數(shù)來(lái)繪制蒙版,如下所示:findContoursdrawContours


img_cv = cv2.imread('input.jpg', cv2.COLOR_BGR2RGB)


for i in range(len(prediction[0]['masks'])):

    # iterate over masks

    mask = prediction[0]['masks'][i, 0]

    mask = mask.mul(255).byte().cpu().numpy()

    contours, _ = cv2.findContours(

            mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)

    cv2.drawContours(img_cv, contours, -1, (255, 0, 0), 2, cv2.LINE_AA)


cv2.imshow('img output', img_cv)


查看完整回答
反對(duì) 回復(fù) 2022-09-27
  • 1 回答
  • 0 關(guān)注
  • 183 瀏覽
慕課專欄
更多

添加回答

舉報(bào)

0/150
提交
取消
微信客服

購(gòu)課補(bǔ)貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動(dòng)學(xué)習(xí)伙伴

公眾號(hào)

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號(hào)