Categorical Reparameterization with Gumbel-Softmax [arXiv:1611.01144]

2016年11月12日

概要

はじめに

Deep Learningなどでクラス分類を行う場合、カテゴリカル分布と呼ばれる分布を用いて属するクラスの確率を求めると思います。

たとえばMNISTであれば10個のクラスを用意し、10次元の出力ベクトルをsoftmax関数に通すことでカテゴリカル分布を作ります。

categorycal

上の画像はクラス数が6個の場合の分布の例です。

この分布からサンプリングを行うとクラスを得ることができます。

Deep Learningではクラスを表す変数をスカラーではなくone-hotなベクトルとするのが一般的ですので、たとえばクラス2を表すベクトルz

\[z=(0,1,0,0,0,0)
\\]

のように2番目の要素だけ1で他の要素はすべて0となります。

カテゴリカル分布からのサンプリングは一般的に、データxとパラメータϕのニューラルネットfϕ、さらにsoftmax関数を用いて

\[zsoftmax(fϕ(x))
\\]

のように行います。

(クラス分類の場合はサンプリングではなくargmaxで確率最大のクラスを取ります)

得られたサンプルzはパラメータϕで微分することができないため、この論文ではGumbel-Softmax分布を用いたreparameterization trickにより微分可能なサンプリングを実現しています。

Gumbel-Softmax

以下、クラス数をk、クラス変数zk次元のベクトルとし、それぞれのクラスの確率をπ=(π1,π2,πk)とします。

Gumbel-Max trick

Gumbel-Max trickはargmaxを用いてカテゴリカル分布からサンプリングを行うことができる手法です。

まずノイズgを以下のように生成します。

\[uUniform(0,1)g=log(log(u))
\\]

ノイズはクラスの数だけ生成します。

次に以下のような操作によってサンプリングを行います。

\[z=one_hot(argmaxi[gi+logπi])
\\]

one_hotはクラスの番号からone-hotなベクトルを作る関数です。

この式がもし以下のような形であれば、何度argmaxしても同じクラスが出力されますが、

\[z=one_hot(argmaxi[πi])
\\]

式(5)のようにノイズを乗せて分布の形を変えることで、argmaxしたときに違うクラスが出力されるようになり、擬似的にサンプリングが行えます。

Gumbel-Softmax分布

Gumbel-Max trickでは、カテゴリカル分布にノイズを乗せ、argmaxしてからone_hotすることでクラス変数zをサンプリングできるようになりました。

論文ではこのargmaxしてからone-hotなベクトルに変える処理を省略し、分布から直接one-hotなベクトルをサンプリングするために、Gumbel-Softmax分布を提案しています。

この分布は温度パラメータτを導入し、クラス変数zの各要素の値を以下のように定義します。

\[zi=exp((log(πi)+gi)/τ)kj=1exp((log(πi)+gi)/τ)
\\]

これはsoftmax関数の操作と同じですので、以下のように書くことができます。

\[z=softmax((log(π)+g)/τ)
\\]

確率ベクトルπはニューラルネットから出力させますので、式(2)と同じ記号を使うと

\[z=softmax((log(fϕ(x))+g)/τ)
\\]

のようになります。

式(2)との違いは、ノイズgを決めるとクラス変数zが決定的に求まるということです。

さらに式(9)によって得られるzはパラメータϕで微分することができます。

実際に可視化したほうがわかりやすいので、Gumbel-Softmax分布と温度τの関係を見ていきます。

まずsoftmax(fϕ(x))は以下のような形をしています。

original

100回サンプリングすればクラス2が80回くらい出るような分布です。

次に温度τ=0.1とした時に式(9)から得られるzの各要素の値です。

original

one-hotなベクトルがサンプリングされました。

次に温度τ=1,5,100とした時に式(9)から得られるzの各要素の値です。

τ=1

original

τ=5

original

τ=100

original

全くone-hotではないベクトルが生成されました。

実はGumbel-Softmax分布は温度が低いとone-hotな形になり、温度が高いと一様分布の形になる分布です。

形がone-hotなのでこれをそのままone-hotなクラス変数にしようというのが論文のアイディアです。

実際に温度τ=0.1のときに式(9)を複数回実行した場合に得られるzは以下のようになります。

original

クラス2が頻出しているため、サンプリングが上手くできていることがわかります。

