求解微分方程,用seq2seq就夠了,性能遠超 Mathematica、Matlab

求解微分方程,用seq2seq就夠了,性能遠超 Mathematica、Matlab

作者 | XK

距離用深度學習技術求解符號數學推理問題,或許只差一個恰當的表示和恰當的數據集。

近日,Facebook AI研究院的Guillaume Lample 和Francois Charton兩人在arxiv上發表了一篇論文,標題為《Deep Learning for Symbolic Mathematics》。

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

論文地址:https://arxiv.org/abs/1912.01412

這篇論文提出了一種新的基於seq2seq的方法來求解符號數學問題,例如函數積分、一階常微分方程、二階常微分方程等複雜問題。其結果表明,這種模型的性能要遠超現在常用的能進行符號運算的工具,例如Mathematica、Matlab、Maple等。

有例為證:

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

上圖左側幾個微分方程,Mathematica和Matlab都求解失敗,而作者所提的模型卻能夠獲得右側的正確結果(這不是個案,而是普遍現象,具體可見後文)。

更有意思的是,這還並不僅僅是它的唯一好處。由於seq2seq模型的特點,作者所提方法能夠對同一個公式得出不止一個的運算結果,例如如下的微分方程

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

該模型能夠反饋這麼多的結果:

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

可以驗證一下,這些結果都是正確的,至多差一個常數 c。

我們來看下這樣美好的結果,作者是如何做到的。(其實很簡單!)

一、總體思路

首先需要強調,在過往中,機器學習(包括神經網絡)是一種統計學習方法,這些方法被證明在統計模式識別方面非常有效,例如在CV、NLP、語音識別等問題上均已經達到了超過人類的性能。

但機器學習(這裡特別強調是神經網絡)卻不適合去解決符號推理問題,目前僅有少數這樣的工作,但主要集中在解決基本的算術任務(例如加法和乘法)上,且實驗上證明在這些問題上,神經網絡的方法往往表現不佳,需要引入一些已有的指向任務的組件才勉強可行。

相比於以往的各種方法,作者思想獨特,他們認為數學符號計算的過程本質上就是一個模式識別的過程。由此他們將數學(尤其是符號計算)視為一個 NLP 模型問題,符號推理等同於seq2seq的「機器翻譯」過程。(真是“機器翻譯”解決一切啊)

具體來講,作者在文章中主要針對函數積分和常微分方程(ODE)進行研究。

學過高等數學的我們都有過求積分和解微分方程的痛苦經歷,對計算機軟件來講,求解這些問題事實上也同樣困難。以函數積分為例,人類在求解過程中主要是依賴一些規則(例如基本函數的積分公式、換元積分、部分積分等);而傳統的計算機代數系統則主要是通過從大量具體的案例中進行搜索,例如對用於函數積分的Risch算法的完整描述就超過了100頁。

但,回過頭,我們思考,從本質上來講,求積分的過程不正是一個模式識別的過程嗎?當給你一個公式yy′(y^2 + 1)^{−1/2},你會從腦海中牢牢記住的數十、數百個積分模型中尋找出「模式」最為匹配的結果\\sqrt{y^2 + 1}。

基於這種思路,作者首先提出了將數學表達式轉換為seq2seq表示形式的方法,並用多種策略生成了用於監督學習的數據集(積分、一階和二階微分方程),然後將seq2seq模型用於這些數據集,便得出了比最新計算機代數程序Matlab、Mathematica等更好的性能。

二、表示:從數學公式到seq

作者將數學問題視作自然語言處理的問題,因此首要一步便是將數學公式轉化為NLP模型能夠處理的形式,即序列(seq)。

這分兩步:

首先,將數學公式轉化為樹結構。

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

運算符和函數(例如cos、pow等)為內部節點,數字、常數和變量為葉。可以看出這裡每一個數學公式都對應唯一一個樹結構。

需要強調兩點:

  • 這裡把2+3 和 3 +2視作不同的數學公式;

  • 這裡x/0、log(0)等在數學中認為是無效的函數表達式在這裡並不會排除在外;

由於樹和表達式之間存在一一對應的關係,因此表達式之間的相等性,將反映在它們相關的樹上。作為等價關係,由於 2 + 3 = 5 = 12-7 = 1×5,所以這對應於這些表達式的四棵樹是等價的。

