最適輸送が好きな6つの理由

こんにちは。Sansan株式会社 技術本部 研究開発部の大田尾 匠と申します。

普段は、営業DXサービス「Sansan」でメールから氏名や会社名などの名刺相当の送信者情報を抽出する「メール署名取り込み」という機能の性能改善に取り組んでいます。メールは過去のスレッドも含めるとテキスト長が膨大になることや、複数人の情報が含まれていることもあり、その中から送信者情報のみを抽出するのは簡単ではなく、アルゴリズムに工夫し甲斐がある面白い問題です。

この記事では、私が学部・修士の時に研究していた「最適輸送 (Optimal Transport)」という手法について紹介します。最適輸送は、物流の最適化をするというわけではなく、確率分布と確率分布を比較する手法で、最適化問題として数式で定義できます。問題そのものが面白く、最適化問題に工夫を加えることで考えられる亜種の存在や、目まぐるしい速度で発展している自然言語処理へ応用できることも踏まえて、とても魅力的だと私は思っています。以下では、私がそんな最適輸送を好きな理由を6個紹介しようと思います。

最適輸送が好きな理由

視覚的にわかりやすい

最適輸送が魅力的な理由の一つとして、視覚的に何をしているのかが分かりやすいという点があります。最適輸送は最適化問題の一種であり、数式は一見複雑になりますが、図でイメージがつかみやすい手法です。この章では、最適輸送とは何か、ということを図を用いて直感的に理解できることを説明します。

まず、本記事では「確率分布」は特に断りがない場合、確率分布の中でも点と重み(確率値)が対応付けられたものの集合を対象とすることとします。これは離散分布と呼ばれるものです。他にも連続分布と呼ばれる確率分布もあり、最適輸送で比較できますが、本記事では扱いません。また、確率分布では重みの総和は1になります。それぞれが1/4の重みを持つ4つの点から構成される、赤と青の2つの確率分布の例を図1に示しました。最適輸送を使うと、赤の確率分布と青の確率分布を比較できます。

図1: 2つの確率分布

図1を例に使って、最適輸送という手法は何をしているかを説明します。最適輸送とは一言でいうと、「片方の確率分布を移動させて、もう片方の確率分布に一致させるための最適な移動方法」といえます。点の重みを砂山と捉えると、確率分布を移動させるというのは、確率分布の各点の砂山を別の点に輸送させることに対応します。砂山を輸送する方法は無限にありますが、最適輸送ではコストが最も少ない輸送方法を計算で求め、その方法に従って輸送します。図1の2つの確率分布における最適な輸送は、図2のようになります。赤い確率分布の各点が、それぞれ最も近い青い確率分布の点に砂山を輸送しており、全体として最も輸送コストが少ない輸送になっています。これに対して、最適ではない輸送の例の1つを図3に示します。図2の輸送と比べて輸送の矢印が長く、輸送コストが高くなっていることが分かります。

(左)図2: 最適な輸送 、(右)図3: 最適でない輸送

また、上の例では赤と青の点群の数を等しくしましたが、異なっていても問題ありません。例えば、青の点群を2つにし、各重みを1/2にした場合には、最適な輸送は図4のようになり、青い点群の各点に対して、赤い点群の2つずつの点から砂山が輸送されています。

図4: 点の数が異なる場合の最適な輸送

確率分布を比較できるとどういうことができるか、についても触れようと思います。最適輸送は「片方の確率分布を移動させて、もう片方の確率分布に一致させるための最適な移動方法」であり、輸送コストが低ければ低いほど確率分布の距離が近い、ということができます。図5の3つの確率分布において、赤の分布は青の分布と緑の分布のどちらの方が距離が近いかを考えてみます。

図5: 三つの確率分布

視覚的には赤と青の確率分布は近い箇所にあり、赤と緑の確率分布は離れた箇所にあります。最適輸送を使うと、この視覚的なイメージを、輸送コストという数値で表現できます。具体的には、赤と緑の確率分布間よりも、赤と青の確率分布間の方が最適輸送コストが小さいということが計算で分かるので、赤の確率分布から見ると、緑よりも青の方が距離が近い、ということができます。