次に、温度τ=0.1,1,5,100のときに式(9)をそれぞれ100回実行し、得られたzを全部足して平均を取ったものは以下のようになります。

τ=0.1

original

τ=1

original

τ=5

original

τ=100

original

平均(つまり期待値)は温度が低いほどもとの分布に近く、温度が高いほど一様分布に近づきます。

以上のことをまとめると以下のことが言えます。

  • Gumbel-Softmax分布からのサンプルは、温度が低いとone-hotなベクトルに近づく
  • Gumbel-Softmax分布の期待値は温度が低いとカテゴリカル分布に近づく
  • 温度が低い時のGumbel-Softmax分布からのサンプルはone_hot(argmax(…))の近似になっている

この手法はone-hotなベクトルをサンプリングするというよりは、one-hotな形の分布をサンプリングする手法と言ったほうが良いのかもしれません。

実験に用いたコードを載せておきます。

# -*- coding: utf-8 -*-
import math
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from chainer import Variable
from chainer import functions as F
from chainer import links as L

sns.set(style="white", context="talk")

x = Variable(np.random.normal(0, 1, size=(1, 10)).astype(np.float32))
layer = L.Linear(10, 6, wscale=3)
cat = F.softmax(layer(x))
plt.clf()
plt.ylim(ymax=1, ymin=0)
plot = sns.barplot(np.arange(1, 7), cat.data[0], color="#d73c2c")
plot.figure.savefig("original")
eps = 1e-12
log_pi = layer(x)

temperature = 100
plt.clf()
plt.ylim(ymax=1, ymin=0)
u = np.random.uniform(0, 1, 6).astype(np.float32)
g = -np.log(-np.log(u + eps) + eps)
cat = F.softmax((log_pi + g) / temperature)
plot = sns.barplot(np.arange(1, 7), cat.data[0], color="#ca2c68")
plot.figure.savefig("sample_%d" % (temperature * 10))
expectation = np.zeros((6,), dtype=np.float32)
for i in xrange(100):
	u = np.random.uniform(0, 1, 6).astype(np.float32)
	g = -np.log(-np.log(u + eps) + eps)
	cat = F.softmax((log_pi + g) / temperature)
	expectation += cat.data[0]
expectation /= 100
plt.clf()
plt.ylim(ymax=1, ymin=0)
plot = sns.barplot(np.arange(1, 7), expectation, color="#ca2c68")
plot.figure.savefig("expectation_%d" % (temperature * 10))


temperature = 5
plt.clf()
plt.ylim(ymax=1, ymin=0)
u = np.random.uniform(0, 1, 6).astype(np.float32)
g = -np.log(-np.log(u + eps) + eps)
cat = F.softmax((log_pi + g) / temperature)
plot = sns.barplot(np.arange(1, 7), cat.data[0], color="#6e248d")
plot.figure.savefig("sample_%d" % (temperature * 10))
expectation = np.zeros((6,), dtype=np.float32)
for i in xrange(100):
	u = np.random.uniform(0, 1, 6).astype(np.float32)
	g = -np.log(-np.log(u + eps) + eps)
	cat = F.softmax((log_pi + g) / temperature)
	expectation += cat.data[0]
expectation /= 100
plt.clf()
plt.ylim(ymax=1, ymin=0)
plot = sns.barplot(np.arange(1, 7), expectation, color="#6e248d")
plot.figure.savefig("expectation_%d" % (temperature * 10))


temperature = 1
plt.clf()
plt.ylim(ymax=1, ymin=0)
u = np.random.uniform(0, 1, 6).astype(np.float32)
g = -np.log(-np.log(u + eps) + eps)
cat = F.softmax((log_pi + g) / temperature)
plot = sns.barplot(np.arange(1, 7), cat.data[0], color="#0067b0")
plot.figure.savefig("sample_%d" % (temperature * 10))
expectation = np.zeros((6,), dtype=np.float32)
for i in xrange(100):
	u = np.random.uniform(0, 1, 6).astype(np.float32)
	g = -np.log(-np.log(u + eps) + eps)
	cat = F.softmax((log_pi + g) / temperature)
	expectation += cat.data[0]
expectation /= 100
plt.clf()
plt.ylim(ymax=1, ymin=0)
plot = sns.barplot(np.arange(1, 7), expectation, color="#0067b0")
plot.figure.savefig("expectation_%d" % (temperature * 10))


