確率的プログラミングPyro入門
確率的プログラミングPyro入門
はじめまして。eureka-BIチームの小林です。
普段は卓球とスプラトゥーンをやっています。
この記事は eureka Engineering Advent Calendar 2017 – Qiita の17日目の記事です。
16日目は サマーインターン参加者かつSREでインターン中のdatchこと原田くんの
「Pairsのテキストデータを学習させたword2vecを使って、コミュニティを分類してみた」です。
はじめに
BIチームでは、様々な数字を分析することで、プロダクトの意思決定に貢献しています。
その中で、データからモデルを作成し、予測を立てるといった業務をすることがあります。
今までは、簡単な線形回帰でのモデル作成に留まりがちで、知識としてもMCMCで止まっていたので、
今後、確率的プログラミングを取り入れたモデリングをしていきたいと思い、最近発表されたばかりのPyroを触ってみました。
実際に業務に用いられるレベルにはマスターできていないので、今回は導入から簡単な使い方までを、
公式のIntroductionをなぞっていくことで説明します。
確率的プログラミング言語Pyro
PyroはUBERが公開している確率的プログラミング言語です。
プログラミング言語、とは言うものの実際はPythonのライブラリとして公開されています。
確率的プログラミングについては、ここに書くと長くなってしまうので割愛しますが、以下の記事が詳しいです。
確率的プログラミング | POSTD
PyroはPyTorchをバックエンドに使い、高速なテンソル計算と自動微分を実現しています。
Pyroに似たものにEdwardがあり、こちらはTensorflowをバックエンドに利用しています。
Pyroの導入
PyroではPyTorchが必要になるので、インストールしておいてください。
PyTorchが入っていれば、Pyro自体はpip install pyro-pplでインストールできます。
Pyroの基礎
変数は全てtorchのtensorをVariableで包んだ形で保持します。
言葉で表現すると意味不明なのですが、コードで書くとすなわち
1234 import torchfrom torch.autograd import Variable hoge = Variable(torch.Tensor(hage))
こうなります。
つまり、mu = 0 と sigma = 1 を表現すると
12 mu = Variable(torch.zeros(1)) # zerosは0のみのtensorを生成するsigma = Variable(torch.ones(1))
こうなります。
平均 mu 分散 sigma の正規分布にしたがう x は以下で表現できます。
123 import pyro.distributions as dist x = dist.normal(mu, sigma)
または、
12 import pyrox = pyro.sample("my_sample", dist.normal, mu, sigma)
とすることで、my_sample という名前を用いたサンプリングとして定義することもできます。
また、この x の時の対数確率密度の値は
1 log_p_x = dist.normal.log_pdf(x, mu, sigma)
で取得できます。
Pyroでのモデリング
上記の基礎構文を用いて、天気と気温の関係を表現するモデルを作成します。
12345678910 def weather(): cloudy = pyro.sample('cloudy', dist.bernoulli, Variable(torch.Tensor([0.3]))) cloudy = 'cloudy' if cloudy.data[0] == 1.0 else 'sunny' mean_temp = {'cloudy': [55.0], 'sunny': [75.0]}[cloudy] sigma_temp = {'cloudy': [10.0], 'sunny': [15.0]}[cloudy] temp = pyro.sample('temp', dist.normal, Variable(torch.Tensor(mean_temp)), Variable(torch.Tensor(sigma_temp))) return cloudy, temp.data[0]
順を追って説明すると、
・ 2~4行目では、cloudyはベルヌーイ分布により30%の確率で曇り、70%の確率で晴れとなることを表しています。
・ 5,6行目では気温の従う分布の平均と分散を、天気を条件にして定めます。(変数はテンソルで扱うために、配列として渡しています。)
・ 最後に、上で定めた値を元に正規分布より気温の値を決め、天気と気温を返します。
このように、統計的な分布に基づくランダムな値を生成するモデルを作成することができます。
Pyroは、Pythonで作成されているので、統計的な関数はPythonで用いるような複雑な書き方もできます。
例えば、再帰的な書き方をしたモデルの例として、
12345678 def geometric(p, t=None): if t is None: t = 0 x = pyro.sample("x_{}".format(t), dist.bernoulli, p) if torch.equal(x.data, torch.zeros(1)): return x else: return x + geometric(p, t+1)
このように書けます。
サンプリングする際は、必ずユニークな名前を付ける必要があるので、再帰回数ごとに、x_1, x_2 …としています。
また、以下のように他の確率的関数を入力としたり、出力とすることもできます。
1234567891011121314 def normal_product(mu, sigma): z1 = pyro.sample("z1", dist.normal, mu, sigma) z2 = pyro.sample("z2", dist.normal, mu, sigma) y = z1 * z2 return y def make_normal_normal(): mu_latent = pyro.sample("mu_latent", dist.normal, Variable(torch.zeros(1)), Variable(torch.ones(1))) fn = lambda sigma: normal_product(mu_latent, sigma) return fn print(make_normal_normal()(Variable(torch.ones(1))))
Pyroによる推定
重点サンプリングによって、周辺分布を求めることができます。
例えば、毎回測定誤差が出るような秤のモデルを例に置くと、
123 def scale(guess): weight = pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) return pyro.sample("measurement", dist.normal, weight, Variable(torch.Tensor([0.75]))
と定義した秤のモデルに対して、
1 posterior = pyro.infer.Importance(scale, num_samples=100)
とすると、重点サンプリングがなされます。
しかし、posterior単体では有用なオブジェクトではなく、pyro.infer.Marginalによる周辺化に用いられます。
123 guess = Variable(torch.Tensor([8.5]))marginal = pyro.infer.Marginal(posterior)print(marginal(guess))
marginalは、scaleを重点サンプリングしたposteriorからヒストグラムを生成し、
それを元に、guessの値が与えられた場合の分布から値をサンプリングします。
同じ引数を持つmarginalを複数回呼び出すと、同じヒストグラムからサンプリングされるので、
1234 plt.hist([marginal(guess).data[0] for _ in range(100)], range=(5.0, 12.0))plt.title("P(measurement | guess)")plt.xlabel("weight")plt.ylabel("#")
とすると、同一のヒストグラムを元にサンプリングされるので、元の形が再現されていきます。

パラメータ調整
確率的プログラミングによるモデリングの有用性は、
観測値によってモデルを調整することで、データ生成における潜在的な要因を推定する能力にあります。
例えば、秤のモデルにおいて,計測値が8.5になる場合はこのように表現します。
12 conditioned_scale = pyro.condition( scale, data={"measurement": Variable(torch.Tensor([8.5]))})
パラメータ調整において、下記のように引数を与えられるようにもできます。
12 def deferred_conditioned_scale(measurement, *args, **kwargs): return pyro.condition(scale, data={"measurement": measurement})(*args, **kwargs)
また、conditionメソッドではなく、obsパラメータを用いる書き方や、pyro.observeを用いた下記からも存在します。
12345678910111213 ## equivalent to pyro.condition(scale, data={"measurement": Variable(torch.ones(1))})def scale_obs(guess): z = pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) # here we attach an observation measurement == 1 return pyro.sample("measurement", dist.normal, weight, Variable(torch.ones(1)), obs=Variable(torch.Tensor([0.1]))) ## equivalent to scale_obs:def scale_obs(guess): z = pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) # here we attach an observation measurement == 1 return pyro.observe("measurement", dist.normal, Variable(torch.ones(1)), weight, Variable(torch.Tensor([0.1])))
ただし、モデル中でハードコーディングすることはあまり推奨されないので、
pyro.conditionによって、モデルを変更することなく条件を与える方が良いです。
また、複数の条件を与える書き方は複数あり、
123456789101112131415161718 def scale2(guess): weight = pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) tolerance = torch.abs( pyro.sample("tolerance", dist.normal, Variable(torch.zeros(1)), Variable(torch.ones(1)))) return pyro.sample("measurement", dist.normal, weight, tolerance) conditioned_scale2_1 = pyro.condition( pyro.condition(scale2, data={"weight": Variable(torch.ones(1))}), data={"measurement": Variable(torch.ones(1))}) conditioned_scale2_2 = pyro.condition( pyro.condition(scale2, data={"measurement": Variable(torch.ones(1))}), data={"weight": Variable(torch.ones(1))}) conditioned_scale2_3 = pyro.condition( scale2, data={"weight": Variable(torch.ones(1)), "measurement": Variable(torch.ones(1))})
3つのconditionメソッドは同質です。
秤のモデルにおいて、pyro.conditionを使ってguessとmeasurementを与えた時のweightの値について推測したいときは、先ほどの重点サンプリングの例と同様にして、
12345678910111213 guess = Variable(torch.Tensor([8.5]))measurement = Variable(torch.Tensor([9.5])) conditioned_scale = pyro.condition(scale, data={"measurement": measurement}) marginal = pyro.infer.Marginal( pyro.infer.Importance(conditioned_scale, num_samples=100), sites=["weight"]) print(marginal(guess))plt.hist([marginal(guess)["weight"].data[0] for _ in range(100)], range=(5.0, 12.0))plt.title("P(weight | measurement, guess)")plt.xlabel("weight")plt.ylabel("#")
ただし、これらのやり方は、事前分布に関する情報や制約がないため計算効率が悪いです。
そこで、Pyroでは、Guideを利用して効率化することができます。
例えば以下のように書くことで、推定を効率化できます。
12345678 def scale_prior_guide(guess): return pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) posterior = pyro.infer.Importance(conditioned_scale, guide=scale_prior_guide, num_samples=10) marginal = pyro.infer.Marginal(posterior, sites=["weight"])
または、weightの事後分布は、guessとmeasurementによって表されるので、
123456789 def scale_posterior_guide(measurement, guess): a = (guess + torch.sum(measurement)) / (measurement.size(0) + 1.0) b = Variable(torch.ones(1)) / (measurement.size(0) + 1.0) return pyro.sample("weight", dist.normal, a, b) posterior = pyro.infer.Importance(deferred_conditioned_scale, guide=scale_posterior_guide, num_samples=20)marginal = pyro.infer.Marginal(posterior, sites=["weight"])
と書くことができます。
今回の秤のモデルは、自ら中の仕組みを組み上げているため、
正確な事後分布を書くことができますが、
一般的には正確な事後分布を推定するのは難しいです。
そのため、変分推論と呼ばれるアプローチによって、近似的な事後確率を求めます。
PyroによるSVI(簡易説明)
pyro.paramはpyro.sampleのように、第一引数で名前をつけられます。
初回呼び出し時には、名前とその引数が結びつけられ、その後の呼び出しでは、
他の引数にかかわらず、名前によって値が返されます。
1234 def scale_parametrized_guide(guess): a = pyro.param("a", Variable(torch.randn(1) + guess.data.clone(), requires_grad=True)) b = pyro.param("b", Variable(torch.randn(1), requires_grad=True)) return pyro.sample("weight", dist.normal, a, torch.abs(b))
PyroのSVIについて、公式でSVIのためのチュートリアルが用意されているので、
今回は詳しい説明は省きますが、秤のモデルに適応した簡単なものは以下のように書けます。
12345678910111213 pyro.clear_param_store()svi = pyro.infer.SVI(model=conditioned_scale, guide=scale_parametrized_guide, loss="ELBO") losses = []for t in range(1000): losses.append(svi.step(guess)) plt.plot(losses)plt.title("ELBO")plt.xlabel("step")plt.ylabel("loss")