この確率分布間の距離を考える応用先は多くあります。例えば、自然言語処理において、一つの文章を単語集合からなる確率分布としてとらえると、単語集合間の最適輸送コストは文章の意味の近さの指標になります。また、画像生成分野において、画像集合を確率分布としてとらえると、データセットの画像集合と生成モデルが出力した画像集合の間の最適輸送コストは、生成モデルがどれだけ元データセットに近い画像を出力しているかの指標になります。

確率分布を比較する、と聞くと一見難しげに聞こえますが、このように最適輸送は視覚的にわかりやすく、応用先も多岐にわたる面白い手法です。ここまではイメージの話をしましたが、最適な輸送を求める具体的な方法について次の章で説明します。

シンプルな最適化問題として定式化できる

最適輸送の最適化問題としての定式化を説明します。まず、n個の重み付き点群である確率分布\muと、m個の重み付き点群である確率分布\nuを数式で定義します。d次元の空間において、n個の点群の位置を\mathbf{x}_1, ..., \mathbf{x}_n \in \mathbb{R}^dm個の点群の位置を\mathbf{y}_1, ..., \mathbf{y}_m \in \mathbb{R}^dとします。そして、\mathbf{x}_iの重みをa_i\mathbf{y}_jの重みをb_jとします。このとき、\sum_{i=1}^n a_i = \sum_{j=1}^m b_j = 1を満たしているとします。この時、二つの確率分布は、\mu=(\{\mathbf{x}_i\}_{i=1}^n, \{a_i\}_{i=1}^n)\nu=(\{\mathbf{y}_j\}_{j=1}^m, \{b_j\}_{j=1}^m) のように、位置と重みの組で定義できます。また、点\mathbf{x}_iから点\mathbf{y}_jに重み1を輸送する際にかかるコストをC(\mathbf{x}_i, \mathbf{y}_j)とします。ここまでで数式的な準備は完了です。

最適輸送とは、「片方の確率分布を移動させて、もう片方の確率分布に一致させるための最適な移動方法」でした。点\mathbf{x}_iから点\mathbf{y}_jへの重みの輸送量をP_{ij}としたとき、全体の輸送プランは行列\mathbf{P}で表現され、確率分布\mu\nuの間の最適輸送は以下の最適化問題を用いて求められます。


【最適輸送問題】

\mathbf{x}_iから点\mathbf{y}_jへの輸送コストは、C(\mathbf{x}_i, \mathbf{y}_j) P_{ij}のように(重み1を輸送する際にかかるコスト) \times (輸送量)で定義されます。この時、式(1)は確率分布\mu\nuの間の総輸送コストを意味します。輸送プラン\mathbf{P}を変数としてこの総輸送コストを最小化するというのが最適輸送問題です。

しかし、輸送プラン\mathbf{P}が何でもよいというわけではなく、subject to 以下の3つの制約を課す必要があります。まず、輸送元である点\mathbf{x}_1, ..., \mathbf{x}_nから輸送する重みの総量がそれぞれa_1, ..., a_nに一致するという制約が式(2)です。次に、輸送先である点\mathbf{y}_1, ..., \mathbf{y}_mに輸送される重みの総量がそれぞれb_1, ..., b_mに一致するという制約が式(3)です。最後に、輸送量は0以上という制約が式(4)です。

このように、最適輸送問題は目的関数と制約が全て線形(1次式)である、線形計画問題と呼ばれる最適化問題として定式化できます。線形計画問題は最適化問題の中ではシンプルな部類であり、解法が確立していて、実用的にはソルバを用いて解を得られます。

最適化問題の定義ができたところで、図1を例として実際に最適輸送コストを求めてみます。赤の確率分布を\mu、 青の確率分布を\nuとすると、n=m=4であり、位置と重みを数式で図に明記すると図6のようになります。

図6: 確率分布\mu(赤色)と\nu(青色)

また、各点間のコスト行列も以下のように定義します。ここで、C_{ij}=C(\mathbf{x}_i, \mathbf{y}_j)です。


