真正的元兇是缺乏批次不變性。
就在今天,由 OpenAI 前 CTO Mira Murati 成立于今年 2 月的人工智能初創(chuàng)公司 Thinking Machines Lab,發(fā)了第一篇文章 ——《克服 LLM 推理中的不確定性》(Defeating Nondeterminism in LLM Inference)。
這篇博客屬于 Thinking Machines Lab 新提出的博客欄目 Connectionism,意為「連接主義」。該公司表示:「我們相信,分享才能讓科學(xué)更好地發(fā)展。Connectionism 將涵蓋與我們的研究一樣廣泛的主題:從核函數(shù)數(shù)值計(jì)算到提示工程。Connectionism 這一名稱可以追溯到 AI 的早期年代。它曾是 20 世紀(jì) 80 年代的一個(gè)研究分支,專注于神經(jīng)網(wǎng)絡(luò)及其與生物大腦的相似性?!?
此外,Thinking Machines Lab 聯(lián)合創(chuàng)始人、著名技術(shù)博主翁荔(Lilian Weng)還在轉(zhuǎn)推中透露了一個(gè)消息,Connection Machine,即「連接機(jī)」,難道他們的產(chǎn)品要來(lái)了?
真是讓人期待呢。
地址:https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/
博客主要作者為 Horace He,這位 PyTorch 核心開發(fā)者于今年 3 月從 Meta 離職,加入了 Thinking Machines。
接下來(lái)看博客原文內(nèi)容。
可復(fù)現(xiàn)性(reproducibility)是科學(xué)進(jìn)步的基石。然而,從大語(yǔ)言模型中獲得可復(fù)現(xiàn)的結(jié)果卻非常困難。
例如,你可能會(huì)發(fā)現(xiàn):即使是向 ChatGPT 提出同一個(gè)問(wèn)題多次,也可能得到不同的回答。這本身并不令人意外,因?yàn)檎Z(yǔ)言模型生成結(jié)果的過(guò)程涉及采樣 —— 這個(gè)過(guò)程會(huì)將模型的輸出轉(zhuǎn)換為一個(gè)概率分布,并以概率方式選擇一個(gè) token。
更令人驚訝的是,即使我們將溫度參數(shù)調(diào)到 0(理論上使采樣過(guò)程變?yōu)榇_定性),大語(yǔ)言模型的 API 在實(shí)際中仍然不是確定性的。研究者已經(jīng)對(duì)此有諸多討論。
即使是在你自己的硬件上,使用開源推理庫(kù)(如 vLLM 或 SGLang)運(yùn)行推理,采樣過(guò)程依然不是確定性的。
為什么大語(yǔ)言模型的推理引擎不是確定性的呢?
一個(gè)常見的假設(shè)是:浮點(diǎn)運(yùn)算的非結(jié)合性(non-associativity)與并發(fā)執(zhí)行的某種組合會(huì)導(dǎo)致不確定性,這取決于哪個(gè)并發(fā)核心首先完成。我們將這種解釋稱為「LLM 推理不確定性的『并發(fā) + 浮點(diǎn)』假設(shè)」。例如,一篇最近的 arXiv 論文(arXiv:2506.09501)寫道:
GPU 中的浮點(diǎn)運(yùn)算具有非結(jié)合性(non-associativity),意味著 (a+b)+c≠a+(b+c),這是由于精度有限和舍入誤差所致。這一特性會(huì)直接影響 transformer 架構(gòu)中注意力得分和 logit 的計(jì)算,因?yàn)樵诙嗑€程中進(jìn)行的并行操作,其執(zhí)行順序不同會(huì)導(dǎo)致結(jié)果差異。
雖然這個(gè)假設(shè)并不完全錯(cuò)誤,但它并沒(méi)有揭示事情的全貌。
例如,即使在 GPU 上,對(duì)相同的數(shù)據(jù)反復(fù)進(jìn)行相同的矩陣乘法運(yùn)算,每次的結(jié)果也都是每一位都相同的。我們確實(shí)在使用浮點(diǎn)數(shù),GPU 也確實(shí)具有高度并發(fā)性。
那為什么在這個(gè)測(cè)試中卻看不到不確定性呢?
要理解大語(yǔ)言模型推理不確定性的真正原因,我們必須更深入地探究。
不幸的是,甚至連「LLM 推理是確定性」的這一說(shuō)法的定義都很難明確。或許令人困惑的是,以下這些看似矛盾的說(shuō)法實(shí)際上同時(shí)都是真實(shí)的:
GPU 上的一些核(kernel)是不確定性的。
然而,語(yǔ)言模型在前向傳播過(guò)程中使用的所有核都是確定性的。
此外,像 vLLM 這樣的 LLM 推理服務(wù)器的前向傳播過(guò)程,也可以被認(rèn)為是確定性的。
盡管如此,從使用推理服務(wù)器的任何用戶的角度來(lái)看,結(jié)果卻是不確定性的。
在這篇文章中,我們將解釋為什么「并發(fā) + 浮點(diǎn)」假設(shè)沒(méi)有達(dá)到目的,揭露 LLM 推理不確定性背后的真正罪魁禍?zhǔn)?,并解釋如何克服不確定性并在 LLM 推理中獲得真正可重復(fù)的結(jié)果。
原罪:浮點(diǎn)數(shù)的非結(jié)合性
在討論不確定性之前,有必要先解釋一下為什么存在數(shù)值差異。畢竟,我們通常將機(jī)器學(xué)習(xí)模型視為遵循交換律或結(jié)合律等結(jié)構(gòu)性規(guī)則的數(shù)學(xué)函數(shù)。我們的機(jī)器學(xué)習(xí)庫(kù)難道不應(yīng)該提供數(shù)學(xué)上正確的結(jié)果嗎?
罪魁禍?zhǔn)资歉↑c(diǎn)非結(jié)合性(floating-point non-associativity)。也就是說(shuō),對(duì)于浮點(diǎn)數(shù) a、b、c,有:
諷刺的是,正是打破結(jié)合律讓浮點(diǎn)數(shù)變得有用。
浮點(diǎn)數(shù)之所以有用,是因?yàn)樗鼈冊(cè)试S動(dòng)態(tài)的精度。為了便于解釋,我們將使用十進(jìn)制(而不是二進(jìn)制),其中浮點(diǎn)數(shù)的格式為:尾數(shù) * 10^ 指數(shù)。這里還將使用 3 位數(shù)字作為尾數(shù),1 位數(shù)字作為指數(shù)。(注:在計(jì)算機(jī)科學(xué)中,尾數(shù)(mantissa,或有效數(shù))是浮點(diǎn)數(shù)中用來(lái)表示精度的部分,它決定了數(shù)字的有效數(shù)字位數(shù)和精度。)
例如,對(duì)于值 3450,我們可以將其精確表示為 3.45*10^3。我們也可以將更小的值(例如 0.486)表示為 4.86*10^-1。這樣,浮點(diǎn)數(shù)既可以表示非常小的值,也可以表示非常大的值。在科學(xué)領(lǐng)域,我們可以說(shuō)浮點(diǎn)數(shù)使我們能夠保持有效數(shù)的個(gè)數(shù)恒定。
如果兩個(gè)浮點(diǎn)數(shù)的指數(shù)相同,它們的加法運(yùn)算看起來(lái)與整數(shù)加法類似。例如:
但是,如果兩個(gè)浮點(diǎn)數(shù)的指數(shù)不同,例如 1230 和 23.4,又會(huì)發(fā)生什么情況呢?理論上,它們的和應(yīng)該是 1253.4。然而,由于浮點(diǎn)數(shù)運(yùn)算只能保留 3 位有效數(shù)字,因此結(jié)果會(huì)被舍入為 1.25×103(或 1250)。
表示 1230 需要 3 位有效數(shù)字,表示 23.4 也需要 3 位有效數(shù)字。但是,這兩個(gè)數(shù)相加的結(jié)果(1253.4)卻需要 5 位有效數(shù)字才能精確表示。因此,我們的浮點(diǎn)數(shù)格式必須舍棄最后兩位(34)。某種程度上,這相當(dāng)于我們?cè)谙嗉又埃瑢⒃瓉?lái)的 23.4 四舍五入為 20.0。
然而,這樣做會(huì)導(dǎo)致信息丟失。請(qǐng)注意,只要我們對(duì)兩個(gè)不同階位(即不同指數(shù))的浮點(diǎn)數(shù)進(jìn)行加法運(yùn)算,就會(huì)發(fā)生這種情況。而實(shí)際應(yīng)用中,我們經(jīng)常需要對(duì)不同指數(shù)的浮點(diǎn)數(shù)進(jìn)行加法運(yùn)算。事實(shí)上,如果我們能夠保證所有浮點(diǎn)數(shù)的指數(shù)都相同,那么我們完全可以只使用整數(shù)!
換句話說(shuō),每次以不同順序相加浮點(diǎn)數(shù)時(shí),結(jié)果都有可能完全不同。舉個(gè)極端的例子,對(duì)于某個(gè)數(shù)組,根據(jù)加法順序的不同,其求和結(jié)果可能出現(xiàn) 102 種不同的結(jié)果。
雖然這是導(dǎo)致輸出結(jié)果不一致的根本原因,但它并不能直接解釋不確定性行為的來(lái)源。它也無(wú)法幫助我們理解為什么浮點(diǎn)數(shù)的加法順序會(huì)改變、這種情況在什么時(shí)候發(fā)生、以及我們?nèi)绾伪苊馑?
答案藏在核函數(shù)(kernel)的實(shí)現(xiàn)方式中。
為什么核函數(shù)計(jì)算中數(shù)字加法順序并非總是固定的?
如前所述,解釋核函數(shù)計(jì)算中數(shù)字加法順序不一致的一個(gè)常見原因是「并發(fā)性 + 浮點(diǎn)運(yùn)算」假設(shè)。
該假設(shè)認(rèn)為,如果并發(fā)線程的執(zhí)行順序是不可預(yù)測(cè)的,并且累加操作的順序依賴于并發(fā)線程的執(zhí)行順序(例如原子加法 /atomic adds),那么最終的累加結(jié)果也會(huì)變得不可預(yù)測(cè)。
然而,令人困惑的是,盡管這種現(xiàn)象會(huì)導(dǎo)致核函數(shù)計(jì)算結(jié)果的不確定性,但并發(fā)機(jī)制(以及原子加法)實(shí)際上與大型語(yǔ)言模型推理中的不確定性無(wú)關(guān)!
為了解釋真正的罪魁禍?zhǔn)资鞘裁?,我們首先需要了解為什么現(xiàn)代 GPU 核函數(shù)很少需要使用原子加法。
什么時(shí)候需要使用原子加法操作?
GPU 通常會(huì)同時(shí)在多個(gè)核心(即流處理器)上并行運(yùn)行程序。由于這些核心之間沒(méi)有內(nèi)置同步機(jī)制,因此如果它們需要相互通信,就會(huì)很麻煩。例如,如果所有核心都需要對(duì)同一個(gè)元素進(jìn)行累加,就可以使用原子加法(有時(shí)也稱為 fetch-and-add)。原子加法是不確定性的,結(jié)果的累加順序完全取決于哪個(gè)核心先完成計(jì)算。
具體來(lái)說(shuō),假設(shè)你要使用 100 個(gè)核心對(duì)一個(gè)包含 100 個(gè)元素的向量進(jìn)行求和(例如 torch.sum ())。雖然可以并行加載所有 100 個(gè)元素,但最終我們必須將結(jié)果匯總為一個(gè)值。一種實(shí)現(xiàn)方法是使用某種原子加法操作,硬件保證所有加法操作都會(huì)執(zhí)行,但并不保證執(zhí)行順序。
原子加法操作可以確保每個(gè)核心的計(jì)算結(jié)果都能最終反映在總和中。但是,它并不能保證這些結(jié)果的累加順序。累加順序完全取決于哪個(gè)核心先完成計(jì)算,這是一種不確定性行為。
因此,多次執(zhí)行相同的并行程序可能會(huì)產(chǎn)生不同的結(jié)果。這通常就是人們所說(shuō)的不確定性,即,使用完全相同的輸入數(shù)據(jù)執(zhí)行兩次相同的程序,但最終結(jié)果卻可能不同。這被稱為運(yùn)行間不確定性(run-to-run nondeterminism),例如,運(yùn)行兩次完全相同的 Python 腳本,即使依賴庫(kù)版本完全相同,結(jié)果也可能不同。
雖然并發(fā)的原子加法操作會(huì)使核函數(shù)的執(zhí)行結(jié)果變得不可預(yù)測(cè),但對(duì)于大多數(shù)核函數(shù)來(lái)說(shuō),原子加法并非必需。
事實(shí)上,在 LLM 的典型前向傳播過(guò)程中,通常根本不需要使用原子加法。這可能令人感到意外,因?yàn)椴⑿谢?jì)算中的歸約操作通??梢詮脑蛹臃ㄖ蝎@益。但實(shí)際上,原子加法在大多數(shù)情況下并非必需,主要原因有兩點(diǎn)。
1. 通常情況下,批處理維度上的并行性已經(jīng)足夠,因此我們無(wú)需在歸約維度上進(jìn)行并行化。
2. 隨著時(shí)間的推移,大多數(shù)神經(jīng)網(wǎng)絡(luò)庫(kù)都采用了各種策略,以在不犧牲性能的情況下實(shí)現(xiàn)結(jié)果的可預(yù)測(cè)性。
由于上述兩個(gè)因素,對(duì)于絕大多數(shù)神經(jīng)網(wǎng)絡(luò)操作來(lái)說(shuō),不使用原子加法幾乎不會(huì)帶來(lái)性能損失。
當(dāng)然,仍然有少數(shù)常見操作在不使用原子加法時(shí)會(huì)遭遇顯著的性能下降。例如,PyTorch 中的 scatter_add(即 a [b] += c)。不過(guò),在大語(yǔ)言模型中唯一常用且依賴原子加法的操作,是 FlashAttention 的反向傳播(backward)。
然而,LLM 的前向傳播過(guò)程中并不涉及任何需要原子加法的操作。因此,LLM 的前向過(guò)程本質(zhì)上是運(yùn)行間確定的(即每次運(yùn)行結(jié)果一致)。
維基百科上寫道:一個(gè)確定性算法是在給定特定輸入的情況下,始終產(chǎn)生相同輸出的算法。而在這里,只要輸入完全相同(即推理服務(wù)器處理的請(qǐng)求完全一致),前向傳播就總是會(huì)生成完全相同的輸出。
然而,前向傳播本身是確定性的并不意味著整個(gè)系統(tǒng)也是確定性的。比如,如果某個(gè)請(qǐng)求的輸出依賴于并行用戶的請(qǐng)求(例如 batch-norm 這樣的操作),那么由于每個(gè)請(qǐng)求都無(wú)法預(yù)知其他并發(fā)請(qǐng)求的內(nèi)容,從單個(gè)請(qǐng)求的視角來(lái)看,整個(gè) LLM 推理過(guò)程就會(huì)是不確定性的。
事實(shí)證明,我們的請(qǐng)求輸出確實(shí)依賴于其他并發(fā)用戶的請(qǐng)求。但這并不是因?yàn)榭?batch 泄露了信息,而是因?yàn)槲覀兊那跋騻鞑ミ^(guò)程缺乏批次不變性(batch invariance),這導(dǎo)致同一個(gè)請(qǐng)求的輸出會(huì)受到前向傳播中 batch size(batch size)變化的影響。
批次不變性與確定性
為了說(shuō)明什么是批次不變性,我們可以簡(jiǎn)化問(wèn)題,只關(guān)注矩陣乘法(matmul)。你可以假設(shè)所有的 matmul 實(shí)現(xiàn)都是運(yùn)行間確定的,也就是說(shuō),同樣的輸入,每次運(yùn)行都會(huì)得到相同的結(jié)果。
但它們并不是批次不變的。換句話說(shuō),當(dāng) batch size 發(fā)生變化時(shí),batch 中的每個(gè)元素可能會(huì)得到不同的計(jì)算結(jié)果。
從數(shù)學(xué)角度來(lái)看,這是一種相當(dāng)反常的性質(zhì)。理論上,矩陣乘法在 batch 維度上應(yīng)當(dāng)是獨(dú)立的,batch 中其他元素的存在與否,或 batch 的大小,都不應(yīng)影響某個(gè)具體元素的計(jì)算結(jié)果。
然而,我們通過(guò)實(shí)驗(yàn)證據(jù)可以發(fā)現(xiàn),現(xiàn)實(shí)情況并非如此。
請(qǐng)注意,這里的確定性是指每次運(yùn)行結(jié)果都相同。如果你多次運(yùn)行該腳本,它會(huì)始終返回相同的結(jié)果。
但是,如果將非批處理不變的核函數(shù)用作更大推理系統(tǒng)的一部分,則整個(gè)系統(tǒng)可能變得不確定性。當(dāng)你向推理端點(diǎn)發(fā)送請(qǐng)求時(shí),從用戶角度來(lái)看,服務(wù)器的負(fù)載情況是不可預(yù)測(cè)的。負(fù)載決定了核函數(shù)的 batch size,從而影響每個(gè)請(qǐng)求的最終結(jié)果。
如果你把某種核函數(shù)不具備不變性的屬性(例如:batch size)與該屬性本身的不確定性(例如:服務(wù)器負(fù)載情況)組合在一起,就會(huì)得到一個(gè)不確定性的系統(tǒng)。
換句話說(shuō),幾乎所有大語(yǔ)言模型推理端點(diǎn)之所以是不確定的,主要原因就是負(fù)載(以及由此決定的 batch size)本身具有不確定性!這種不確定性并非僅限于 GPU,使用 CPU 或 TPU 運(yùn)行的 LLM 推理端點(diǎn)也會(huì)存在同樣的問(wèn)題。因此,如果我們想避免推理服務(wù)器中的不確定性,就必須確保核函數(shù)對(duì) batch size 具有不變性。
為了理解如何實(shí)現(xiàn)這一點(diǎn),我們首先需要了解為什么核函數(shù)默認(rèn)情況下并不具備批處理不變性。
我們?nèi)绾问购司哂信尾蛔冃裕?
為了確保 Transformer 模型的實(shí)現(xiàn)與 batch size 無(wú)關(guān),我們必須確保模型中的每個(gè)核心模塊都與 batch size 無(wú)關(guān)。幸運(yùn)的是,我們可以假設(shè)每個(gè)逐點(diǎn)運(yùn)算(pointwise operation)都與 batch size 無(wú)關(guān)。因此,我們只需要擔(dān)心涉及的 3 個(gè)操作:RMSNorm、矩陣乘法和注意力。
巧合的是,這些操作的難度正好是依次遞增的。要想在保持合理性能的同時(shí)實(shí)現(xiàn)批次不變性,每一種操作都需要一些額外的考量。我們先從 RMSNorm 開始談起。
RMSNorm
RMSNorm 實(shí)現(xiàn)方式:
批次不變性的要求是,無(wú)論核函數(shù)的 batch size 如何,每個(gè)元素的歸約順序都必須保持不變。需要注意的是,這并不意味著我們必須始終使用相同的歸約策略。例如,即使我們改變了要進(jìn)行歸約的元素?cái)?shù)量,只要?dú)w約順序不變,我們的算法仍然可以滿足批處理不變性的要求。
因此,只有當(dāng) batch size 影響到歸約策略時(shí),我們才會(huì)打破批次不變性。
讓我們來(lái)看一下 RMSNorm 的標(biāo)準(zhǔn)并行化策略。一般來(lái)說(shuō),并行算法都會(huì)從盡量減少核心之間的通信中獲益。在這里,為了方便討論,你可以假設(shè)我們所說(shuō)的核心(cores)就是指 SM(Streaming Multiprocessors,流處理多處理器)。更具體地說(shuō),這里重要的性質(zhì)是:核函數(shù)啟動(dòng)的線程塊(threadblocks)數(shù)量多于 SM 的數(shù)量。
基于這一點(diǎn),一種可行的策略就是:將每個(gè) batch 元素分配給一個(gè)核心,就像上圖展示的那樣。
當(dāng)我們?cè)黾?batch size 時(shí),并不會(huì)影響歸約策略;如果 batch size = 200 已經(jīng)能為核函數(shù)提供足夠的并行性,那么 batch size = 2000 顯然也同樣能夠提供足夠的并行性。
另一方面,減小 batch size 也會(huì)帶來(lái)一些挑戰(zhàn)。由于我們?yōu)槊總€(gè)批次元素分配一個(gè)核心,減小 batch size 會(huì)導(dǎo)致核心數(shù)量大于批次元素?cái)?shù)量,從而造成部分核心閑置。遇到這種情況,優(yōu)秀的核函數(shù)工程師會(huì)采用前面提到的解決方案之一(原子加法或分段求和),從而保持良好的并行性,進(jìn)而提升性能。然而,這會(huì)改變求和策略,導(dǎo)致該核函數(shù)不再具備 batch size 不變的特性。
最簡(jiǎn)單的解決方案就是直接忽略這些情況。這并不是完全不合理的,因?yàn)楫?dāng) batch size 很小時(shí),核函數(shù)通常本來(lái)就能很快執(zhí)行,因此即使出現(xiàn)一些減速,也不會(huì)造成災(zāi)難性的影響。
如果我們必須優(yōu)化這種場(chǎng)景,一種方法是:始終使用一種在極小 batch size 下也能提供足夠并行度的歸約策略。這樣的策略會(huì)在 batch size 較大時(shí)導(dǎo)致過(guò)度并行,從而無(wú)法達(dá)到峰值性能,但它可以讓我們?cè)谡麄€(gè) batch size 范圍內(nèi)都獲得尚可(雖然不是最佳)的性能表現(xiàn)。
批次不變矩陣乘法
從本質(zhì)上講,你可以把矩陣乘法看作是一次逐點(diǎn)運(yùn)算后接一次歸約。那么,如果我們通過(guò)將輸出劃分為小塊來(lái)并行化矩陣乘法,就能得到一種類似的數(shù)據(jù)并行核函數(shù)策略,使得每一次歸約都在單個(gè)核心內(nèi)完成。
與 RMSNorm 類似,矩陣乘法的批次維度(M 和 N)也可能變得過(guò)小,迫使我們必須沿歸約維度(K)進(jìn)行拆分。盡管有兩個(gè)批次維度,矩陣乘法仍然需要每個(gè)核心有更多的工作量才能有效利用張量核心。例如,對(duì)于一個(gè) [1024, K] x [K, 1024] 的矩陣乘法和一個(gè)標(biāo)準(zhǔn)的 [128, 128] 二維 tile 大小,數(shù)據(jù)并行策略最多只能將其分配到 64 個(gè)核心上,這不足以使 GPU 達(dá)到飽和。
在矩陣乘法中沿歸約維度進(jìn)行拆分被稱為 Split-K 矩陣乘法。與 RMSNorm 的情況一樣,使用這種策略會(huì)破壞批次不變性。
矩陣乘法還有一個(gè)額外的復(fù)雜性,即張量核心指令。對(duì)于歸約操作,我們可以一次只處理一行;但高效的矩陣乘法核函數(shù)必須一次性操作一整個(gè) tile。
每條張量核心指令(例如 wgmma.mma_async.sync.aligned.m64n128k16)在內(nèi)部可能有不同的歸約順序。選擇不同張量核心指令的一個(gè)原因可能是 batch size 非常小。例如,如果我們使用的張量核心 PTX 指令操作的是一個(gè)長(zhǎng)度為 256 的 tile,但 batch size 只有 32,那我們幾乎浪費(fèi)了所有的計(jì)算資源!當(dāng) batch size 為 1 時(shí),最快的核函數(shù)通常根本不使用張量核心。
因此,確保矩陣乘法批次不變性的最簡(jiǎn)單方法是:編譯一個(gè)固定的核函數(shù)配置,并將其用于所有形狀的計(jì)算。盡管這會(huì)損失一些性能,但在 LLM 推理場(chǎng)景下,這種損失通常不是災(zāi)難性的。特別是,Split-K 策略在 M 和 N 維度都很小時(shí)才最被需要,而幸運(yùn)的是,在我們的應(yīng)用場(chǎng)景中,N 維度(即模型維度)通常都相當(dāng)大!
批次不變性注意力機(jī)制
在實(shí)現(xiàn)了矩陣乘法的批次不變性之后,注意力機(jī)制又引入了兩個(gè)額外的難題 —— 這也很貼切,因?yàn)樗冒瑑纱尉仃嚦朔ā?
1. 與 RMSNorm 和矩陣乘法僅在特征維度上進(jìn)行歸約不同,注意力機(jī)制現(xiàn)在需要在特征維度和序列維度上都進(jìn)行歸約。
2. 因此,注意力機(jī)制必須處理各種影響序列處理方式的推理優(yōu)化(例如分塊預(yù)填充、前綴緩存等)。
因此,為了在 LLM 推理中實(shí)現(xiàn)確定性,我們的數(shù)值計(jì)算必須對(duì)兩個(gè)因素保持不變:一是單次處理的請(qǐng)求數(shù)量,二是每個(gè)請(qǐng)求在推理引擎中的切分方式。
我們首先來(lái)了解一下注意力機(jī)制的標(biāo)準(zhǔn)并行策略,該策略最初由 FlashAttention-2 提出。與 RMSNorm 和矩陣乘法類似,其默認(rèn)策略是數(shù)據(jù)并行策略。由于歸約是沿著鍵 / 值(K/V)張量進(jìn)行的,因此數(shù)據(jù)并行策略只能沿著查詢(Q)張量進(jìn)行并行化。
例如,根據(jù)推理引擎的選擇,一個(gè)序列可能被分成幾個(gè)部分處理(如在分塊預(yù)填充中),也可能一次性處理完畢(如果預(yù)填充未被分割)。為了實(shí)現(xiàn)批次不變性,對(duì)于一個(gè)給定的 token,其歸約順序必須獨(dú)立于其所在序列中同時(shí)被處理的其他 token 的數(shù)量。
如果你將 KV 緩存中的 K/V 值與當(dāng)前正在處理的 token 的 K/V 值分開進(jìn)行歸約(就像在 vLLM 的 Triton 注意力核函數(shù)中那樣),這個(gè)目標(biāo)就無(wú)法實(shí)現(xiàn)。例如,在處理序列中的第 1000 個(gè)查詢 token 時(shí),無(wú)論 KV 緩存中有 0 個(gè) token(預(yù)填充階段)還是 999 個(gè) token(解碼階段),其歸約順序都必須完全相同。
為解決此問(wèn)題,我們可以在注意力核函數(shù)運(yùn)行前就更新 KV 緩存和頁(yè)表,從而確保無(wú)論處理多少個(gè) token,我們的鍵和值始終具有一致的內(nèi)存布局。
加上這一額外處理(以及前文提到的所有措施,如使用一致的 tile 大小),我們便能實(shí)現(xiàn)一個(gè)批次不變性的注意力機(jī)制!
然而,這里存在一個(gè)重要問(wèn)題。與矩陣乘法不同,LLM 推理中的注意力計(jì)算形狀通常確實(shí)需要一個(gè)拆分 - 歸約核函數(shù)(split-reduction kernel),這類核函數(shù)常被稱為 Split-KV 或 FlashDecoding。這是因?yàn)槿绻覀儾谎刂鴼w約維度進(jìn)行并行,就只能沿著批次維度、頭維度和查詢長(zhǎng)度維度進(jìn)行并行。
在注意力的解碼階段,查詢長(zhǎng)度非常?。ㄍǔ?1),因此除非 batch size 非常大,否則我們往往無(wú)法使 GPU 達(dá)到飽和狀態(tài)。不幸的是,這種情況不像在 RMSNorm 和矩陣乘法中那樣容易被忽略。例如,如果你的 KV 緩存非常長(zhǎng),即使只處理一個(gè)請(qǐng)求,注意力核函數(shù)的計(jì)算也可能耗時(shí)很長(zhǎng)。
此外,常用于注意力的拆分 - 歸約策略也給批次不變性帶來(lái)了挑戰(zhàn)。例如,F(xiàn)lashInfer 的平衡調(diào)度算法會(huì)選擇能夠使 GPU 所有核心飽和的最大拆分大小,這使得其歸約策略并非批次不變的。然而,與 RMSNorm / 矩陣乘法不同,無(wú)論 batch size 如何,僅僅選擇一個(gè)固定的拆分?jǐn)?shù)量是不夠的。
相反,為了實(shí)現(xiàn)批次不變性,我們必須采用固定拆分大小策略。換言之,我們固定的不是拆分的數(shù)量,而是每個(gè)拆分塊的大小,這樣最終會(huì)得到一個(gè)可變的拆分?jǐn)?shù)量。通過(guò)這種方式,我們可以保證無(wú)論正在處理多少個(gè) token,我們總是執(zhí)行完全相同的歸約順序。
實(shí)現(xiàn)
我們基于 vLLM,通過(guò)利用其 FlexAttention 后端和 torch.Library,提供了一個(gè)確定性推理的演示。通過(guò) torch.Library,我們能夠以一種非侵入式的方式替換掉大部分相關(guān)的 PyTorch 算子。
你可以在 thinking-machines-lab/batch-invariant-ops 找到「批次不變性」核函數(shù)庫(kù),以及在「確定性」模式下運(yùn)行的 vLLM 示例。
地址:https://github.com/thinking-machines-lab/batch_invariant_ops
實(shí)驗(yàn)
完成結(jié)果的不確定性程度如何?
我們使用 Qwen3-235B-A22B-Instruct-2507 模型,在溫度為 0 的設(shè)置下,使用提示詞「Tell me about Richard Feynman」(非思考模式)采樣了 1000 次完成結(jié)果,每次生成 1000 個(gè) token。
令人驚訝的是,我們得到了 80 個(gè)不同的完成結(jié)果,其中最常見的一個(gè)出現(xiàn)了 78 次。
通過(guò)觀察這些結(jié)果的差異,我們發(fā)現(xiàn)它們?cè)谇?102 個(gè) token 上實(shí)際上是完全相同的!
首次出現(xiàn)差異是在第 103 個(gè) token。所有的結(jié)果都生成了「Feynman was born on May 11, 1918, in」這個(gè)序列。然而,接下來(lái),其中 992 次結(jié)果生成了「Queens, New York」,而另外 8 次則生成了「New York City」。
然而,當(dāng)我們啟用批次不變性核函數(shù)后,全部 1000 次結(jié)果都變得完全相同。這正是我們期望采樣器應(yīng)有的表現(xiàn),但若不使用我們的批次不變性核函數(shù),就無(wú)法實(shí)現(xiàn)確定性結(jié)果。
性能
目前,我們還沒(méi)有投入精力優(yōu)化批次不變性核函數(shù)的性能。不過(guò),我們還是進(jìn)行了一些實(shí)驗(yàn)來(lái)驗(yàn)證其性能是否仍在可用范圍內(nèi)。
我們搭建了一個(gè)配備單塊 GPU 的 API 服務(wù)器,運(yùn)行 Qwen-3-8B 模型,并請(qǐng)求生成 1000 個(gè)序列,輸出長(zhǎng)度控制在 90 到 110 個(gè) token 之間。
性能下降的主要原因在于 vLLM 中的 FlexAttention 集成尚未經(jīng)過(guò)深度優(yōu)化。盡管如此,我們看到其性能并未出現(xiàn)災(zāi)難性下降。
真正的在策略強(qiáng)化學(xué)習(xí)
正如研究人員所指出的,訓(xùn)練和推理之間的數(shù)值差異會(huì)隱式地將我們的在策略強(qiáng)化學(xué)習(xí)(on-policy RL)轉(zhuǎn)變?yōu)殡x策略強(qiáng)化學(xué)習(xí)(off-policy RL)。
當(dāng)然,如果我們甚至無(wú)法從兩次相同的推理請(qǐng)求中獲得每一位都相同的結(jié)果,那么在訓(xùn)練和推理之間獲得每一位都相同的結(jié)果也是不可能的。因此,確定性推理使我們能夠修改訓(xùn)練堆棧,從而在采樣和訓(xùn)練之間獲得每一位都相同的結(jié)果,最終實(shí)現(xiàn)真正的在策略強(qiáng)化學(xué)習(xí)。
我們?cè)?Bigmath 上,使用 RLVR 設(shè)置進(jìn)行了實(shí)驗(yàn),其中強(qiáng)化學(xué)習(xí)策略由 Qwen 2.5-VL instruct 8B 模型初始化,最大 rollout 長(zhǎng)度為 4096。
如果我們不使用離策略校正(即重要度加權(quán))進(jìn)行訓(xùn)練,我們的獎(jiǎng)勵(lì)會(huì)在訓(xùn)練中途崩潰;而添加離策略校正項(xiàng)則可以使訓(xùn)練順利進(jìn)行。但是,如果我們?cè)诓蓸悠骱陀?xùn)練器之間實(shí)現(xiàn)了每一位都相同的結(jié)果,我們就完全處于在策略狀態(tài)(即 KL 散度為 0),同樣可以順利地進(jìn)行訓(xùn)練。
我們還可以繪制采樣器和訓(xùn)練器之間對(duì)數(shù)概率的 KL 散度,其中所有 3 次運(yùn)行都表現(xiàn)出顯著不同的行為。在使用重要度加權(quán)運(yùn)行時(shí),KL 散度保持在 0.001 左右,并伴有偶爾的峰值。然而,在不使用重要度加權(quán)的情況下運(yùn)行,最終會(huì)導(dǎo)致 KL 散度在大約與獎(jiǎng)勵(lì)崩潰同一時(shí)間出現(xiàn)峰值。當(dāng)然,在運(yùn)行「真正的在策略強(qiáng)化學(xué)習(xí)」時(shí),我們的 KL 散度始終保持為 0,這表明訓(xùn)練策略和采樣策略之間不存在任何差異。
總結(jié)
現(xiàn)代軟件系統(tǒng)往往由多層抽象構(gòu)成。在機(jī)器學(xué)習(xí)中,當(dāng)我們遇到不確定性和一些微妙的數(shù)值差異時(shí),人們往往會(huì)傾向于視而不見。
畢竟,我們的系統(tǒng)本來(lái)就是「概率性的」,再多一點(diǎn)不確定性又有何妨?單元測(cè)試掛掉時(shí),把 atol/rtol 調(diào)大點(diǎn)有什么問(wèn)題?訓(xùn)練器和采樣器之間的對(duì)數(shù)概率差異,應(yīng)該不是真正的 bug 吧?
我們拒絕這種消極心態(tài)。只要稍微多做一些努力,我們就能理解不確定性的根源,甚至真正解決它們!
我們希望這篇博文能為社區(qū)提供一套可靠的思路,幫助大家在推理系統(tǒng)中應(yīng)對(duì)不確定性,并激勵(lì)更多人深入理解自己的系統(tǒng)。
? THE END
轉(zhuǎn)載請(qǐng)聯(lián)系本公眾號(hào)獲得授權(quán)
投稿或?qū)で髨?bào)道:liyazhou@jiqizhixin.com