temperature = 0.1
plt.clf()
plt.ylim(ymax=1, ymin=0)
u = np.random.uniform(0, 1, 6).astype(np.float32)
g = -np.log(-np.log(u + eps) + eps)
cat = F.softmax((log_pi + g) / temperature)
plot = sns.barplot(np.arange(1, 7), cat.data[0], color="#009c41")
plot.figure.savefig("sample_%d" % (temperature * 10))
expectation = np.zeros((6,), dtype=np.float32)
for i in xrange(100):
	u = np.random.uniform(0, 1, 6).astype(np.float32)
	g = -np.log(-np.log(u + eps) + eps)
	cat = F.softmax((log_pi + g) / temperature)
	expectation += cat.data[0]
expectation /= 100
plt.clf()
plt.ylim(ymax=1, ymin=0)
plot = sns.barplot(np.arange(1, 7), expectation, color="#009c41")
plot.figure.savefig("expectation_%d" % (temperature * 10))

for i in xrange(10):
	plt.clf()
	plt.ylim(ymax=1, ymin=0)
	u = np.random.uniform(0, 1, 6).astype(np.float32)
	g = -np.log(-np.log(u + eps) + eps)
	cat = F.softmax((log_pi + g) / temperature)
	plot = sns.barplot(np.arange(1, 7), cat.data[0], color="#9e6c4b")
	plot.figure.savefig("z_%d" % i)

savefigはファイル名にピリオドが使えないので10倍しています。

半教師あり学習での利用

変分オートエンコーダと呼ばれるVAEADGMで半教師あり学習を行う場合、目的関数はだいたい以下のような感じになります。

\[a(l)qϕ(ax)z(l)qϕ(z(l)a(l),x,y)L(x,y)1NMCNMCl=1logpθ(a(l)x,y,z(l))pθ(xy,z(l))p(y)p(z(l))qϕ(a(l)x)qϕ(z(l)a(l),x,y)f()=logpθ(a,x,y,z)qϕ(a,zx,y)U(x)1NMCNMCl=1{y{qϕ(ya(l),x)f()}+H(qϕ(ya(l),x))}
\\]

L(x,y)はラベルありの場合の目的関数なのでここでは無関係です。)

ラベルなしデータの目的関数U(x)を求める際、変分オートエンコーダでは隠れ変数yz(ADGMではさらにa)を消去するのですが、zaNMC個のサンプルを用いて消去するのに対し、yは全てのクラスについて周辺化yを行うことで消去していました。

このyはMNISTのようにクラスが10個しかなければ、メモリを大量に使う富豪的なテクニックを使うと一発で求めることができるのですが、クラス数が数十~数百もある場合、全クラスを列挙して周辺化をすると計算量が増大する問題点があります。

しかし、今回のGumbel-Softmaxを用いるとクラスyも微分可能な形でサンプリングできるため、周辺化をサンプリングで近似することで高速化できると考えられます。

つまり、

\[U(x)1NMCNMCl=1{y{qϕ(ya(l),x)logpθ(a(l),x,y,z(l))qϕ(a(l),z(l)x,y)}+H(qϕ(ya(l),x))}
\\]

を、

\[y(l)qϕ(ya(l),x)U(x)1NMCNMCl=1{logpθ(a(l),x,y(l),z(l))qϕ(a(l),z(l)x,y(l))+H(qϕ(ya(l),x))}
\\]

のように計算することができます。

そこでMNISTの半教師あり学習で実験を行いました。

MNISTはテストデータ10,000枚、訓練データ60,000枚からなりますが、訓練データをさらに10,000枚のバリデーションデータ、49,900枚のラベルなしデータ、100枚のラベルありデータに分割し学習を行いました。

以下のグラフは各世代のバリデーションデータの分類精度です。

original

Marginalizeの方は全クラスの周辺化、Gumbelはyを1回だけサンプリングして目的関数を近似しています。

両方ともNMC=1としているので、zaはどちらも1回だけサンプリングしています。

Gumbel-Softmaxでも精度は引けを取らないということが分かりました。

おわりに

Gumbel-Softmaxからのサンプリングは必ずしもone-hotなベクトルとはならないので、これでいいのか?と思っていますが、学習はうまくいくのでまあいいかなと思います。