\begin{equation*}
\mathbf{C} = \begin{pmatrix}
0.1 & 0.4 & 0.5 & 0.6 \\
0.4 & 0.1 & 0.6 & 0.5 \\
0.2 & 0.3 & 0.1 & 0.3 \\
0.3 & 0.2 & 0.3 & 0.1 \\
\end{pmatrix}
\end{equation*}

この時、最適な輸送(図2)と最適でない輸送の一つ(図3)に各輸送の輸送量を明記すると、図7と図8のようになります。

(左)図7: 最適な輸送、(右)図8: 最適でない輸送

最適な輸送の総輸送コストは、


\begin{equation*}
\begin{split}
& \quad \: C(\mathbf{x}_1, \mathbf{y}_1)P_{11} + C(\mathbf{x}_2, \mathbf{y}_2)P_{22} + C(\mathbf{x}_3, \mathbf{y}_3)P_{33} + C(\mathbf{x}_4, \mathbf{y}_4)P_{44} \\
&= 0.1*0.25 + 0.1*0.25 + 0.1 * 0.25 * 0.1 * 0.25 \\
&= 0.1
\end{split}
\end{equation*}

となり、最適でない輸送の総輸送コストは、


\begin{equation*}
\begin{split}
& \quad \: C(\mathbf{x}_1, \mathbf{y}_2)P_{12} + C(\mathbf{x}_2, \mathbf{y}_1)P_{21} + C(\mathbf{x}_3, \mathbf{y}_4)P_{34} + C(\mathbf{x}_4, \mathbf{y}_3)P_{43} \\
&= 0.4*0.25 + 0.4*0.25 + 0.3*0.25 * 0.3 * 0.25 \\
&= 0.35
\end{split}
\end{equation*}

となるため、最適な輸送の方が総輸送コストが小さいことが計算から分かります。

正則化を考えるとGPU上で効率的に計算できる

最適輸送問題は線形計画問題として定式化でき、ソルバを用いて解けます。ただ、この最適化問題は、解を求めるのに時間がかかるという問題があります。線形計画問題を解くには、点群の数の約三乗に比例した計算時間がかかります。これは、点群の数が10倍になると計算時間が約10^3倍になるということであり、点群の数が多くなると計算がかなり遅くなります。
最適輸送分野では、 【最適輸送問題】の目的関数である式(1)に正則化と呼ばれる項を加え、より高速に最適化問題を解く最適輸送の亜種が提案されています。 [1]

では、具体的な定式化を確認していきます。まず、行列\mathbf{P} \in \mathbb{R}^{n \times m}に対するエントロピー関数を以下のように定義します。

\begin{equation*}
h(\mathbf{P}) = - \sum_{i=1}^n \sum_{j=1}^m P_{ij} (\log P_{ij} - 1)
\end{equation*}

ここで、\log 0 = 0とします。エントロピー関数は、行列\mathbf{P}の各値が一様であるほど大きくなるという特徴があります。\mathbf{P}の要素数はnmなので、\mathbf{P}の要素の総和を1とすると、h(\mathbf{P})を最大にするのは\mathbf{P}のすべての要素が1/nmの時となります。

【最適輸送問題】の目的関数にエントロピー関数を正則化として加えた最適化問題は以下になります。制約式(2)(3)(4)に変化はありません。


【エントロピー正則化付き最適輸送問題】

\begin{equation*}
\begin{split}
\min_{\mathbf{P} \in \mathbb{R}^{n \times m}} \quad & \sum_{i=1}^n \sum_{j=1}^m \: C(\mathbf{x}_i, \mathbf{y}_j) P_{ij} - \epsilon \: h(\mathbf{P}) \\
\text{subject to} \quad & \sum_{j=1}^m P_{ij} = a_i
\quad (\forall i)\\
& \sum_{i=1}^n P_{ij} = b_j \quad (\forall j) \\
& P_{ij} \ge 0 \quad \:\:\: (\forall i, \forall j)
\end{split}
\end{equation*}