今回は、optimによる最適化手法の選択と、
lossでの損失関数の指定については説明しません。
また、以下のように最適化されたガイドを重点サンプリングの重要度分布として使用すると、
以前よりも少ないサンプルで周辺分布を推定できます。
1234567 posterior = pyro.infer.Importance(conditioned_scale, scale_parametrized_guide, num_samples=10)marginal = pyro.infer.Marginal(posterior, sites=["weight"]) plt.hist([marginal(guess)["weight"].data[0] for _ in range(100)], range=(5.0, 12.0))plt.title("P(weight | measurement, guess)")plt.xlabel("weight")plt.ylabel("#")
ガイドから直接、事後分布の近似としてサンプリングすることもできます。
1234 plt.hist([scale_parametrized_guide(guess).data[0] for _ in range(100)], range=(5.0, 12.0))plt.title("P(weight | measurement, guess)")plt.xlabel("weight")plt.ylabel("#")
まとめ
以上、簡単にではありますがPyroの導入から基本的な使い方を、Introductionに沿って説明させていただきました。
本当は、SVIの解説を詳しくやっていくつもりだったのですが、文量が10倍になりそうでしたので、今回は省かせていただきました。
また、Edwardとの比較や、pystanなどとの速度比較もしたかったのですが、次回のお楽しみとさせていただきます。
ちなみに、pyroを日本語に訳すと、「火」「熱」「高温」という意味があるようです。
プロジェクトの炎上を連想させる「火」をチョイスするセンスは見習いたいものですね。
明日は、BIチームでもっともホスピタリティのある鈴木さん aka ミニオンさんによる
「非エンジニアがSQLを学習する際の10の心得」です。お楽しみに!
エウレカでは、一緒に働いていただける方を絶賛募集中です。募集中の職種はこちらからご確認ください!皆様のエントリーをお待ちしております!
確率的プログラミングPyro入門
はじめまして。eureka-BIチームの小林です。
普段は卓球とスプラトゥーンをやっています。
この記事は eureka Engineering Advent Calendar 2017 – Qiita の17日目の記事です。
16日目は サマーインターン参加者かつSREでインターン中のdatchこと原田くんの
「Pairsのテキストデータを学習させたword2vecを使って、コミュニティを分類してみた」です。
はじめに
BIチームでは、様々な数字を分析することで、プロダクトの意思決定に貢献しています。
その中で、データからモデルを作成し、予測を立てるといった業務をすることがあります。
今までは、簡単な線形回帰でのモデル作成に留まりがちで、知識としてもMCMCで止まっていたので、
今後、確率的プログラミングを取り入れたモデリングをしていきたいと思い、最近発表されたばかりのPyroを触ってみました。
実際に業務に用いられるレベルにはマスターできていないので、今回は導入から簡単な使い方までを、
公式のIntroductionをなぞっていくことで説明します。
確率的プログラミング言語Pyro
PyroはUBERが公開している確率的プログラミング言語です。
プログラミング言語、とは言うものの実際はPythonのライブラリとして公開されています。
確率的プログラミングについては、ここに書くと長くなってしまうので割愛しますが、以下の記事が詳しいです。
確率的プログラミング | POSTD
PyroはPyTorchをバックエンドに使い、高速なテンソル計算と自動微分を実現しています。
Pyroに似たものにEdwardがあり、こちらはTensorflowをバックエンドに利用しています。
Pyroの導入
PyroではPyTorchが必要になるので、インストールしておいてください。
PyTorchが入っていれば、Pyro自体はpip install pyro-pplでインストールできます。
Pyroの基礎
変数は全てtorchのtensorをVariableで包んだ形で保持します。
言葉で表現すると意味不明なのですが、コードで書くとすなわち
1 2 3 4 | import torchfrom torch.autograd import Variablehoge = Variable(torch.Tensor(hage)) |
こうなります。
つまり、mu = 0 と sigma = 1 を表現すると
1 2 | mu = Variable(torch.zeros(1)) # zerosは0のみのtensorを生成するsigma = Variable(torch.ones(1)) |
こうなります。
平均 mu 分散 sigma の正規分布にしたがう x は以下で表現できます。
1 2 3 | import pyro.distributions as distx = dist.normal(mu, sigma) |
または、
1 2 | import pyrox = pyro.sample("my_sample", dist.normal, mu, sigma) |
とすることで、my_sample という名前を用いたサンプリングとして定義することもできます。
また、この x の時の対数確率密度の値は
1 | log_p_x = dist.normal.log_pdf(x, mu, sigma) |
で取得できます。
Pyroでのモデリング
上記の基礎構文を用いて、天気と気温の関係を表現するモデルを作成します。
1 2 3 4 5 6 7 8 9 10 | def weather(): cloudy = pyro.sample('cloudy', dist.bernoulli, Variable(torch.Tensor([0.3]))) cloudy = 'cloudy' if cloudy.data[0] == 1.0 else 'sunny' mean_temp = {'cloudy': [55.0], 'sunny': [75.0]}[cloudy] sigma_temp = {'cloudy': [10.0], 'sunny': [15.0]}[cloudy] temp = pyro.sample('temp', dist.normal, Variable(torch.Tensor(mean_temp)), Variable(torch.Tensor(sigma_temp))) return cloudy, temp.data[0] |
順を追って説明すると、
・ 2~4行目では、cloudyはベルヌーイ分布により30%の確率で曇り、70%の確率で晴れとなることを表しています。
・ 5,6行目では気温の従う分布の平均と分散を、天気を条件にして定めます。(変数はテンソルで扱うために、配列として渡しています。)
・ 最後に、上で定めた値を元に正規分布より気温の値を決め、天気と気温を返します。
このように、統計的な分布に基づくランダムな値を生成するモデルを作成することができます。
Pyroは、Pythonで作成されているので、統計的な関数はPythonで用いるような複雑な書き方もできます。
例えば、再帰的な書き方をしたモデルの例として、
1 2 3 4 5 6 7 8 | def geometric(p, t=None): if t is None: t = 0 x = pyro.sample("x_{}".format(t), dist.bernoulli, p) if torch.equal(x.data, torch.zeros(1)): return x else: return x + geometric(p, t+1) |
このように書けます。
サンプリングする際は、必ずユニークな名前を付ける必要があるので、再帰回数ごとに、x_1, x_2 …としています。
また、以下のように他の確率的関数を入力としたり、出力とすることもできます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | def normal_product(mu, sigma): z1 = pyro.sample("z1", dist.normal, mu, sigma) z2 = pyro.sample("z2", dist.normal, mu, sigma) y = z1 * z2 return ydef make_normal_normal(): mu_latent = pyro.sample("mu_latent", dist.normal, Variable(torch.zeros(1)), Variable(torch.ones(1))) fn = lambda sigma: normal_product(mu_latent, sigma) return fnprint(make_normal_normal()(Variable(torch.ones(1)))) |
Pyroによる推定
重点サンプリングによって、周辺分布を求めることができます。
例えば、毎回測定誤差が出るような秤のモデルを例に置くと、
1 2 3 | def scale(guess): weight = pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) return pyro.sample("measurement", dist.normal, weight, Variable(torch.Tensor([0.75])) |
と定義した秤のモデルに対して、
1 | posterior = pyro.infer.Importance(scale, num_samples=100) |
とすると、重点サンプリングがなされます。
しかし、posterior単体では有用なオブジェクトではなく、pyro.infer.Marginalによる周辺化に用いられます。
1 2 3 | guess = Variable(torch.Tensor([8.5]))marginal = pyro.infer.Marginal(posterior)print(marginal(guess)) |
marginalは、scaleを重点サンプリングしたposteriorからヒストグラムを生成し、
それを元に、guessの値が与えられた場合の分布から値をサンプリングします。
同じ引数を持つmarginalを複数回呼び出すと、同じヒストグラムからサンプリングされるので、
1 2 3 4 | plt.hist([marginal(guess).data[0] for _ in range(100)], range=(5.0, 12.0))plt.title("P(measurement | guess)")plt.xlabel("weight")plt.ylabel("#") |
とすると、同一のヒストグラムを元にサンプリングされるので、元の形が再現されていきます。
パラメータ調整
確率的プログラミングによるモデリングの有用性は、
観測値によってモデルを調整することで、データ生成における潜在的な要因を推定する能力にあります。
例えば、秤のモデルにおいて,計測値が8.5になる場合はこのように表現します。
1 2 | conditioned_scale = pyro.condition( scale, data={"measurement": Variable(torch.Tensor([8.5]))}) |
パラメータ調整において、下記のように引数を与えられるようにもできます。
1 2 | def deferred_conditioned_scale(measurement, *args, **kwargs): return pyro.condition(scale, data={"measurement": measurement})(*args, **kwargs) |
また、conditionメソッドではなく、obsパラメータを用いる書き方や、pyro.observeを用いた下記からも存在します。
1 2 3 4 5 6 7 8 9 10 11 12 13 | ## equivalent to pyro.condition(scale, data={"measurement": Variable(torch.ones(1))})def scale_obs(guess): z = pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) # here we attach an observation measurement == 1 return pyro.sample("measurement", dist.normal, weight, Variable(torch.ones(1)), obs=Variable(torch.Tensor([0.1])))## equivalent to scale_obs:def scale_obs(guess): z = pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) # here we attach an observation measurement == 1 return pyro.observe("measurement", dist.normal, Variable(torch.ones(1)), weight, Variable(torch.Tensor([0.1]))) |
ただし、モデル中でハードコーディングすることはあまり推奨されないので、
pyro.conditionによって、モデルを変更することなく条件を与える方が良いです。
また、複数の条件を与える書き方は複数あり、
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | def scale2(guess): weight = pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1))) tolerance = torch.abs( pyro.sample("tolerance", dist.normal, Variable(torch.zeros(1)), Variable(torch.ones(1)))) return pyro.sample("measurement", dist.normal, weight, tolerance)conditioned_scale2_1 = pyro.condition( pyro.condition(scale2, data={"weight": Variable(torch.ones(1))}), data={"measurement": Variable(torch.ones(1))})conditioned_scale2_2 = pyro.condition( pyro.condition(scale2, data={"measurement": Variable(torch.ones(1))}), data={"weight": Variable(torch.ones(1))})conditioned_scale2_3 = pyro.condition( scale2, data={"weight": Variable(torch.ones(1)), "measurement": Variable(torch.ones(1))}) |
3つのconditionメソッドは同質です。
秤のモデルにおいて、pyro.conditionを使ってguessとmeasurementを与えた時のweightの値について推測したいときは、先ほどの重点サンプリングの例と同様にして、
1 2 3 4 5 6 7 8 9 10 11 12 13 | guess = Variable(torch.Tensor([8.5]))measurement = Variable(torch.Tensor([9.5]))conditioned_scale = pyro.condition(scale, data={"measurement": measurement})marginal = pyro.infer.Marginal( pyro.infer.Importance(conditioned_scale, num_samples=100), sites=["weight"])print(marginal(guess))plt.hist([marginal(guess)["weight"].data[0] for _ in range(100)], range=(5.0, 12.0))plt.title("P(weight | measurement, guess)")plt.xlabel("weight")plt.ylabel("#") |
ただし、これらのやり方は、事前分布に関する情報や制約がないため計算効率が悪いです。
そこで、Pyroでは、Guideを利用して効率化することができます。
例えば以下のように書くことで、推定を効率化できます。
1 2 3 4 5 6 7 8 | def scale_prior_guide(guess): return pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1)))posterior = pyro.infer.Importance(conditioned_scale, guide=scale_prior_guide, num_samples=10)marginal = pyro.infer.Marginal(posterior, sites=["weight"]) |
または、weightの事後分布は、guessとmeasurementによって表されるので、
1 2 3 4 5 6 7 8 9 | def scale_posterior_guide(measurement, guess): a = (guess + torch.sum(measurement)) / (measurement.size(0) + 1.0) b = Variable(torch.ones(1)) / (measurement.size(0) + 1.0) return pyro.sample("weight", dist.normal, a, b)posterior = pyro.infer.Importance(deferred_conditioned_scale, guide=scale_posterior_guide, num_samples=20)marginal = pyro.infer.Marginal(posterior, sites=["weight"]) |
と書くことができます。
今回の秤のモデルは、自ら中の仕組みを組み上げているため、
正確な事後分布を書くことができますが、
一般的には正確な事後分布を推定するのは難しいです。
そのため、変分推論と呼ばれるアプローチによって、近似的な事後確率を求めます。
PyroによるSVI(簡易説明)
pyro.paramはpyro.sampleのように、第一引数で名前をつけられます。
初回呼び出し時には、名前とその引数が結びつけられ、その後の呼び出しでは、
他の引数にかかわらず、名前によって値が返されます。
1 2 3 4 | def scale_parametrized_guide(guess): a = pyro.param("a", Variable(torch.randn(1) + guess.data.clone(), requires_grad=True)) b = pyro.param("b", Variable(torch.randn(1), requires_grad=True)) return pyro.sample("weight", dist.normal, a, torch.abs(b)) |
PyroのSVIについて、公式でSVIのためのチュートリアルが用意されているので、
今回は詳しい説明は省きますが、秤のモデルに適応した簡単なものは以下のように書けます。
1 2 3 4 5 6 7 8 9 10 11 12 13 | pyro.clear_param_store()svi = pyro.infer.SVI(model=conditioned_scale, guide=scale_parametrized_guide, loss="ELBO")losses = []for t in range(1000): losses.append(svi.step(guess))plt.plot(losses)plt.title("ELBO")plt.xlabel("step")plt.ylabel("loss") |
今回は、optimによる最適化手法の選択と、
lossでの損失関数の指定については説明しません。
また、以下のように最適化されたガイドを重点サンプリングの重要度分布として使用すると、
以前よりも少ないサンプルで周辺分布を推定できます。
1 2 3 4 5 6 7 | posterior = pyro.infer.Importance(conditioned_scale, scale_parametrized_guide, num_samples=10)marginal = pyro.infer.Marginal(posterior, sites=["weight"])plt.hist([marginal(guess)["weight"].data[0] for _ in range(100)], range=(5.0, 12.0))plt.title("P(weight | measurement, guess)")plt.xlabel("weight")plt.ylabel("#") |
ガイドから直接、事後分布の近似としてサンプリングすることもできます。
1 2 3 4 | plt.hist([scale_parametrized_guide(guess).data[0] for _ in range(100)], range=(5.0, 12.0))plt.title("P(weight | measurement, guess)")plt.xlabel("weight")plt.ylabel("#") |
まとめ
以上、簡単にではありますがPyroの導入から基本的な使い方を、Introductionに沿って説明させていただきました。
本当は、SVIの解説を詳しくやっていくつもりだったのですが、文量が10倍になりそうでしたので、今回は省かせていただきました。
また、Edwardとの比較や、pystanなどとの速度比較もしたかったのですが、次回のお楽しみとさせていただきます。
ちなみに、pyroを日本語に訳すと、「火」「熱」「高温」という意味があるようです。
プロジェクトの炎上を連想させる「火」をチョイスするセンスは見習いたいものですね。
明日は、BIチームでもっともホスピタリティのある鈴木さん aka ミニオンさんによる
「非エンジニアがSQLを学習する際の10の心得」です。お楽しみに!
エウレカでは、一緒に働いていただける方を絶賛募集中です。募集中の職種はこちらからご確認ください!皆様のエントリーをお待ちしております!