形式數學的許多問題都可以重組為對錶達式或樹的運算。例如,表達式簡化等於找到樹的較短等效表示。

在這篇文章中,作者考慮兩個問題:符號積分和微分方程。兩者都可以歸結為將一個表達式轉換為另一個表達式。例如在函數積分中,將 cos(x) 的樹映射到其解 sin(x)+c 的樹。這本質上就是機器翻譯的一個特殊實例,而已。

其次,將樹轉化為序列。

這很顯然,機器翻譯模型運行在序列(seq)。針對這一步,學過計算機的同學應該都不陌生,作者選用了前綴表示法,從左到右,將每個節點寫在其子節點前面。例如 2 + 3×(5+2),表示為序列為 [+ 2 * 3 + 5 2]。這裡,在序列內部,運算符、函數或變量由特定的標記表示。就像在表達式和樹之間的情況一樣,樹和前綴序列之間也存在一對一的映射。

三、數據集生成

當有了合適的表示之後,另一個重要的事情便是如何生成恰當的數據集。作者採用生成隨機表達式的算法(具體這裡不再贅述),如果用p1表示一元運算子(例如cos、sin、exp、log等)的集合,p2表示二元運算子(例如+、-、×、÷等)的集合,L表示變量、常數、整數的集合,n 為一棵樹的內部節點個數(因此也是表達式中運算子的個數)。可以計算,表達式的個數與n之間有如下關係:

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

要訓練網絡模型,就需要有(問題,解決方案)對的數據集。理想情況下,我們應該生成問題空間的代表性樣本,即隨機生成要積分的函數和要求解的微分方程。但我們知道,並不是所有的函數都能夠積分(例如f=exp(x^2)和f=log(log(x)))。為了生成大型的訓練集,作者提出了一些技巧。

在這裡我們以積分為例(ODE-1 和ODE-2 數據集的生成方法這裡不再贅述,可參見論文)。作者提出了三種方法:

Forward generation(FWD)。給定n 個運算子的表達式,通過計算機代數系統求解出該表達式的積分;如果不能求解,則將該表達式丟棄。顯然這種方式獲得的數據集只是問題空間的一個子集,也即只包含符號框架可以求解的函數積分;且求積分的過程往往是非常耗時的。

Backward generation(BWD)。求微分是容易的。因此我們可以先隨機生成積分表達式f,然後再對其進行微分得到 f',將(f,f')添加到數據集當中。這種方法不會依賴於符號積分系統。這種方法生成的數據集也有一定的問題:1)數據集中簡單積分函數的數量很少,例如 f=x^3 sin(x),其對應的積分式微F=-x^3 cos(x) + 3x^2sin(x) + 6x cos(x) - 6 sin(x),這是一個有15個運算子的表達式,隨機生成的概率相對來說會小一些;2)表達式的微分往往會比表達式本身更長,因此在BWD方式所生成的數據集中,積分(問題的解)傾向短於積分函數(問題)。

Backward generation withintegration by parts(IBP)。為了克服BWD所存在的問題,作者提出IBP的方法,即利用分部積分

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

隨機生成兩個函數F和G,如果已知fG和它的積分式已經在數據集當中,那麼就可以求解出Fg的積分式,然後把Fg和它的積分式放入數據集。反之也可以求解 fG 的積分式。如果fG和Fg都不在數據集中,那麼可以按照BWD的方式求解FG 對應的微分fg。不斷迭代,從而獲得數據集。

可以對比一下不同的方式,生成數據集的特點:

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

這裡假設了 n= 15,L ={x} ∪ {-5, ... , 5} \\ {0}, p2={+,-, ×, ÷}, p1= {exp, lgo, sqrt, sin, cos, tan, sin-1, cos-1, tan-1, sinh, cosh,tanh, sinh-1, cosh-1, tanh-1}。

可以看出 FWD和 IBP 傾向於生成輸出比輸入更長的樣本,而 BWD 方法則生成較短的輸出。與 BWD 情況一樣,ODE 生成器傾向於生成比其方程式短得多的解。

補充一點,生成過程中清洗數據也非常重要。這包括幾個方面:

1)方程簡化。例如將 x+1+1+1+1 簡化為x +4

2)係數簡化。例如x + x tan(3) + cx +1 簡化為 cx +1