\epsilon\ge0は目的関数における正則化の影響の大きさを決めるハイパーパラメータです。\epsilon=0のとき、この問題は元の問題に一致します。\epsilonを大きくするほど、エントロピー関数の影響が大きくなり、目的関数を小さくするために第一項よりも正則化項を重要視する方向に解\mathbf{P}^*が変化します。エントロピー関数を大きくするのは一様な行列なので、\epsilonを大きくすると解\mathbf{P}^*は一様な行列に近づいていきます。一つ注意点としては、正則化を加えたことで、元の最適輸送問題とは解\mathbf{P}^*と最適輸送コストがともに一致しないので、あくまで最適輸送の亜種としてとらえる必要はあります。

正則化付きの最適輸送問題は線形ではない項を加えたので、線形計画ソルバでは解けず、シンクホーンアルゴリズム [2] という別の解法で解くことが提案されており [1]、この記事では概要を紹介します。

まず、 【エントロピー正則化付き最適輸送問題】を主問題と呼ぶと、目的関数の最適値が同じで、変数を変えた双対問題と呼ばれる別の問題を導出できます。主問題と双対問題の最適値は一致しないこともありますが、 【エントロピー正則化付き最適輸送問題】に関しては一致することを示せます。そのため、主問題の最適値を求めるためには双対問題を解けば十分です。導出は省略しますが、主問題の変数\mathbf{P}が登場せず、新しい二つの変数\mathbf{f} \in \mathbb{R}^n, \mathbf{g} \in \mathbb{R}^mについて最適化を行う双対問題は以下になります。


【エントロピー正則化付き最適輸送問題の双対問題】
\begin{equation}
\max_{\mathbf{f} \in \mathbb{R}^n, \mathbf{g} \in \mathbb{R}^m} \sum_{i=1}^n f_i a_i + \sum_{j=1}^m g_j b_j - \epsilon \sum_{i=1}^n \sum_{j=1}^m \exp(\frac{f_i + g_j - C_{ij}}{\epsilon})
\end{equation}

主問題と比べると、変数に対する制約がなくなっており、より解きやすい問題になっています。双対問題の解を求める前に、まず新しい変数\mathbf{u} \in \mathbb{R}^n, \mathbf{v} \in \mathbb{R}^m, \mathbf{K} \in \mathbb{R}^{n \times m}

\begin{align}
& \mathbf{u}=\exp(\mathbf{f}/\epsilon) \\
&\mathbf{v}=\exp(\mathbf{g}/\epsilon) \\
& \mathbf{K}=\exp(- \mathbf{C}/\epsilon) \\
\end{align}

と定義します。ここで、\exp()は、ベクトルや行列の各要素に対して指数関数の計算を行う演算であり、u_{i}=e^{f_i / \epsilon}, \:\: u_{j}=e^{g_j / \epsilon}, \:\: K_{ij}=e^{-C_{ij}  / \epsilon}です。二つの変数\mathbf{u}, \mathbf{v}の最適値は、以下のように交互に最適化を繰り返すことで求められます。この交互最適化アルゴリズムをシンクホーンアルゴリズムと呼びます。

\begin{align}
& 1. \quad\quad \mathbf{u}^{t} = \mathbf{a} ./ (\mathbf{K} \mathbf{v}^{t-1}) \\
& 2. \quad\quad\mathbf{v}^{t} = \mathbf{b} ./ (\mathbf{K}^{\top} \mathbf{u}^{t}) \\
\end{align}

ここで、tは繰り返しの回数であり、1から始めて、最大何回繰り返すかを決める必要があります。もしくは、変数の値の変化が小さくなったら繰り返しを止めるということもできます。\mathbf{v}^{0}は値がすべて1のベクトルで初期化することが多いです。また、演算子./はベクトルの要素ごとの商を意味します。シンクホーンアルゴリズムで求めた最適解\mathbf{u}^*, \mathbf{v}^*を使うと、【エントロピー正則化付き最適輸送問題の双対問題】の解\mathbf{f}^*, \mathbf{g}^*

\begin{align}
& \mathbf{f}^* = \epsilon \log \mathbf{u}^* \\
& \mathbf{g}^* = \epsilon \log \mathbf{v}^* \\
\end{align}

で得られます。また、主問題の解\mathbf{P}^*

\begin{align}
\mathbf{P}^* = \text{diag} (\mathbf{u}^*)\: \mathbf{K}\: \text{diag}(\mathbf{v}^*)
\end{align}

