(转) Parameter estimation for text analysis 暨LDA学习小结

时间:2024-01-20 18:24:03

Reading Note : Parameter estimation for text analysis 暨LDA学习小结

伟大的Parameter estimation for text analysis!当把这篇看的差不多的时候,也就到了LDA基础知识终结的时刻了,意味着LDA基础模型的基本了解完成了。所以对该模型的学习告一段落,下一阶段就是了解LDA无穷无尽的变种,不过那些不是很有用了,因为LDA已经被人水遍了各大“论坛”……

抛开LDA背后复杂深入的数学背景不说,光就LDA的内容,确实不多,虽然变分法还是不懂,不过现在终于还是理解了“LDA is just a simple model”这句话。

总结一下学习过程:
1.概率的基本概念:CDF、PDF、Bayes’rule、各种简单的分布Bernoulli,binomial,multinomial、包括对prior、likelihood、postprior的理解(PRML1.2)
​3.概率图模型 Probabilistic Graphical Models: PRML Chapter 8 基本概念即可
4.采样算法:Basic Sampling,Sampling Methods(PRML Chapter 11),马尔科夫蒙特卡洛 MCMC,Gibbs Sampling
​6.进阶资料:《Gibbs Sampling for the Uninitiated》、本文
——————————————– 伟大的分割线 !PETA! ​——————————————–

一、前面无关部分

关于ML、MAP、Bayesian inference

二、模型进一步记忆

(转) Parameter estimation for text analysis 暨LDA学习小结

从本图来看,需要记住:

1.θm是每一个document单独一个θ,所以M个doc共有M个θm,整个θ是一个M*K的矩阵(M个doc,每个doc一个K维topic分布向量)。

2.φk总共只有K个,对于每一个topic,有一个φk,这些参数是独立于文档的,也就是对于整个corpus只sample一次。不像θm那样每一个都对应一个文档,每个文档都不同,φk对于所有文档都相同,是一个K*V的矩阵(K个topic,每个topic一个V维从topic产生词的概率分布)。

就这些了。

三、推导

公式(39):P(p|α)=Dir(p|α)意思是从参数为α的狄利克雷分布,采样一个多项分布参数p的概率是多少,概率是标准狄利克雷PDF。这里Dirichlet delta function为:

Δ(α⃗ )=Γ(α1)∗Γ(α2)∗…∗Γ(αk)Γ(∑K1 αk)

这个function要记住,下面一溜烟全是这个。

公式(43)是一元语言模型的likelihood,意思是如果提供了语料库W,知道了W里面每个词的个数,那么使用最大似然估计最大化L就可以估计出参数多项分布p。

公式(44)是考虑了先验的情形,假如已知语料库W和参数α,那么他们产生多项分布参数p的概率是Dir(p|α+n),这个推导我记得在PRML2.1中有解释,抛开复杂的数学证明,只要参考标准狄利克雷分布的归一化项,很容易想出式(46)的归一化项就是Δ(α+n)。这时如果要通过W估计参数p,那么就要使用贝叶斯推断,用这个狄利克雷pdf输出一个p的期望即可。

最关键的推导(63)-(78):从63-73的目标是要求出整个LDA的联合概率表达式,这样(63)就可以被用在Gibbs Sampler的分子上。首先(63)把联合概率拆成相互独立的两部分p(w|z,β)和p(z|α),然后分别对这两部分布求表达式。式(64)、(65)首先不考虑超参数β,而是假设已知参数Φ。这个Φ就是那个K*V维矩阵,表示从每一个topic产生词的概率。然后(66)要把Φ积分掉,这样就可以求出第一部分p(w|z,β)为表达式(68)。从66-68的积分过程一直在套用狄利克雷积分的结果,反正整篇文章套来套去始终就是这么一个狄利克雷积分。n⃗ z是一个V维的向量,对于topic z,代表每一个词在这个topic里面有几个。从69到72的道理其实和64-68一模一样了。n⃗ m是一个K维向量,对于文档m,代表每一个topic在这个文档里有几个词。

最后(78)求出了Gibbs Sampler所需要的条件概率表达式。这个表达式还是要贴出来的,为了和代码里面对应:

(转) Parameter estimation for text analysis 暨LDA学习小结

