当前位置: 首页>JAVA>正文

PyTorch學習筆記(15) ——PyTorch中的contiguous

PyTorch學習筆記(15) ——PyTorch中的contiguous

本文轉載自栩風在知乎上的文章《PyTorch中的contiguous》。我覺得很好,特此轉載。

0. 前言

本文講解了pytorch中contiguous的含義、定義、實現,以及contiguous存在的原因,非contiguous時的解決辦法。并對比了numpy中的contiguous。


contiguous 本身是形容詞,表示連續的,關于 contiguous,PyTorch 提供了is_contiguouscontiguous(形容詞動用)兩個方法 ,分別用于判定Tensor是否是 contiguous 的,以及保證Tensor是contiguous的。

1.PyTorch中的is_contiguous是什么含義?

is_contiguous直觀的解釋: Tensor底層一維數組元素的存儲順序與Tensor按行優先一維展開的元素順序是否一致。

Tensor多維數組底層實現是使用一塊連續內存的1維數組(行優先順序存儲,下文描述),Tensor在元信息里保存了多維數組的形狀,在訪問元素時,通過多維度索引轉化成1維數組相對于數組起始位置的偏移量即可找到對應的數據。某些Tensor操作(如transpose、permute、narrow、expand)與原Tensor是共享內存中的數據,不會改變底層數組的存儲,但原來在語義上相鄰、內存里也相鄰的元素在執行這樣的操作后,在語義上相鄰,但在內存不相鄰,即不連續了(is not contiguous)。

如果想要變得連續使用contiguous方法,如果Tensor不是連續的,則會重新開辟一塊內存空間保證數據是在內存中是連續的,如果Tensor是連續的,則contiguous無操作。

1.1 行優先

行是指多維數組一維展開的方式,對應的是列優先。C/C++中使用的是行優先方式(row major),Matlab、Fortran使用的是列優先方式(column major),PyTorch中Tensor底層實現是C,也是使用行優先順序。舉例說明如下:

>>> t = torch.arange(12).reshape(3,4)
>>> t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])

二維數組 t 如圖1:
在這里插入圖片描述

數組 t 在內存中實際以一維數組形式存儲,通過flatten方法查看 t 的一維展開形式,實際存儲形式與一維展開一致,如圖2,

>>> t.flatten()
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

在這里插入圖片描述

說明:圖1、圖2來自:What is the difference between contiguous and non-contiguous arrays? 圖1、圖2 中顏色相同的數據表示在同一行,不論是行優先順序、或是列優先順序,如果要訪問矩陣中的下一個元素都是通過偏移來實現,這個偏移量稱為步長(stride[1])。在行優先的存儲方式下,訪問行中相鄰元素物理結構需要偏移1個位置,在列優先存儲方式下偏移3個位置。

2. 為什么需要 contiguous ?

  • torch.view等方法操作需要連續的Tensor。

transposepermute 操作雖然沒有修改底層一維數組,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。torch.view 方法約定了不修改數組本身,只是使用新的形狀查看數據。如果我們在 transpose、permute 操作后執行 view,Pytorch 會拋出以下錯誤:

invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension 
spans across two contiguous subspaces). Call .contiguous() before .view(). 
at /Users/soumith/b101_2/2019_02_08/wheel_build_dirs/wheel_3.6/pytorch/aten/src/TH/generic/THTensor.cpp:213

為什么 view 方法要求Tensor是連續的?考慮以下操作,

>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
>>>t.stride()
(4, 1)
>>>t2 = t.transpose(0,1)
>>>t2
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])
>>>t2.stride()
(1, 4)
>>>t.data_ptr() == t2.data_ptr() # 底層數據是同一個一維數組
True
>>>t.is_contiguous(),t2.is_contiguous() # t連續,t2不連續
(True, False)

t2 與 t 引用同一份底層數據 a: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], t2和t兩者僅是stride、shape不同。如果執行 t2.view(-1),期望返回b(但實際會報錯)[ 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]

可以看出,在 a 的基礎上使用一個新的 stride 無法直接得到 b ,需要先使用 t2 的 stride (1, 4) 轉換到 t2 的結構,再基于 t2 的結構使用 stride (1,) 轉換為形狀為 (12,)的 b 。但這不是view工作的方式,view僅在底層數組上使用指定的形狀進行變形,即使 view 不報錯,它返回的數據是 [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 而對t2而言,顯然我們的目標是獲取[ 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11], 那么,如果我們想得到后者,該怎么辦呢?

>>>t3 = t2.contiguous()
>>>t3
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])
>>>t3.data_ptr() == t2.data_ptr() # 底層數據不是同一個一維數組
False

可以看到,我們用t3 = t2.contiguous()即可, 開辟了一塊新的內存空間給t3,也就是說:t與t2 底層數據指針一致,t3 與 t2 底層數據指針不一致,說明確實重新開辟了內存空間。