という簡単な行列計算で得られます。ここで、\text{diag} (\mathbf{u}^*)は対角成分が\mathbf{u}^*であり、その他の値が0である対角行列を意味します。

このように、正則化付き最適輸送問題の解は、双対問題を考え、シンクホーンアルゴリズムを使うと行列演算の繰り返しで求められます。行列\mathbf{K}n \times m次元であることから、シンクホーンアルゴリズムの1回の反復ではnmに比例した計算時間が必要で、アルゴリズム全体としてnm、つまり点群の数の二乗に比例した時間で最適輸送を計算できます。これは元の最適輸送の三乗時間と比べて高速であることがわかります。また、線形計画のソルバと違って行列演算で解を求められるので、GPU上で効率的に実装できます。理論的な計算時間が三乗時間から二乗時間に減少したことと、GPU上で高速に計算できることから、実際にはシンクホーンアルゴリズムはかなり高速に動作します。

自然言語処理に応用できる

この章では、最適輸送の自然言語処理への応用について説明します。自然言語処理への応用は、私が最も面白いと思う最適輸送の応用先であり、多くの研究が過去になされています。
自然言語処理分野において、最適輸送を使うと二つの文章の類似度を測れます。文章の類似度は重要であり、手元の文章と似た文章を検索する文章検索や、文章のカテゴリを判断する文章分類などに利用されます。

最適輸送をどう応用するかを説明する前に、まず単語埋め込みという手法について説明します。単語埋め込みとは、一つの単語を多次元のベクトルで表現する手法であり、埋め込みの距離が近いほど単語の意味が似ているということができます。また、単語間の意味的な関係性を埋め込みの演算によって表現することもできます。例えば、単語\text{A}の埋め込みを\textbf{emb}(\text{A})と書くことにすると、

  • \textbf{emb}(\text{king}) - \textbf{emb}(\text{man}) + \textbf{emb}(\text{woman}) \approx \textbf{emb}(\text{queen)}

といったような関係性を捉えられます。

最適輸送では、文章を単語の集合として捉え、単語集合と単語集合の間の輸送コストを計算することで文章の類似度を計算します。その際に必要になる単語間の距離を、単語埋め込みを用いて計算します。最適輸送の入力は確率分布なので、単語が持つ重みを定義する必要がありますが、文中に出現した単語の数に比例する重みを設定する方法 [3]や、単語埋め込みのノルム(ベクトルの大きさ)に比例する重みを設定する方法 [4]が提案されています。最適輸送のイメージを [3]のFigure 1.から引用した図9を用いて説明します。

図9: 単語集合を最適輸送で比較するイメージ ([3]のFigure 1.より引用)

図9では、以下の二つの文章の類似度を計算しようとしています。

  1. Obama speaks to the media in Illinois.
  2. The President greets the press in Chicago.

最適輸送を適応するために、まずword2vec [5]と呼ばれる単語埋め込みを使って二つの文章の単語を単語埋め込み空間で表現します。(もちろん他の単語埋め込みも使用可能です。)この時、to・the・inといった文章の意味に大きく影響しない単語は除きます。そして、単語の重みを適切に設定して最適輸送を計算すると、図の矢印で表現されている輸送の対応関係を得られます。輸送ペアは('Obama', 'President')、('speaks', 'greets')、 ('media', 'press')、('Illinois', 'Chicago')となり、二つの文章の中で意味が似ている単語間で重みが輸送されていることが分かります。このように単語集合間で計算した最適輸送コストは、文章の非類似度(どれだけ似ていないか)を表現できます。つまり、最適輸送コストが小さいほど二つの文章が似ており、大きいほど似ていないということです。文章の類似度を計算する手法は他にも、文章一つを文章ベクトルとして埋め込み、ベクトル対のコサイン類似度を計算する手法などがありますが、文類似度タスクの複数のデータセットにおいて、最適輸送ベースの手法は文章ベクトルを使った手法よりも2020年時点で高性能であることが実験で示されています。 [4]