具体选择下一个新topic的方法是:通过计算每一个topic的新的产生概率p(zi=k|z┐i,w)也就是代码中的p[k]产生一个新topic。比如有三个topic,算出来产生新的p的概率值为{0.3,0.2,0.4},注意这个条件概率加起来并不一定是一。然后我为了按照这个概率产生一个新topic,我用random函数从uniform distribution产生一个0至0.9的随机数r。如果0<=r<0.3,则新topic赋值为1,如果0.3<=r<0.5,则新topic赋值为2,如果0.5<=r<0.9,那么新topic赋值为3。

四、代码

  1. /*
  2. * (C) Copyright 2005, Gregor Heinrich (gregor :: arbylon : net)
  3. * LdaGibbsSampler is free software; you can redistribute it and/or modify it
  4. * under the terms of the GNU General Public License as published by the Free
  5. * Software Foundation; either version 2 of the License, or (at your option) any
  6. * later version.
  7. * LdaGibbsSampler is distributed in the hope that it will be useful, but
  8. * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  9. * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
  10. * details.
  11. * You should have received a copy of the GNU General Public License along with
  12. * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
  13. * Place, Suite 330, Boston, MA 02111-1307 USA
  14. */
  15. import java.text.DecimalFormat;
  16. import java.text.NumberFormat;
  17. public class LdaGibbsSampler {
  18. /**
  19. * document data (term lists)
  20. */
  21. int[][] documents;
  22. /**
  23. * vocabulary size
  24. */
  25. int V;
  26. /**
  27. * number of topics
  28. */
  29. int K;
  30. /**
  31. * Dirichlet parameter (document--topic associations)
  32. */
  33. double alpha;
  34. /**
  35. * Dirichlet parameter (topic--term associations)
  36. */
  37. double beta;
  38. /**
  39. * topic assignments for each word.
  40. * N * M 维,第一维是文档,第二维是word
  41. */
  42. int z[][];
  43. /**
  44. * nw[i][j] number of instances of word i (term?) assigned to topic j.
  45. */
  46. int[][] nw;
  47. /**
  48. * nd[i][j] number of words in document i assigned to topic j.
  49. */
  50. int[][] nd;
  51. /**
  52. * nwsum[j] total number of words assigned to topic j.
  53. */
  54. int[] nwsum;
  55. /**
  56. * nasum[i] total number of words in document i.
  57. */
  58. int[] ndsum;
  59. /**
  60. * cumulative statistics of theta
  61. */
  62. double[][] thetasum;
  63. /**
  64. * cumulative statistics of phi
  65. */
  66. double[][] phisum;
  67. /**
  68. * size of statistics
  69. */
  70. int numstats;
  71. /**
  72. * sampling lag (?)
  73. */
  74. private static int THIN_INTERVAL = 20;
  75. /**
  76. * burn-in period
  77. */
  78. private static int BURN_IN = 100;
  79. /**
  80. * max iterations
  81. */
  82. private static int ITERATIONS = 1000;
  83. /**
  84. * sample lag (if -1 only one sample taken)
  85. */
  86. private static int SAMPLE_LAG;
  87. private static int dispcol = 0;
  88. /**
  89. * Initialise the Gibbs sampler with data.
  90. *
  91. * @param V
  92. *            vocabulary size
  93. * @param data
  94. */
  95. public LdaGibbsSampler(int[][] documents, int V) {
  96. this.documents = documents;
  97. this.V = V;
  98. }
  99. /**
  100. * Initialisation: Must start with an assignment of observations to topics ?
  101. * Many alternatives are possible, I chose to perform random assignments
  102. * with equal probabilities
  103. *
  104. * @param K
  105. *            number of topics
  106. * @return z assignment of topics to words
  107. */
  108. public void initialState(int K) {
  109. int i;
  110. int M = documents.length;
  111. // initialise count variables.
  112. nw = new int[V][K];
  113. nd = new int[M][K];
  114. nwsum = new int[K];
  115. ndsum = new int[M];
  116. // The z_i are are initialised to values in [1,K] to determine the
  117. // initial state of the Markov chain.
  118. // 为了方便,他没用从狄利克雷参数采样,而是随机初始化了!
  119. z = new int[M][];
  120. for (int m = 0; m < M; m++) {
  121. int N = documents[m].length;
  122. z[m] = new int[N];
  123. for (int n = 0; n < N; n++) {
  124. //随机初始化!
  125. int topic = (int) (Math.random() * K);
  126. z[m][n] = topic;
  127. // number of instances of word i assigned to topic j
  128. // documents[m][n] 是第m个doc中的第n个词
  129. nw[documents[m][n]][topic]++;
  130. // number of words in document i assigned to topic j.
  131. nd[m][topic]++;
  132. // total number of words assigned to topic j.
  133. nwsum[topic]++;
  134. }
  135. // total number of words in document i
  136. ndsum[m] = N;
  137. }
  138. }
  139. /**
  140. * Main method: Select initial state ? Repeat a large number of times: 1.
  141. * Select an element 2. Update conditional on other elements. If
  142. * appropriate, output summary for each run.
  143. *
  144. * @param K
  145. *            number of topics
  146. * @param alpha
  147. *            symmetric prior parameter on document--topic associations
  148. * @param beta
  149. *            symmetric prior parameter on topic--term associations
  150. */
  151. private void gibbs(int K, double alpha, double beta) {
  152. this.K = K;
  153. this.alpha = alpha;
  154. this.beta = beta;
  155. // init sampler statistics
  156. if (SAMPLE_LAG > 0) {
  157. thetasum = new double[documents.length][K];
  158. phisum = new double[K][V];
  159. numstats = 0;
  160. }
  161. // initial state of the Markov chain:
  162. //启动马尔科夫链需要一个起始状态
  163. initialState(K);
  164. //每一轮sample
  165. for (int i = 0; i < ITERATIONS; i++) {
  166. // for all z_i
  167. for (int m = 0; m < z.length; m++) {
  168. for (int n = 0; n < z[m].length; n++) {
  169. // (z_i = z[m][n])
  170. // sample from p(z_i|z_-i, w)
  171. //核心步骤,通过论文中表达式(78)为文档m中的第n个词采样新的topic
  172. int topic = sampleFullConditional(m, n);
  173. z[m][n] = topic;
  174. }
  175. }
  176. // get statistics after burn-in
  177. //如果当前迭代轮数已经超过 burn-in的限制,并且正好达到 sample lag间隔
  178. //则当前的这个状态是要计入总的输出参数的,否则的话忽略当前状态,继续sample
  179. if ((i > BURN_IN) && (SAMPLE_LAG > 0) && (i % SAMPLE_LAG == 0)) {
  180. updateParams();
  181. }
  182. }
  183. }
  184. /**
  185. * Sample a topic z_i from the full conditional distribution: p(z_i = j |
  186. * z_-i, w) = (n_-i,j(w_i) + beta)/(n_-i,j(.) + W * beta) * (n_-i,j(d_i) +
  187. * alpha)/(n_-i,.(d_i) + K * alpha)
  188. *
  189. * @param m
  190. *            document
  191. * @param n
  192. *            word
  193. */
  194. private int sampleFullConditional(int m, int n) {
  195. // remove z_i from the count variables
  196. //这里首先要把原先的topic z(m,n)从当前状态中移除
  197. int topic = z[m][n];
  198. nw[documents[m][n]][topic]--;
  199. nd[m][topic]--;
  200. nwsum[topic]--;
  201. ndsum[m]--;
  202. // do multinomial sampling via cumulative method:
  203. double[] p = new double[K];
  204. for (int k = 0; k < K; k++) {
  205. //nw 是第i个word被赋予第j个topic的个数
  206. //在下式中,documents[m][n]是word id,k为第k个topic
  207. //nd 为第m个文档中被赋予topic k的词的个数
  208. p[k] = (nw[documents[m][n]][k] + beta) / (nwsum[k] + V * beta)
  209. * (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
  210. }
  211. // cumulate multinomial parameters
  212. for (int k = 1; k < p.length; k++) {
  213. p[k] += p[k - 1];
  214. }
  215. // scaled sample because of unnormalised p[]
  216. double u = Math.random() * p[K - 1];
  217. for (topic = 0; topic < p.length; topic++) {
  218. if (u < p[topic])
  219. break;
  220. }
  221. // add newly estimated z_i to count variables
  222. nw[documents[m][n]][topic]++;
  223. nd[m][topic]++;
  224. nwsum[topic]++;
  225. ndsum[m]++;
  226. return topic;
  227. }
  228. /**
  229. * Add to the statistics the values of theta and phi for the current state.
  230. */
  231. private void updateParams() {
  232. for (int m = 0; m < documents.length; m++) {
  233. for (int k = 0; k < K; k++) {
  234. thetasum[m][k] += (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
  235. }
  236. }
  237. for (int k = 0; k < K; k++) {
  238. for (int w = 0; w < V; w++) {
  239. phisum[k][w] += (nw[w][k] + beta) / (nwsum[k] + V * beta);
  240. }
  241. }
  242. numstats++;
  243. }
  244. /**
  245. * Retrieve estimated document--topic associations. If sample lag > 0 then
  246. * the mean value of all sampled statistics for theta[][] is taken.
  247. *
  248. * @return theta multinomial mixture of document topics (M x K)
  249. */
  250. public double[][] getTheta() {
  251. double[][] theta = new double[documents.length][K];
  252. if (SAMPLE_LAG > 0) {
  253. for (int m = 0; m < documents.length; m++) {
  254. for (int k = 0; k < K; k++) {
  255. theta[m][k] = thetasum[m][k] / numstats;
  256. }
  257. }
  258. } else {
  259. for (int m = 0; m < documents.length; m++) {
  260. for (int k = 0; k < K; k++) {
  261. theta[m][k] = (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
  262. }
  263. }
  264. }
  265. return theta;
  266. }
  267. /**
  268. * Retrieve estimated topic--word associations. If sample lag > 0 then the
  269. * mean value of all sampled statistics for phi[][] is taken.
  270. *
  271. * @return phi multinomial mixture of topic words (K x V)
  272. */
  273. public double[][] getPhi() {
  274. double[][] phi = new double[K][V];
  275. if (SAMPLE_LAG > 0) {
  276. for (int k = 0; k < K; k++) {
  277. for (int w = 0; w < V; w++) {
  278. phi[k][w] = phisum[k][w] / numstats;
  279. }
  280. }
  281. } else {
  282. for (int k = 0; k < K; k++) {
  283. for (int w = 0; w < V; w++) {
  284. phi[k][w] = (nw[w][k] + beta) / (nwsum[k] + V * beta);
  285. }
  286. }
  287. }
  288. return phi;
  289. }
  290. /**
  291. * Configure the gibbs sampler
  292. *
  293. * @param iterations
  294. *            number of total iterations
  295. * @param burnIn
  296. *            number of burn-in iterations
  297. * @param thinInterval
  298. *            update statistics interval
  299. * @param sampleLag
  300. *            sample interval (-1 for just one sample at the end)
  301. */
  302. public void configure(int iterations, int burnIn, int thinInterval,
  303. int sampleLag) {
  304. ITERATIONS = iterations;
  305. BURN_IN = burnIn;
  306. THIN_INTERVAL = thinInterval;
  307. SAMPLE_LAG = sampleLag;
  308. }
  309. /**
  310. * Driver with example data.
  311. *
  312. * @param args
  313. */
  314. public static void main(String[] args) {
  315. // words in documents
  316. int[][] documents = { {1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 6},
  317. {2, 2, 4, 2, 4, 2, 2, 2, 2, 4, 2, 2},
  318. {1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 0},
  319. {5, 6, 6, 2, 3, 3, 6, 5, 6, 2, 2, 6, 5, 6, 6, 6, 0},
  320. {2, 2, 4, 4, 4, 4, 1, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 0},
  321. {5, 4, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2}};
  322. // vocabulary
  323. int V = 7;
  324. int M = documents.length;
  325. // # topics
  326. int K = 2;
  327. // good values alpha = 2, beta = .5
  328. double alpha = 2;
  329. double beta = .5;
  330. LdaGibbsSampler lda = new LdaGibbsSampler(documents, V);
  331. //设定sample参数,采样运行10000轮,burn-in 2000轮,第三个参数没用,是为了显示
  332. //第四个参数是sample lag,这个很重要,因为马尔科夫链前后状态conditional dependent,所以要跳过几个采样
  333. lda.configure(10000, 2000, 100, 10);
  334. //跑一个!走起!
  335. lda.gibbs(K, alpha, beta);
  336. //输出模型参数,论文中式 (81)与(82)
  337. double[][] theta = lda.getTheta();
  338. double[][] phi = lda.getPhi();
  339. }
  340. }