当前位置: 首页>C语言>正文

PyTorch中的contiguous解讀

PyTorch中的contiguous解讀

本文講解了pytorch中contiguous的含義、定義、實現,以及contiguous存在的原因,非contiguous時的解決辦法。并對比了numpy中的contiguous。


contiguous 本身是形容詞表示連續的關于 contiguous,PyTorch 提供了is_contiguouscontiguous(形容詞動用)兩個方法 ,分別用于判定Tensor是否是 contiguous 的,以及保證Tensor是contiguous的。

PyTorch中的is_contiguous是什么含義?

is_contiguous直觀的解釋是Tensor底層一維數組元素的存儲順序與Tensor按行優先一維展開的元素順序是否一致

Tensor多維數組底層實現是使用一塊連續內存的1維數組(行優先順序存儲,下文描述),Tensor在元信息里保存了多維數組的形狀,在訪問元素時,通過多維度索引轉化成1維數組相對于數組起始位置的偏移量即可找到對應的數據。某些Tensor操作(如transpose、permute、narrow、expand)與原Tensor是共享內存中的數據,不會改變底層數組的存儲,但原來在語義上相鄰、內存里也相鄰的元素在執行這樣的操作后,在語義上相鄰,但在內存不相鄰,即不連續了(is not contiguous)。

如果想要變得連續使用contiguous方法,如果Tensor不是連續的,則會重新開辟一塊內存空間保證數據是在內存中是連續的,如果Tensor是連續的,則contiguous無操作。

行優先

行是指多維數組一維展開的方式,對應的是列優先。C/C++中使用的是行優先方式(row major),Matlab、Fortran使用的是列優先方式(column major),PyTorch中Tensor底層實現是C,也是使用行優先順序。舉例說明如下:

>>> t = torch.arange(12).reshape(3,4)
>>> t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])

二維數組 t 如圖1:

圖1. 3X4矩陣行優先存儲邏輯結構

數組 t 在內存中實際以一維數組形式存儲,通過 flatten 方法查看 t 的一維展開形式,實際存儲形式與一維展開一致,如圖2,

>>> t.flatten()
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

圖2. 3X4矩陣行優先存儲物理結構

而列優先的存儲邏輯結構如圖3。

圖3. 3X4矩陣列優先存儲邏輯結構

使用列優先存儲時,一維數組中元素順序如圖4:

圖4. 3X4矩陣列優先存儲物理結構

說明:圖1、圖2、圖3、圖4來自:What is the difference between contiguous and non-contiguous arrays?

圖1、圖2、圖3、圖4 中顏色相同的數據表示在同一行,不論是行優先順序、或是列優先順序,如果要訪問矩陣中的下一個元素都是通過偏移來實現,這個偏移量稱為步長(stride[1])。在行優先的存儲方式下,訪問行中相鄰元素物理結構需要偏移1個位置,在列優先存儲方式下偏移3個位置。

為什么需要 contiguous

1. torch.view等方法操作需要連續的Tensor。

transpose、permute 操作雖然沒有修改底層一維數組,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。torch.view 方法約定了不修改數組本身,只是使用新的形狀查看數據。如果我們在 transpose、permute 操作后執行 view,Pytorch 會拋出以下錯誤:

invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension 
spans across two contiguous subspaces). Call .contiguous() before .view(). 
at /Users/soumith/b101_2/2019_02_08/wheel_build_dirs/wheel_3.6/pytorch/aten/src/TH/generic/THTensor.cpp:213

為什么 view 方法要求Tensor是連續的[2]?考慮以下操作,

>>>t = torch.arange(12).reshape(3,4)
>>>t
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]])
>>>t.stride()
(4, 1)
>>>t2 = t.transpose(0,1)
>>>t2
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])
>>>t2.stride()
(1, 4)
>>>t.data_ptr() == t2.data_ptr() # 底層數據是同一個一維數組
True
>>>t.is_contiguous(),t2.is_contiguous() # t連續,t2不連續
(True, False)

t2 與 t 引用同一份底層數據 a,如下:

[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]

,兩者僅是stride、shape不同。如果執行 t2.view(-1) ,期望返回以下數據 b(但實際會報錯):

[ 0,  4,  8,  1,  5,  9,  2,  6, 10,  3,  7, 11]