ここまでは最適輸送のコストに主に着目してきましたが、同時に単語間の対応付け(アラインメント)を得ることもできます。図9を例にすると、二つの文章を比較するうえで、どの単語対に着目したかということが輸送量を見ると一目瞭然にわかります。これは、文章ベクトルでコサイン類似度を計算するときには得られない対応関係です。最適輸送ベースの手法で得られる単語アラインメントを使うと、単一言語間の単語アラインメントタスク(Monolingual word alignment)においてSOTAの手法に匹敵する性能を得られることが報告されており [6]、自然言語処理分野において、文章の類似度計算だけではなく単語アラインメントの点でも最適輸送は有効な手法だといえます。

簡単なPythonコードで実行できる

最適輸送はPythonのライブラリが整備されており、簡単なコードで動かせる部分も私が好きなところです。以下では、前の章で説明した文章の類似度計算をPythonで動かすコードを紹介しようと思います。ここからは、以下の3つの文章を考えます。

  1. Obama speaks to the media in Illinois.
  2. The President greets the press in Chicago.
  3. Optimal transport is interesting.

文章1と文章2は似ていますが、文章1と文章3は全く意味の違う文章です。
この時、文章1と文章2の非類似度と、文章1と文章3の非類似度をそれぞれ最適輸送で計算して、意味の違いを反映できているかを確認します。以下のコードは、Google Colaboratory上で実行しました。

まず、必要なライブラリを用意します。今回使うのは、数値計算を効率的に行うためのライブラリであるnumpyと、単語埋め込みが提供されているgensimと、最適輸送に関する実装がまとまっているPython Optimal Transportです。Python Optimal Transportを以下ではPOTと略します。POTは、私自身も最適輸送に関する研究でよく使っていました。まず、これら3つのライブラリをインストールします。

!pip install numpy
!pip install gensim
!pip install POT

そして、以下のコードでインポートします。

import numpy as np
import gensim.downloader
import ot

次に、3つの文章を単語のリストに分割します。to・the・in・isといった文章の意味に大きく影響しない単語は除いて、以下のようなリストを用意します。

sent1 = ["Obama", "speaks", "media", "Illinois"]
sent2 = ["President", "greets", "press", "Chicago"]
sent3 = ["Optimal", "transport", "interesting"]

次に、単語埋め込みを用意します。gensimの学習済み単語埋め込みは多数提供されていますが、今回はGoogle Newsと呼ばれるコーパスで学習した300次元の埋め込みであるword2vec [5]を使うことにします。まず、学習済み単語埋め込みをロードします。大きなモデルなのでここは少し時間がかかります。

word2vec_model = gensim.downloader.load('word2vec-google-news-300')

そしてこの単語埋め込みを使って、先ほどの単語リストの単語をそれぞれ単語埋め込みに変換し、numpyの行列形式で保持します。

sent1_emb = np.array([word2vec_model[word] for word in sent1])
sent2_emb = np.array([word2vec_model[word] for word in sent2])
sent3_emb = np.array([word2vec_model[word] for word in sent3])

例えばsent1_embの次元は以下のようになっており、sent1_emb[0]には"Obama"の単語埋め込みが格納されています。

print(sent1_emb.shape)
>>>
(4, 300)

ここまでで、単語埋め込みの準備は完了です。ここから、最適輸送に必要な入力を準備します。最適輸送を計算するうえで必要なのは単語の重みと、単語間で計算されるコスト行列なので、これらを用意します。単語の重みは、一様なものとして定義します。

weight1 = np.ones(len(sent1)) / len(sent1) # [1/4, 1/4, 1/4, 1/4]
weight2 = np.ones(len(sent2)) / len(sent2) # [1/4, 1/4, 1/4, 1/4]
weight3 = np.ones(len(sent3)) / len(sent3) # [1/3, 1/3, 1/3]

次に、文章1と文章2の単語間のコスト行列と、文章1と文章3の単語間のコスト行列を定義します。単語間の距離は、2つの単語埋め込みを\mathbf{x} \in \mathbb{R}^d, \mathbf{y}\in \mathbb{R}^dとすると、\|\mathbf{x} - \mathbf{y}\|_2 = \sqrt{\sum_{i=1}^d (x_i - y_i)^2}で定義されるユークリッド距離を使います。

