哦哇資訊網

Meta 內部都在用 FX 工具:利用 Graph Transformation 最佳化 PyTorch 模型

由 HyperAI超神經 發表于 美食2023-01-10

PyTorch 中的 graph mode 在效能方面表示更為出色,本文介紹 Torch。FX 這個強大工具,可以捕捉和最佳化 PyTorch 程式 graph。

一、簡介

PyTorch 支援兩種執行模式:eager mode 和 graph mode。

eager mode 中,模型中的運算子在讀取時會立即執行,它易於使用,對機器學習從業者更友好,因此被設定為預設的執行模式。

graph mode 中,運算子先被合成一個 graph,然後作為一個整體進行編譯和執行,它的效能更高,因此在實際生產中大量使用。

具體來說,graph mode 支援運算元融合,兩個運算元透過合併,可以降低或本地化記憶體讀取以及核心啟動總開銷。

融合可以是橫向 (horizontal) 的:

採取應用於多個 operand 的單一操作(如 BatchNorm),並將這些 operand 合併到一個數組中。

融合也可以是縱向 (vertical) 的:

將一個核心與另一個核心合併,後者需要使用第一個核心的輸出(如 ReLU 後接卷積)。

Torch。FX(縮寫為 FX)是一個公開可用的工具包,作為 PyTorch 軟體包的一部分,支援 graph mode 的執行。它可以:

從 PyTorch 程式中獲取 graph

允許開發者在獲取的 graph 上編寫 transformation

Meta 內部先前已經在用 FX 來最佳化生產模型 (production model) 的訓練吞吐量 (training throughput)。本文將透過介紹 Meta 開發的基於 FX 的最佳化,來展示利用圖結構轉換 (graph transformation) 最佳化 PyTorch 部署模型效能的方法。

二、背景

embedding table 廣泛存在於推薦系統中,

本節將介紹 FX 和 embedding table 的背景知識。

2。1。 FX

圖 1 是一個簡單示例,演示瞭如何用 FX 轉換 PyTorch 程式,

它包含三個步驟:

從程式中獲取 graph

修改 graph(在本例中,我們用 GELU 代替 RELU)

從修改後的 graph 中生成一個新程式

圖1:在 PyTorch 模組中用 GELU 取代 RELU 的 FX

FX API 為檢查和轉換 PyTorch 程式 graph 還提供了許多其他功能。

2。2。 embedding table

圖2:批尺寸=1 的稀疏特徵 embedding table 示意圖

在推薦系統中,

稀疏特徵(例如,User ID,Story ID)由 embedding table 表示。

embedding table E 是一個 HxD 矩陣,其中 H 是雜湊大小,D 是嵌入向量維度。E 的每一行都是一個浮點數向量。

feature hashing 的作用是將一個稀疏特徵對映到 E的索引列表中,例如 [S1,S2,。。。,Sk],其中 0≤Si

為了充分利用 GPU,稀疏特徵通常為批處理。

批處理中的每個實體都有自己的索引列表。如果一個批次有 B 個實體,可以簡單理解為一個表徵有 B 個索引列表。

更為嚴謹的表示方法是將 B 個索引列表合併成一個索引列表,並新增一個索引長度的列表(該批中的每個實體都有一個長度 length)。

例如,如果一批包含 3 個實體,其索引列表如下:

Entity 1: indices = [10, 20]

Entity 2: indices = [5, 9, 77, 81]

Entity 3: indices = [15, 20, 45]

則完整批尺寸的 indice 和 length 將是:

Indices = [10, 20, 5, 9, 77, 81, 15, 20, 45]

Lengths = [2, 4, 3]

而整個 batch 的 embedding table 查詢,輸出為是一個 BxD 矩陣。

三、3 種 FX Transformation

PyTorch 更新了 3 個 FX transformation,以加速對 embedding table 的訪問,本節將逐一介紹。

下文 3。1 關於將多個小輸入張量結合成一個大張量的轉換;3。2 關於將多個平行計算鏈融合成一個計算鏈的轉換;3。3 關於將通訊與計算重疊的轉換。

3。1 結合輸入稀疏特徵

batch 中的每個輸入稀疏特徵,都可以表示為兩個列表:一個索引列表和一個 B length 列表,其中 B 表示批尺寸。

在 PyTorch 中,這兩個列表都可以以張量的形式存在。

當 PyTorch 模型在 GPU 上執行時,embedding table 通常儲存在 GPU 記憶體中(它更接近 GPU,讀寫頻寬比 CPU 記憶體更高)。

需要使用輸入稀疏特徵時,兩個張量都要先從 CPU 複製到 GPU。然而每個主機到裝置的記憶體複製都需要啟動核心,這對於實際的資料傳輸來說,會更加耗費時間。

如果一個模型使用了多個輸入稀疏特徵,這種複製可能成為效能瓶頸(例如,1000 個輸入稀疏特徵將需要從主機到裝置複製 2000 個張量)。

一個減少主機到裝置 memcpy 數量的最佳化方法,就是在多個輸入稀疏特徵傳送到裝置之前,先將其進行組合。

例如,給定以下三個輸入特徵:

Feature_A: indices = [106, 211, 7], lengths = [2, 1]