a 的基礎上使用一個新的 stride 無法直接得到 b ,需要先使用 t2 的 stride (1, 4) 轉換到 t2 的結構,再基于 t2 的結構使用 stride (1,) 轉換為形狀為 (12,)的 b但這不是view工作的方式view 僅在底層數組上使用指定的形狀進行變形,即使 view 不報錯,它返回的數據是:

[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]

這是不滿足預期的。使用contiguous方法后返回新Tensor t3,重新開辟了一塊內存,并使用照 t2 的按行優先一維展開的順序存儲底層數據。

>>>t3 = t2.contiguous()
>>>t3
tensor([[ 0,  4,  8],[ 1,  5,  9],[ 2,  6, 10],[ 3,  7, 11]])
>>>t3.data_ptr() == t2.data_ptr() # 底層數據不是同一個一維數組
False

可以發現 t與t2 底層數據指針一致,t3 與 t2 底層數據指針不一致,說明確實重新開辟了內存空間。

為什么不在view 方法中默認調用contiguous方法?

因為歷史上view方法已經約定了共享底層數據內存,返回的Tensor底層數據不會使用新的內存,如果在view中調用了contiguous方法,則可能在返回Tensor底層數據中使用了新的內存,這樣打破了之前的約定,破壞了對之前的代碼兼容性。為了解決用戶使用便捷性問題,PyTorch在0.4版本以后提供了reshape方法,實現了類似于 tensor.contigous().view(*args)的功能,如果不關心底層數據是否使用了新的內存,則使用reshape方法更方便。 [3]

2. 出于性能考慮

連續的Tensor,語義上相鄰的元素,在內存中也是連續的,訪問相鄰元素是矩陣運算中經常用到的操作,語義和內存順序的一致性是緩存友好的(What is a “cache-friendly” code?[4]),在內存中連續的數據可以(但不一定)被高速緩存預取,以提升CPU獲取操作數據的速度。transpose、permute 后使用 contiguous 方法則會重新開辟一塊內存空間保證數據是在邏輯順序和內存中是一致的,連續內存布局減少了CPU對對內存的請求次數(訪問內存比訪問寄存器慢100倍[5]),相當于空間換時間。


PyTorch中張量是否連續的定義

對于任意的 k 維張量 t ,如果滿足對于所有 i,第 i 維相鄰元素間隔 = 第 i+1 維相鄰元素間隔 與 第 i+1 維長度的乘積,則 t 是連續的。

  • 使用 表示第 i 維相鄰元素之間間隔的位數,稱為步長,可通過 stride 方法獲得。
  • 使用 表示固定其他維度時,第 i 維元素數量。

PyTorch中判讀張量是否連續的實現

PyTorch中通過調用 is_contiguous 方法判斷 tensor 是否連續,底層實現為 TH 庫中THTensor.isContiguous 方法,為方便加上一些調試信息,翻譯為 Python 代碼如下:

def isContiguous(tensor):"""判斷tensor是否連續    :param torch.Tensor tensor: :return: bool"""z = 1d = tensor.dim() - 1size = tensor.size()stride = tensor.stride()print("stride={} size={}".format(stride, size))while d >= 0:if size[d] != 1:if stride[d] == z:print("dim {} stride is {}, next stride should be {} x {}".format(d, stride[d], z, size[d]))z *= size[d]                else:print("dim {} is not contiguous. stride is {}, but expected {}".format(d, stride[d], z))return Falsed -= 1return True

判定上文中 t、t2 是否連續的輸出如下:

>>>isContiguous(t)
stride=(4, 1) size=torch.Size([3, 4])
dim 1 stride is 1, next stride should be 1 x 4
dim 0 stride is 4, next stride should be 4 x 3True
>>>isContiguous(t2)
stride=(1, 4) size=torch.Size([4, 3])
dim 1 is not contiguous. stride is 4, but expected 1False

isContiguous 實現可以看出,最后1維的 stride 必須為1(邏輯步長),這是合理的,最后1維即邏輯結構上最內層數組,其相鄰元素間隔位數為1,按行優先順序排列時,最內層數組相鄰元素間隔應該為1。

numpy中張量是否連續的定義