cost_12 = np.zeros((len(sent1), len(sent2))) # 文章1と文章2の単語間のコスト行列
for i in range(len(sent1)):
    for j in range(len(sent2)):
        cost_12[i][j] = np.linalg.norm(sent1_emb[i] - sent2_emb[j])

cost_13 = np.zeros((len(sent1), len(sent3))) # 文章1と文章3の単語間のコスト行列
for i in range(len(sent1)):
    for j in range(len(sent3)):
        cost_13[i][j] = np.linalg.norm(sent1_emb[i] - sent3_emb[j])

重みとコスト行列が用意できたので、実際に最適輸送を計算していきます。エントロピー正則化なしの最適輸送問題で計算される最適輸送コストは、以下のように求められます。

ot_cost_12 = ot.emd2(weight1, weight2, cost_12)
ot_cost_13 = ot.emd2(weight1, weight3, cost_13)

print(ot_cost_12)
print(ot_cost_13)
>>>
2.865866243839264
3.985972623030344

文章1と文章2の最適輸送コストの方が、文章1と文章3の最適輸送コストよりも小さいことがわかりました。最適輸送コストは文章の非類似度(どれだけ似ていないか)だったので、文章1と文章2の方がより意味が似ている文章だということができ、人間の理解と一致しています。また、POTライブラリを使うと最適輸送プランも得られます。

ot_plan_12 = ot.emd(weight1, weight2, cost_12)

# sent1 = ["Obama", "speaks", "media", "Illinois"]
# sent2 = ["President", "greets", "press", "Chicago"]
# P_{ij}: sent1のi番目の単語とsent2のj番目の単語の間での輸送量
# 例えば、P_{00}は"Obama"と"President"の間での輸送量で、P_{01}は"Obama"と"greets"の間での輸送量

print(ot_plan_12)
>>>
[[0.25 0.   0.   0.  ]
 [0.   0.25 0.   0.  ]
 [0.   0.   0.25 0.  ]
 [0.   0.   0.   0.25]]

輸送量が0より大きいペアが('Obama', 'President')、('speaks', 'greets')、('media', 'press')、('Illinois', 'Chicago')であり、それぞれ意味が似た単語間で輸送があることが分かります。このように、POTライブラリを使うと文章の非類似度と単語アラインメントを得られます。

また、POTライブラリを使うとエントロピー正則化付きの最適輸送問題も計算できます。内部ではシンクホーンアルゴリズムが計算されています。こちらでも文章の非類似度を比較してみます。正則化係数の大きさを手動で決める必要があるので、ここでは0.5に設定します。

entropic_cost_12 = ot.sinkhorn2(weight1, weight2, cost_12, 0.5)
entropic_cost_13 = ot.sinkhorn2(weight1, weight3, cost_13, 0.5)

print(entropic_cost_12)
print(entropic_cost_13)
>>>
3.138847458686559
4.0402936166229875

正則化なしの最適輸送問題で求めた輸送コストとは違った結果になっていますが、文章1と文章2の最適輸送コストの方が小さいのは変わっていません。正則化付きの最適輸送問題で求めた最適輸送プランも見てみましょう。

entropic_plan_12 = ot.sinkhorn(weight1, weight2, cost_12, 0.5)

# sent1 = ["Obama", "speaks", "media", "Illinois"]
# sent2 = ["President", "greets", "press", "Chicago"]

print(entropic_plan_12)
>>>
[[0.14760676 0.03096842 0.03329064 0.03813418]
 [0.02828278 0.18505624 0.02776536 0.00889561]
 [0.04130818 0.01963072 0.17702336 0.01203775]
 [0.03280228 0.01434461 0.01192064 0.19093247]]

エントロピー正則化を入れた分、最適輸送プランは一様な行列に近づいていくため、輸送の対応が1対1ではなくなります。ただ、輸送量が大きいペアを見ると、やはり('Obama', 'President')、('speaks', 'greets')、('media', 'press')、('Illinois', 'Chicago')が対応していることが分かります。次に、正則化係数を極端な大きな値1000に設定した結果を見てみます。