Feature_B: indices = [52, 498, 616, 870, 1013], lengths = [3, 2]

Feature_C: indices = [2011, 19, 351, 790], lengths = [1, 3]

組合後的形式為:

Features_A_B_C: indices = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790], lengths = [2, 1, 3, 2, 1, 3]

所以不需要從主機到裝置複製 3x2=6 個張量,只需要複製 2 個張量。

圖 3(b) 描述了這種最佳化的實現,它包含兩個元件:

CPU 端:

輸入 pipeline 被修改為將所有稀疏特徵的 indices 組合成一個張量,所有 length 組合成另一個張量。然後將這兩個張量複製到 GPU 上。

GPU 端:

使用 FX,在模型 graph 中插入一個Permute_and_Split 運算元,從合併的張量中恢復單個特徵 indices 和 length 張量,並將其傳送至下游的相應節點。

最佳化前:兩個張量都要從 CPU 複製到 GPU

最佳化後:將輸入稀疏特徵進行組合

3。2 從訪問 embedding table 開始的計算鏈橫向融合

在一個生產模型中,每個 GPU 上有 10 個 embedding table 很常見。出於效能方面的考慮,

對這些 table 的查詢被分到一組,這樣它們的輸出就被串聯在一個大張量中

(見圖 4(a)中的紅色部分)。

為了對單個特徵輸出進行計算,

使用 Split 運算元將大張量分成 N 個小張量

(其中 N 為特徵的數量),然後將所需的計算應用於每個張量。

如圖 4(a) 所示,應用於每個特徵輸出 O 的計算是Tanh(LayerNorm(O))。所有的計算結果都被串聯成一個大的張量,然後傳遞給下游的運算元(圖 4(a) 中的 Op1)。

這裡主要的 runtime cost 是 GPU 核心啟動的開銷。

例如,圖 4(a) 中的 GPU 核心的啟動次數為 2*N+3(圖中的每個橢圓都表示一個 GPU 核心)。這會影響效能,因為 LayerNorm 和 Tanh 在 GPU 上的執行時間,與它們的核心啟動時間相比很短。

此外,Split 運算元可能會建立一個額外的嵌入向量輸出張量的副本,消耗額外的 GPU 記憶體。

用 FX 來實現一種叫做橫向融合 (horizontal fusion) 的最佳化,可以大大減少 GPU 核心的啟動次數

(在這個例子中,最佳化後的 GPU 核心啟動次數為 5,見圖 4(b))。

使用 Add_middle_dim 運算元代替顯式 Split,將 shape 為 (B, NxD) 的 2D 嵌入張量重塑為 shape 為 (B, N, D) 的 3D 張量。接下來將一個單一的 LayerNorm 應用到它的最後一維。對 LayerNorm 的結果應用一個 Tanh。最後,用 Remove_middle_dim 運算元將 Tanh 的結果恢復成 2D 張量。

由於 Add_middle_dim 和 Remove_middle_dim 只是重塑張量,

並沒有建立額外的副本,所以也可以減少 GPU 記憶體的消耗。

最佳化前:所有輸出被串聯到一個大張量中

Meta 內部都在用 FX 工具:利用 Graph Transformation 最佳化 PyTorch 模型

進行橫向融合最佳化後

3。3 計算與通訊間的重疊 (overlap)

面向投產的推薦模型的訓練,通常是在分散式 GPU 系統上完成的。

由於每個 GPU 的裝置記憶體容量不足以容納模型中的所有 embedding table,因此需要將其分佈在多個 GPU 上。

在訓練步驟中,GPU 需要從其他 GPU 上的 embedding table 中讀取/寫入特徵值。

這被稱為 all-to-all 通訊,可能是影響效能的重要原因。

透過 FX 實現一個 transformation,可以將計算與 all-to-all 通訊重疊。

圖 5(a) 顯示了一個具備嵌入向量 table 訪問 (EmbeddingAllToAll) 及其他運算元的模型 graph 例項。如圖 5(b) 所示,在沒有任何最佳化的情況下,它們會在一個 GPU 流上順序執行。

使用FX將 EmbeddingAllToAll 分成 EmbeddingAllToAll_Request和EmbeddingAllToAll_Wait,並在它們之間安排獨立的運算元。

Meta 內部都在用 FX 工具:利用 Graph Transformation 最佳化 PyTorch 模型

圖5:計算與通訊的重疊

3。4 總結

Meta 內部都在用 FX 工具:利用 Graph Transformation 最佳化 PyTorch 模型

表1:本節討論的最佳化及解決的相應效能瓶頸

為了發現哪些模型會從這些 transformation 中受益,開發人員對 MAIProf 收集的執行在 Meta 資料中心的模型的效能資料進行分析。

得出與 eager mode 相比,這些 transformation 在一組生產模型上實現了 2-3 倍的速度提升。

四、結語

從效能角度考量,PyTorch 中的 graph mode 比生產環境中使用的 eager mode 更受歡迎。

FX 是一個強大的工具,可以捕捉和最佳化 PyTorch 程式 graph。本文展示了三種 FX transformation,用於最佳化 Meta 內部的生產推薦模型。

最後希望更多 PyTorch 開發者可以使用 graph transformation 來提升模型的效能。

—— 完 ——

TAG: GPU張量graphFxtable