對于任意的 N 維張量 t ,如果滿足第 k 維相鄰元素間隔 = 第 K+1維 至 最后一維的長度的乘積,則 t 是連續的。

  • 使用 表示行優先模式下,第 維度相鄰兩個元素之間在內存中間隔的字節數,可通過 strides 屬性獲得。
  • 使用表示每個元素的字節數(根據數據類型而定,如PyTorch中 int32 類型是 4,int64 是8)。
  • 使用 表示固定其他維度時,第 維元素的個數,即 t.shape[j]。

?

PyTorch與numpy中contiguous定義的關系

PyTorch和numpy中對于contiguous的定義看起來有差異,本質上是一致的。

首先對于 stride的定義指的都是某維度下,相鄰元素之間的間隔,PyTorch中的 stride 是間隔的位數(可看作邏輯步長),而numpy 中的 stride 是間隔的字節數(可看作物理步長),兩種 stride 的換算關系為:

再看對于 stride 的計算公式,PyTorch 和 numpy 從不同角度給出了公式。PyTorch 給出的是一個遞歸式定義,描述了兩個相鄰維度 stride 與 size 之間的關系。numpy 給出的是直接定義,描述了 stride 與 shape 的關系。PyTorch中的 size 與 numpy 中的 shape 含義一致,都是指 tensor 的形狀。 都是指當固定其他維度時,該維度下元素的數量。

參考

  1. ^訪問相鄰元素所需要跳過的位數或字節數?python - How to understand numpy strides for layman? - Stack Overflow
  2. ^Munging PyTorch's tensor shape from (C, B, H) to (B, C*H)?python - Munging PyTorch's tensor shape from (C, B, H) to (B, C*H) - Stack Overflow
  3. ^view() after transpose() raises non contiguous error #764?view() after transpose() raises non contiguous error · Issue #764 · pytorch/pytorch · GitHub
  4. ^What is a “cache-friendly” code??c++ - What is a "cache-friendly" code? - Stack Overflow
  5. ^計算機緩存Cache以及Cache Line詳解?計算機緩存Cache以及Cache Line詳解 - 知乎
  6. ^Tensor.view方法對連續的描述?torch.Tensor — PyTorch 1.10.1 documentation
  7. ^行優先布局的stride?The N-dimensional array (ndarray) — NumPy v1.22 Manual

https://www.zydui.com/af832UG8CDQ9VB1YB.html
>

相关文章:

  • 詳解PyTorch中的contiguous
  • PyTorch中的contiguous解讀
  • Pytorch中contiguous()函數理解
  • ios自定義UITabBar-仿寫掌上英雄聯盟的UITabBar
  • 基于Cocos2d-x的英雄聯盟皮膚選擇菜單
  • lol-登陸英雄聯盟出錯
  • JS中雙層for循環執行順序
  • 關于for循環執行順序
  • 上古卷軸5boss計算機丟失,上古卷軸5常見BUG解決辦法
  • 上古世紀服務器維護真情禮,4月9日例行維護懷舊服合服公告
  • 塔羅牌張數
  • 工程管理中的工程技術
  • 銳派出品:LOL新年特輯S4各類細節之下路篇
  • 藍城兄弟Q4業績背后,垂直社區具備多少想象力?
  • Mac OS啟動服務優化高級篇(launchd tuning)
  • #Geek Point# 為什么現在要去印度看一看?
  • vm 流程運行mac os_什么是“商務”流程,為什么在我的Mac上運行?
  • # 陌生人社交產品:需求、困境與破局之道
  • mac 不受信任在哪里更改_什么是受信任的,為什么它可以在Mac上運行?
  • 車行軌跡分類實踐
  • 智慧車行預約小程序 v9.1
  • i12藍牙耳機充電倉怎么看充滿電_車行藍牙耳機價格高性價比的選擇
  • 車行平安
  • 論文閱讀——《基于卷積神經網絡的車行環境多類障礙物檢測與識別》
  • 飛槳開發者創意薈:PaddleHub一鍵部署,AI創意實現原來如此簡單
  • eclipse左側欄目即包資源管理器怎么打開
  • 卷毛機器人符文_卷毛S6娜美輔助天賦 娜美輔助符文天賦S6最新
  • 天賦介紹
  • 蘋果怎么沒有4g信號還無服務器,不顯示4g信號怎么回事?蘋果手機不顯示4g信號的解決方法...
  • c4D體積生成和Quad Remesher重新拓撲減面插件