entropic_plan_12_reg1000 = ot.sinkhorn(weight1, weight2, cost_12, 1000.0)

print(entropic_plan_12_reg1000)
>>>
[[0.06252602 0.06248701 0.06248779 0.06249919]
 [0.0624861  0.06255462 0.06249384 0.06246544]
 [0.06249567 0.06248222 0.06254948 0.06247262]
 [0.0624922  0.06247615 0.06246889 0.06256275]]

正則化係数が大きいと、最適輸送問題自体の目的関数ではなく、エントロピー関数が優先されて最適化が行われるため、最適輸送プランがほぼ一様な行列になっています。実際にエントロピー正則化付き最適輸送問題を使う場合には、1以下の値を使うことが多いです。

以上では、最適輸送の自然言語処理への応用のコードについて説明しました。最適輸送を計算する部分の実装はライブラリが整備されているため簡単なので、自然言語処理に限らずどの分野でも最適輸送は実際に手を動かしやすい手法だと思います。もし何かのタスクに取り組んでいて、最適輸送が使えそうな部分が見えた際は、ぜひ実際に手を動かしてみてください!

教材が豊富

最適輸送は、勉強できる教材が充実している点も私が好きなところの一つです。まず、国立情報学研究所の佐藤竜馬先生が、「最適輸送の理論とアルゴリズム」 [7]という最適輸送の本を書かれています。この本は最適輸送の理論から応用まで詳細に書かれています。証明も省略されることなく詳しい記述があるので、全体的にとても分かりやすく最適輸送を勉強するにはもってこいだと思います。例えば、シンクホーンアルゴリズムは双対問題から最適解の収束性まで丁寧に解説されているので、この記事を読んで厳密性が気になった方はぜひ本を手に取ってみてください。

また、学会のチュートリアル講演などで使用されたスライドも公開されています。前述した本の著者である佐藤竜馬先生がIBIS2021で講演された「最適輸送入門」のスライドや、東北大学の横井祥先生が言語処理学会第28回年次大会のチュートリアルで講演された「最適輸送と自然言語処理」のスライドが公開されており、どれもわかりやすい図や説明があるので、最適輸送を勉強するときには大いに役立つと思います。

まとめ

この記事では、私が最適輸送を好きな理由を6つ紹介しました。説明のために数式が多くなってしまった部分はありますが、「最適輸送面白いな」と思っていただければ嬉しいです。この記事では紹介しきれませんでしたが、木構造を使ってさらに高速に最適輸送を計算したり、強力な生成モデルである拡散モデルにも応用したりなど、最適輸送にはまだまだ面白い部分があります。
この記事はほんの導入しか触れていないですが、最適輸送を勉強しようと思っている方の一助になれば幸いです。

最適輸送は集合同士を比較できるので応用先が広い手法ですが、弊社Sansanではまだ最適輸送を使ったプロダクトは存在しておらず、これから社内に最適輸送を浸透させて盛り上げていきたいと思っています。もし最適輸送とSansanの研究開発の両方に興味がある方がいれば、ぜひ弊社で一緒に最適輸送を盛り上げましょう!

参考文献

  • [1] Marco Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. Advances in neural information processing systems, 2013.
  • [2] Richard Sinkhorn. Diagonal equivalence to matrices with prescribed row and column sums. The American Mathematical Monthly, Vol. 74, No. 4, pp. 402–405, 1967
  • [3] Matt Kusner, Yu Sun, Nicholas Kolkin, and Kilian Weinberger. From word embeddings to document distances. In International conference on machine learning, 2015.
  • [4] Sho Yokoi, Ryo Takahashi, Reina Akama, Jun Suzuki, and Kentaro Inui. Word rotator’s distance. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), 2020.
  • [5] Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg S Corrado, and Jeff Dean. Distributed representations of words and phrases and their compositionality.Advances in neural information processing systems, 2013.
  • [6] Yuki Arase, Han Bao, and Sho Yokoi. Unbalanced optimal transport for unbalanced word alignment. In Proceedings of the 61st Annual Meeting ofthe Association for Computational Linguistics, 2023
  • [7] 佐藤竜馬. 最適輸送の理論とアルゴリズム. 講談社, 2023.