3)清除無效表達式。例如 log(0)。

四、模型

這篇文章中所使用的模型比較簡單,就是一個seq2seq的模型,當給定一個問題的表達式(seq),來預測其對應的解的表達式(seq)。

訓練

具體來說,作者使用了一個transformer模型,有 8 個注意力頭,6層,512維。(在這個案例中,大的模型並不能提高性能)

在訓練中,作者使用了Adam優化器,學習率為10E-4。對於超過512個token的表達式,直接丟棄;每批使用256個表達式對進行訓練。

在推斷過程中,作者使用了帶有early stopping的beam搜索方法來生成表達式,並通過序列長度來歸一化beam中假設的對數似然分數。

注意一點,在生成過程中沒有任何約束,因此會生成一些無效的前綴表達式,例如[+ 2 * 3]。這很好解決,直接丟棄就行了,並不會影響最終結果。

評估

在機器翻譯中,一般採用對人工翻譯進行對比的BLEU分數作為指標來評價翻譯質量,但許多研究表明,更好的BLEU分數並不一定與更好的表現有關。不過對求解積分(或微分方程)來說,評估則相對比較簡單,只要將生成的表達式與其參考解進行簡單比較,就可以驗證結果的正確性了。例如微分方程xy′ − y + x =0的參考解為 x log(c/ x) ,模型生成的解為 x log(c) − x log(x),顯然這是兩個等價的方程。

由於對錶達式是否正確可以很容易地進行驗證,因此作者提出如果生成的beam中的表達式中,只要有一個正確,則表示模型成功解決了輸入方程(而不是隻選用得分最高的)。例如當 beam =10時,也即生成 10 個可能的解,只要有一個正確即表明模型成功輸出結果正確。

五、結果

1、實驗結果

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

從上表可以看出,

1)在積分中即使讓beam=1,模型的準確性也是很高的。

2)beam=1時,ODE結果並不太理想。不過當beam尺寸增大時,結果會有非常顯著的提升。原因很簡單,beam大了,可供挑選的選項也就多了,正確率自然會提高。

2、與三大著名數學軟件對比

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

這個表格顯示了包含 500 個方程的測試集上,本文模型與Mathematica、Matlab、Maple三大著名數學軟件的比較。對於Mathematica,假設了當時間超過30s而未獲得解則認為失敗(更多時延的對比可見論文原文附錄)。對於給定的方程式,本文的模型通常會在不到 1 秒的時間裡找到解決方案。

從正確率上可以看出,本文方法要遠遠優於三大著名數學軟件的結果。

3、等價解

這種方法最有意思的地方出現了。通常你用符號求解軟件,只能得到一個結果。但這種seq2seq 的方法卻能夠同時給你呈現一系列結果,它們完全等價,只是用了不同的表示方式。具體案例,我們前面已經提到過,這裡不再贅述。

4、通用性研究

在前面提到的實驗結果中,測試集與訓練集都來自同一種生成方法。但我們知道每一種生成方法都只是問題空間的一個子集。那麼當跨子集測試時會出現什麼現象呢?

求解微分方程,用seq2seq就够了,性能远超 Mathematica、Matlab

結果很吃驚。

1)當用FWD數據集訓練,用BWD數據集進行測試,分數會極低;不過好在用IBP數據集測試,分數還行;

2)同樣的情況,當用BWD數據集訓練,用FWD數據集進行測試,結果也很差;意外的是,用IBP數據集測試,結果也不理想;

3)當把三個數據集結合在一起共同作為訓練集時,測試結果都還不錯。

這說明

1)FWD數據集和BWD數據集之間的交集真的是非常小;

2)數據集直接影響模型的普適性,因此如何生成更具代表性的數據集,是這種方法未來一個重要的研究內容。

六、總結

我們用幾句話來總結這項工作的意義:

1、本文提出了一種新穎的、利用seq2seq模型求解符號數學推理的方法,這種方法是普遍的,而非特定模型;

2、如何生成更具代表性的數據集,有待進一步研究;

3、完全可以將類似的神經組件,內嵌到標準的數學框架(例如現在的3M:Mathematica、Matlab、Maple)的求解器當中,這會大大提升它們的性能。

鎖定AI研習社,AI社區獨家直播


分享到:


相關文章: