第七色在线视频,2021少妇久久久久久久久久,亚洲欧洲精品成人久久av18,亚洲国产精品特色大片观看完整版,孙宇晨将参加特朗普的晚宴

為了賬號(hào)安全,請(qǐng)及時(shí)綁定郵箱和手機(jī)立即綁定
已解決430363個(gè)問題,去搜搜看,總會(huì)有你想問的

使用方括號(hào)對(duì) Pytorch 張量進(jìn)行子集化

使用方括號(hào)對(duì) Pytorch 張量進(jìn)行子集化

呼啦一陣風(fēng) 2022-07-26 16:00:51
我遇到了一行代碼,用于在 PyTorch 中將 3D 張量簡化為 2D 張量。3D 張量x的大小torch.Size([500, 50, 1])和這行代碼:x = x[lengths - 1, range(len(lengths))]用于減少x到大小為 的 2D 張量torch.Size([50, 1])。lengths也是一個(gè)torch.Size([50])包含值的形狀張量。請(qǐng)任何人解釋這是如何工作的?謝謝你。
查看完整描述

2 回答

?
一只萌萌小番薯

TA貢獻(xiàn)1795條經(jīng)驗(yàn) 獲得超7個(gè)贊

這里的關(guān)鍵特性是將張量的值lengths作為 的索引傳遞x。這里簡化的例子,我交換了容器的尺寸,所以 index dimenson 首先:


container = torch.arange(0, 50 )

container = f.reshape((5, 10))

>>>tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],

        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],

        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],

        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],

        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])


indices = torch.arange( 2, 7, dtype=torch.long )

>>>tensor([2, 3, 4, 5, 6])


print( container[ range( len(indices) ), indices] )

>>>tensor([ 2, 13, 24, 35, 46])    

注意:我們從一行中得到一件事(range( len(indices) )產(chǎn)生連續(xù)的行號(hào)),列號(hào)由索引[ row_number ]


查看完整回答
反對(duì) 回復(fù) 2022-07-26
?
素胚勾勒不出你

TA貢獻(xiàn)1827條經(jīng)驗(yàn) 獲得超9個(gè)贊

在被這種行為難住之后,我對(duì)此進(jìn)行了更多挖掘,發(fā)現(xiàn)它與多維 NumPy 數(shù)組的索引行為一致。使這種違反直覺的原因是兩個(gè)數(shù)組必須具有相同的長度這一不太明顯的事實(shí),即在這種情況下len(lengths)。


事實(shí)上,它的工作原理如下: *lengths確定您訪問第一個(gè)維度的順序。即,如果您有一個(gè)一維數(shù)組a = [0, 1, 2, ...., 500],并使用 list 訪問它b = [300, 200, 100],那么結(jié)果a[b] = [301, 201, 101](這也解釋了lengths - 1運(yùn)算符,它只會(huì)導(dǎo)致訪問的值與分別在b、 或lengths中使用的索引相同)。*range(len(lengths))然后 * 只需選擇第 - 行i中的第 - 個(gè)元素i。如果您有一個(gè)方陣,您可以將其解釋為矩陣的對(duì)角線。由于您只能訪問前兩個(gè)維度上每個(gè)位置的單個(gè)元素,因此可以將其存儲(chǔ)在一個(gè)維度中(從而將您的 3D 張量減少到 2D)。后一個(gè)維度簡單地保持“原樣”。


如果你想玩這個(gè),我強(qiáng)烈建議將range()值更改為更長/更短的值,這將導(dǎo)致以下錯(cuò)誤:


IndexError:形狀不匹配:索引數(shù)組無法與形狀(x,)(y,)一起廣播


其中x和y是您的特定長度值。


要以長形式編寫此訪問方法以了解“幕后”發(fā)生的情況,還請(qǐng)考慮以下示例:


import torch

x = torch.randint(500, 50, 1)

lengths = torch.tensor([2, 30, 1, 4])  # random examples to explore

diag = list(range(len(lengths)))  # [0, 1, 2, 3]

result = []

for i, row in enumerate(lengths):

    temp_tensor = x[row, :, :]  # temp_tensor.shape = [1, 50, 1]

    temp_tensor = temp_tensor.squeeze(0)[diag[i]]  # temp_tensor.shape = [1, 1]

    result.append(temp.tensor)


# back to pytorch

result = torch.tensor(result)

result.shape  # [4, 1]


查看完整回答
反對(duì) 回復(fù) 2022-07-26
  • 2 回答
  • 0 關(guān)注
  • 90 瀏覽
慕課專欄
更多

添加回答

舉報(bào)

0/150
提交
取消
微信客服

購課補(bǔ)貼
聯(lián)系客服咨詢優(yōu)惠詳情

幫助反饋 APP下載

慕課網(wǎng)APP
您的移動(dòng)學(xué)習(xí)伙伴

公眾號(hào)

掃描二維碼
關(guān)注慕課網(wǎng)微信公眾號(hào)