1e2σ2(x−μ)2可得KL(p1,p2) 等于:
∫p1(x)logp2(x)p1(x)dx=∫p1(x)(logp1(x)dx−logp2(x))dx=∫p1(x)×(log2πσ121e2σ12(x−μ1)2−log2πσ221e2σ22(x−μ2)2)dx=∫p1(x)×(−21log2π−logσ1−2σ12(x−μ1)2+21log2π+logσ2+2σ22(x−μ2)2)dx=∫p1(x)(logσ1σ2+[2σ22(x−μ2)2−2σ12(x−μ1)2])dx=∫(logσ1σ2)p1(x)dx+∫(2σ22(x−μ2)2)p1(x)dx−∫(2σ12(x−μ1)2)p1(x)dx=logσ1σ2+2σ221∫((x−μ2)2)p1(x)dx−2σ121∫((x−μ1)2)p1(x)dx
最右一项为方差计算,值约分为:−21,可知上式得:
=logσ1σ2+2σ221∫((x−μ2)2)p1(x)dx−21=logσ1σ2+2σ221∫((x−μ1+μ1−μ2)2)p1(x)dx−21=logσ1σ2+2σ221[∫(x−μ1)2p1(x)dx+∫(μ1−μ2)2p1(x)dx+2∫(x−μ1)(μ1−μ2)]p1(x)dx−21=logσ1σ2+2σ221[∫(x−μ1)2p1(x)dx+(μ1−μ2)2]−21=logσ1σ2+2σ22σ12+(μ1−μ2)2−21
假设N2为正态分布,μ2=0,σ22=1,可知分布N1 及其对应的KL散度为:
KL(μ1,σ1)=−logσ1+2σ12+μ12−21
从上式可以看出当μ1=0,σ12=1时,KL散度值最小。
多维高斯分布的KL散度
多维高斯分布的公式如下:
p(x1,x2,…xn)=2π∗det(Σ)1e(−21(x−μ)TΣ−1(x−μ))
由于通常假定各维变量独立同分布,因此协方差矩阵为对角矩阵,下面直接给出多维高斯分布的KL散度计算结果:
KL(p1∥p2)=21[logdet(Σ1)det(Σ2)−d+tr(Σ2−1Σ1)+(μ2−μ1)TΣ2−1(μ2−μ1)]
Python 实例
假设两个离散分布为p 和q,p 的分布为{1,1,2,2,3},q 的分布为{1,1,1,1,1,2,3,3,3,3}。
两个分布中元素数量不同,但是都包含1
,2
,3
三个元素。
当x=1时,p(x=1)=52=0.4,q(x=1)=105=0.5;
当x=2时,p(x=2)=52=0.4,q(x=2)=101=0.1;
当x=3时,p(x=3)=51=0.2,q(x=3)=104=0.4;
代入KL散度计算公式:
D(P∥∥Q)=0.4log20.50.4+0.4log20.10.4+0.2log20.40.2=0.47
PyTorch 实现
对于上例使用PyTorch进行实现:
1 2 3 4 5 6 7 8
| In [1]: import torch
In [2]: p = torch.tensor([0.4,0.4,0.2], dtype=torch.float32)
In [3]: q = torch.tensor([0.5,0.1,0.4], dtype=torch.float32)
In [4]: (p*torch.log2(p/q)).sum() Out[4]: tensor(0.4712)
|
内置函数
torch.nn.functional.kl_div(q.log(),p,reduction='sum')
函数中的p 和q 位置相反(也就是想要计算DL(p∥∥q),要写成kl_div(q.log(),p)
的形式),而且q要先取 log
。
1 2 3 4 5 6
| In [10]: p = torch.tensor([0.4,0.4,0.2], dtype=torch.float32)
In [11]: q = torch.tensor([0.5,0.1,0.4], dtype=torch.float32)
In [12]: torch.nn.functional.kl_div(q.log(),p,reduction='sum') Out[12]: tensor(0.3266)
|
计算结果不同但是同样是正确的,是函数log
对数以e
为底。
分类KL散度
1 2 3 4 5 6 7 8 9
| import torch.nn.functional as F
def kl_categorical(p_logit, q_logit): p = F.softmax(p_logit, dim=-1) _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1) - F.log_softmax(q_logit, dim=-1)), 1) return torch.mean(_kl)
|
联系作者
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 梦家博客! 打赏
wechat
alipay