3. 為什么不在view 方法中默認調用contiguous方法?

  • ① 歷史原因
    因為歷史上view方法已經約定了共享底層數據內存,返回的Tensor底層數據不會使用新的內存,如果在view中調用了contiguous方法,則可能在返回Tensor底層數據中使用了新的內存,這樣打破了之前的約定,破壞了對之前的代碼兼容性。為了解決用戶使用便捷性問題,PyTorch在0.4版本以后提供了reshape方法,實現了類似于 tensor.contigous().view(*args)的功能,如果不關心底層數據是否使用了新的內存,則使用reshape方法更方便。 [3]

  • ② 出于性能考慮(保證Tensor語義順序和邏輯順序的一致性)

對連續的Tensor來說,語義上相鄰的元素,在內存中也是連續的,訪問相鄰元素是矩陣運算中經常用到的操作,語義和內存順序的一致性是緩存友好的(What is a “cache-friendly” code?[4]),在內存中連續的數據可以(但不一定)被高速緩存預取,以提升CPU獲取操作數據的速度。transposepermute 后使用 contiguous 方法則會重新開辟一塊內存空間保證數據是在邏輯順序和內存中是一致的,連續內存布局減少了CPU對對內存的請求次數(訪問內存比訪問寄存器慢100倍[5]),相當于空間換時間。

4. PyTorch中判讀張量是否連續的實現

PyTorch中通過調用 is_contiguous 方法判斷 tensor 是否連續,底層實現為 TH 庫中THTensor.isContiguous 方法:

int THTensor_(isContiguous)(const THTensor *self)
{long z = 1;int d;for(d = self->nDimension-1; d >= 0; d--){if(self->size[d] != 1){if(self->stride[d] == z)z *= self->size[d];elsereturn 0;}}return 1;
}

為方便加上一些調試信息,翻譯為 Python 代碼如下:

def isContiguous(tensor):"""判斷tensor是否連續    :param torch.Tensor tensor: :return: bool"""z = 1d = tensor.dim() - 1size = tensor.size()stride = tensor.stride()print("stride={} size={}".format(stride, size))while d >= 0:if size[d] != 1:if stride[d] == z:print("dim {} stride is {}, next stride should be {} x {}".format(d, stride[d], z, size[d]))z *= size[d]                else:print("dim {} is not contiguous. stride is {}, but expected {}".format(d, stride[d], z))return Falsed -= 1return True

判定上文中 t、t2 是否連續的輸出如下:

>>>isContiguous(t)
stride=(4, 1) size=torch.Size([3, 4])
dim 1 stride is 1, next stride should be 1 x 4
dim 0 stride is 4, next stride should be 4 x 3True
>>>isContiguous(t2)
stride=(1, 4) size=torch.Size([4, 3])
dim 1 is not contiguous. stride is 4, but expected 1False

isContiguous 實現可以看出,最后1維的 stride 必須為z = 1(邏輯步長),這是合理的,最后1維即邏輯結構上最內層數組,其相鄰元素間隔位數為1,按行優先順序排列時,最內層數組相鄰元素間隔應該為1。

參考資料

[3] view() after transpose() raises non contiguous error #764
[4] What is a “cache-friendly” code?
[5] 計算機緩存Cache以及Cache Line詳解

https://www.zydui.com/af833VW8CDQ9VAF8I.html
>

相关文章:

  • PyTorch學習筆記(15) ——PyTorch中的contiguous
  • 英雄聯盟英雄名單部分功能
  • 英雄聯盟皮膚爬蟲
  • java中for循環執行順序
  • for循環執行順序詳解(避坑)
  • DS18B20序列號的讀取
  • 上古卷軸java怎么刷_上古卷軸5快速升級方法一覽 教你如何快速升級
  • 上古卷軸 java_我打通了197KB的《上古卷軸四:湮滅》,諾基亞手機上的那一種...
  • 上古世紀服務器維護,9月22日臨時維護修改會員排隊問題服務器擴容公告
  • 上古世紀服務器維護真情禮,【已開服】4月15日經典服例行維護版本更新公告
  • 上古卷軸ol java_上古卷軸ol怎么滿級快
  • 全網最詳細解釋memcached中的flags含義
  • FLAGS寄存器 標志寄存器 英文全稱 方便記憶
  • Java爬蟲實戰第二篇:IOS、安卓應用爬蟲
  • 什么是storedownloaded,為什么在Mac上運行?
  • CSDN富文本編輯器回車行間距過大問題的解決
  • 批量處理word所有回車行
  • shell判斷字符串變量是否為空,包括純空格、空串、回車行是空白行等
  • 計算機畢業設計之開山車行二手車交易系統
  • Java實現“xx車行管理系統”
  • Java面試案例-車行易
  • 車行管理系統 java小作業
  • 3個躺著賺錢的神仙副業
  • 計算機中的windows任務管理器在哪,Win10系統中的explorer.exe在哪?怎么重啟Windows資源管理器?...
  • Win11查看文件資源管理器選項卡的方法
  • 為什么你總get不到增長玩法背后的邏輯?
  • JDBC連接MySQL數據庫(一)
  • 從零學Java(18)之三元運算符
  • CGU APAC 2017盛大開幕,七彩虹與英偉達聯手打造電競盛宴
  • AP 計算機 華麗逆襲-----被麻省理工計算機博士老師反復勸退的學生逆襲的肺腑之言