博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Brown Clustering算法和代码学习
阅读量:6539 次
发布时间:2019-06-24

本文共 9088 字,大约阅读时间需要 30 分钟。

hot3.png

一、算法

  布朗聚类是一种自底向上的层次聚类,基于n-gram模型和马尔科夫链模型。布朗聚类是一种硬聚类,每一个词都在且只在唯一的一个类中。

  

w是词,c是词所属的类。

  布朗聚类的输入是一个语料库,这个语料库是一个词序列,输出是一个二叉树,树的叶子节点是一个个词,树的中间节点是我们想要的类(中间结点作为根节点的子树上的所有叶子为类中的词)。

  初始的时候,将每一个词独立的分成一类,然后,将两个类合并,使得合并之后评价函数最大,然后不断重复上述过程,达到想要的类的数量的时候停止合并。

  上面提到的评价函数,是对于n个连续的词(w)序列能否组成一句话的概率的对数的归一化结果。于是,得到评价函数:

n是文本长度,w是词

  上面的评价公式是PercyLiang的“Semi-supervised learning for natural languageprocessing”文章中关于布朗聚类的解释,Browm原文中是基于class-based bigram language model建立的,于是得到下面公式:

 

T是文本长度,t是文本中的词

  上述公式是由于对于bigram,于是归一化处理只需要对T-1个bigram。我觉得PercyLiang的公式容易理解评价函数的定义,但是,Brown的推导过程更加清晰简明,所以,接下来的公式推导遵循Brown原文中的推导过程。

  上面的推导式数学推导,接下来是一个重要的近似处理,近似等于w2在训练集中出现的频率,也就是Pr(w2),于是公式变为:

  H(w)是熵,只跟1-gram的分布有关,也就是与类的分布无关,而I(c1,c2)是相邻类的平均互信息。所以,I决定了L。所以,只有最大化I,L才能最大。

二、优化

  Brown提出了一种估算方式进行优化。首先,将词按照词频进行排序,将前C(词的总聚类数目)个词分到不同的C类中,然后,将接下来词频最高的词,添加到一个新的类,将(C+1)类聚类成C类,即合并两个类,使得平均互信息损失最小。虽然,这种方式使得计算不是特别精确,类的加入顺序,决定了合并的顺序,会影响结果,但是极大的降低了计算复杂度。

  显然上面提及的算法仍然是一种naive的算法,算法复杂度十分高。(上述结果包括下面的复杂度结果来自Percy Liang的论文)。对于这么高的复杂度,对于成百上千词的聚类将变得不现实,于是,优化算法变得不可或缺。Percy Liang和Brown分别从两个角度去优化。

  Brown从代数的角度优化,通过一个表格记录下每次合并的中间结果,然后,用来计算下一次结果。

  Percy Liang从几何的角度考虑优化,更加清晰直观。但是,Percy Liang是从跟Brown的损失函数L相反的角度去考虑(即两者正负号不同),但是,都是为了保留中间结果,减少计算量,个人觉得PercyLiang的算法比较容易理解,而且,他少忽略了一些没必要计算的中间结果,更加优化,后面介绍的代码,也是PercyLiang写的,所以,将会重点介绍一下他的思考方式。

  Percy Liang将聚类结果表示成一个无向图,图的节点有C个,代表C个类,同时,任何两个节点都有一条边,边代表相邻两个节点之间(两个类之间)的平均互信息。边的权重如下表达式:

 

  而评价的总的平均互信息I就是所有边的权重之和。下面是实际代码中的计算损失评价的函数即合并后的I减去合并前的I的损失。

  上述的(c并c')代表合并c和两个节点后的一个节点,C是当前集合,而C'是合并后的集合:

 

三、代码实现

  代码实现的主要过程概览:

   1、读取文本并预处理

     1) 将文本中的每个词读入并编码(其中过滤一些频次极其低的)

     2)统计词表大小、出现次数

     3)将文本左右两个方向的n-gram存储

   2、初始化布朗聚类(N log N)

     1)将词进行排序

     2)将频次最高的initC个词分配到每个类

     3)初始化p1(概率),q2(边的权重)

   3、进行布朗聚类

     1)初始化L2(合并减少的互信息)

     2) 将当前未聚类的词中,出现频次最高的,作为一个类,添加进去,并同时,计算p1,q2,L2

     3)找到最小的L2

     4)合并,并更新q2,L2

  代码还实现了计算KL散度比较相关性,此部分略去。

   这里p1如下

          

    q2如下

        

 

四、重要代码段解析

初始化L2:

 

[cpp]  

  1. <span style="font-size:18px;">// O(C^3) time.  
  2. void compute_L2() {  
  3.   track("compute_L2()", "", true);  
  4.   
  5.   track_block("Computing L2", "", false)  
  6.   FOR_SLOT(s) {  
  7.     track_block("L2", "L2[" << Slot(s) << ", *]", false)  
  8.     FOR_SLOT(t) {  
  9.       if(!ORDER_VALID(s, t)) continue;  
  10.       double l = L2[s][t] = compute_L2(s, t);  
  11.       logs("L2[" << Slot(s) << "," << Slot(t) << "] = " << l << ", resulting minfo = " << curr_minfo-l);  
  12.     }  
  13.   }  
  14. }</span>  

 

 

上面调用,单步计算L2:

 

[cpp]  

  1. <span style="font-size:18px;">// O(C) time.  
  2. double compute_L2(int s, int t) { // compute L2[s, t]  
  3.   assert(ORDER_VALID(s, t));  
  4.   // st is the hypothetical new cluster that combines s and t  
  5.   
  6.   // Lose old associations with s and t  
  7.   double l = 0.0;  
  8.   for (int w = 0; w < len(slot2cluster); w++) {  
  9.     if ( slot2cluster[w] == -1) continue;  
  10.     l += q2[s][w] + q2[w][s];  
  11.     l += q2[t][w] + q2[w][t];  
  12.   }  
  13.   l -= q2[s][s] + q2[t][t];  
  14.   l -= bi_q2(s, t);  
  15.   
  16.   // Form new associations with st  
  17.   FOR_SLOT(u) {  
  18.     if(u == s || u == t) continue;  
  19.     l -= bi_hyp_q2(_(s, t), u);  
  20.   }  
  21.   l -= hyp_q2(_(s, t)); // q2[st, st]  
  22.   return l;  
  23. }  
  24. </span>  

 

 

聚类过程中,更新p1,q2,L2,调用时(两次):

 

[cpp]  

  1. <span style="font-size:18px;">// Stage 1: Maintain initC clusters.  For each of the phrases initC..N-1, make  
  2.   // it into a new cluster.  Then merge the optimal pair among the initC+1  
  3.   // clusters.  
  4.   // O(N*C^2) time.  
  5.   track_block("Stage 1", "", false) {  
  6.     mem_tracker.report_mem_usage();  
  7.     for(int i = initC; i < len(freq_order_phrases); i++) { // Merge phrase new_a  
  8.       int new_a = freq_order_phrases[i];  
  9.       track("Merging phrase", i << '/' << N << ": " << Cluster(new_a), true);  
  10.       logs("Mutual info: " << curr_minfo);  
  11.       incorporate_new_phrase(new_a);//添加后,C->C+1  
  12.       repcheck();  
  13.       merge_clusters(find_opt_clusters_to_merge());//合并后,C+1->C  
  14.   
  15.   
  16.       repcheck();  
  17.     }  
  18.   }  
  19. </span>  

 

添加后,更新p1,q2,L2

[cpp]  

  1. <span style="font-size:18px;">// Add new phrase as a cluster.  
  2. // Compute its L2 between a and all existing clusters.  
  3. // O(C^2) time, O(T) time over all calls.  
  4. void incorporate_new_phrase(int a) {  
  5.   track("incorporate_new_phrase()", Cluster(a), false);  
  6.   
  7.   int s = put_cluster_in_free_slot(a);  
  8.   init_slot(s);  
  9.   cluster2rep[a] = a;  
  10.   rep2cluster[a] = a;  
  11.   
  12.   // Compute p1  
  13.   p1[s] = (double)phrase_freqs[a] / T;  
  14.     
  15.   // Overall all calls: O(T)  
  16.   // Compute p2, q2 between a and everything in clusters  
  17.   IntIntMap freqs;  
  18.   freqs.clear(); // right bigrams  
  19.   forvec(_, int, b, right_phrases[a]) {  
  20.     b = phrase2rep.GetRoot(b);  
  21.     if(!contains(rep2cluster, b)) continue;  
  22.     b = rep2cluster[b];  
  23.     if(!contains(cluster2slot, b)) continue;  
  24.     freqs[b]++;  
  25.   }  
  26.   forcmap(int, b, int, count, IntIntMap, freqs) {  
  27.     curr_minfo += set_p2_q2_from_count(cluster2slot[a], cluster2slot[b], count);  
  28.     logs(Cluster(a) << ' ' << Cluster(b) << ' ' << count << ' ' << set_p2_q2_from_count(cluster2slot[a], cluster2slot[b], count));  
  29.   }  
  30.   
  31.   freqs.clear(); // left bigrams  
  32.   forvec(_, int, b, left_phrases[a]) {  
  33.     b = phrase2rep.GetRoot(b);  
  34.     if(!contains(rep2cluster, b)) continue;  
  35.     b = rep2cluster[b];  
  36.     if(!contains(cluster2slot, b)) continue;  
  37.     freqs[b]++;  
  38.   }  
  39.   forcmap(int, b, int, count, IntIntMap, freqs) {  
  40.     curr_minfo += set_p2_q2_from_count(cluster2slot[b], cluster2slot[a], count);  
  41.     logs(Cluster(b) << ' ' << Cluster(a) << ' ' << count << ' ' << set_p2_q2_from_count(cluster2slot[b], cluster2slot[a], count));  
  42.   }  
  43.   
  44.   curr_minfo -= q2[s][s]; // q2[s, s] was double-counted  
  45.   
  46.   // Update L2: O(C^2)  
  47.   track_block("Update L2", "", false) {  
  48.   
  49.     the_job.s = s;  
  50.     the_job.is_type_a = true;  
  51.     // start the jobs  
  52.     for (int ii=0; ii<num_threads; ii++) {  
  53.       thread_start[ii].unlock(); // the thread waits for this lock to begin  
  54.     }  
  55.     // wait for them to be done  
  56.     for (int ii=0; ii<num_threads; ii++) {  
  57.       thread_idle[ii].lock();  // the thread releases the lock to finish  
  58.     }  
  59.   }  
  60.   
  61.   //dump();  
  62. }  
  63. </span>  

合并后,更新

[cpp]  

  1. <span style="font-size:18px;">// O(C^2) time.  
  2. // Merge clusters a (in slot s) and b (in slot t) into c (in slot u).  
  3. void merge_clusters(int s, int t) {  
  4.   assert(ORDER_VALID(s, t));  
  5.   int a = slot2cluster[s];  
  6.   int b = slot2cluster[t];  
  7.   int c = curr_cluster_id++;  
  8.   int u = put_cluster_in_free_slot(c);  
  9.   
  10.   free_up_slots(s, t);  
  11.   
  12.   // Record merge in the cluster tree  
  13.   cluster_tree[c] = _(a, b);  
  14.   curr_minfo -= L2[s][t];  
  15.   
  16.   // Update relationship between clusters and rep phrases  
  17.   int A = cluster2rep[a];  
  18.   int B = cluster2rep[b];  
  19.   phrase2rep.Join(A, B);  
  20.   int C = phrase2rep.GetRoot(A); // New rep phrase of cluster c (merged a and b)  
  21.   
  22.   track("Merging clusters", Cluster(a) << " and " << Cluster(b) << " into " << c << ", lost " << L2[s][t], false);  
  23.   
  24.   cluster2rep.erase(a);  
  25.   cluster2rep.erase(b);  
  26.   rep2cluster.erase(A);  
  27.   rep2cluster.erase(B);  
  28.   cluster2rep[c] = C;  
  29.   rep2cluster[C] = c;  
  30.   
  31.   // Compute p1: O(1)  
  32.   p1[u] = p1[s] + p1[t];  
  33.   
  34.   // Compute p2: O(C)  
  35.   p2[u][u] = hyp_p2(_(s, t));  
  36.   FOR_SLOT(v) {  
  37.     if(v == u) continue;  
  38.     p2[u][v] = hyp_p2(_(s, t), v);  
  39.     p2[v][u] = hyp_p2(v, _(s, t));  
  40.   }  
  41.   
  42.   // Compute q2: O(C)  
  43.   q2[u][u] = hyp_q2(_(s, t));  
  44.   FOR_SLOT(v) {  
  45.     if(v == u) continue;  
  46.     q2[u][v] = hyp_q2(_(s, t), v);  
  47.     q2[v][u] = hyp_q2(v, _(s, t));  
  48.   }  
  49.   
  50.   // Compute L2: O(C^2)  
  51.   track_block("Compute L2", "", false) {  
  52.     the_job.s = s;  
  53.     the_job.t = t;  
  54.     the_job.u = u;  
  55.     the_job.is_type_a = false;  
  56.   
  57.     // start the jobs  
  58.     for (int ii=0; ii<num_threads; ii++) {  
  59.       thread_start[ii].unlock(); // the thread waits for this lock to begin  
  60.     }  
  61.     // wait for them to be done  
  62.     for (int ii=0; ii<num_threads; ii++) {  
  63.       thread_idle[ii].lock();  // the thread releases the lock to finish  
  64.     }  
  65.   }  
  66. }  
  67. void merge_clusters(const IntPair &st) { merge_clusters(st.first, st.second); }  
  68. </span>  

更新L2过程,其中使用了多线程:

使用:

[cpp]  

  1. <span style="font-size:18px;">// Variables used to control the thread pool  
  2. mutex * thread_idle;  
  3. mutex * thread_start;  
  4. thread * threads;  
  5. struct Compute_L2_Job {  
  6.   int s;  
  7.   int t;  
  8.   int u;  
  9.   bool is_type_a;  
  10. };  
  11. Compute_L2_Job the_job;  
  12. bool all_done = false;  
  13. </span>  

初始化,将所有线程锁住:

[html]  

  1. <span style="font-size:18px;">// start the threads  
  2.   thread_start = new mutex[num_threads];  
  3.   thread_idle = new mutex[num_threads];  
  4.   threads = new thread[num_threads];  
  5.   for (int ii=0; ii<num_threads; ii++) {  
  6.     thread_start[ii].lock();  
  7.     thread_idle[ii].lock();  
  8.     threads[ii] = thread(update_L2, ii);  
  9.   }  
  10. </span>  

调用线程,共计2处,第一处是在添加后:

 

[cpp]  

  1. <span style="font-size:18px;">// Update L2: O(C^2)  
  2.   track_block("Update L2", "", false) {  
  3.   
  4.     the_job.s = s;  
  5.     the_job.is_type_a = true;  
  6.     // start the jobs  
  7.     for (int ii=0; ii<num_threads; ii++) {  
  8.       thread_start[ii].unlock(); // the thread waits for this lock to begin  
  9.     }  
  10.     // wait for them to be done  
  11.     for (int ii=0; ii<num_threads; ii++) {  
  12.       thread_idle[ii].lock();  // the thread releases the lock to finish  
  13.     }  
  14.   }  
  15. </span>  

 

第二处是在合并后

[cpp]  

  1. <span style="font-size:18px;">// Compute L2: O(C^2)  
  2.   track_block("Compute L2", "", false) {  
  3.     the_job.s = s;  
  4.     the_job.t = t;  
  5.     the_job.u = u;  
  6.     the_job.is_type_a = false;  
  7.   
  8.     // start the jobs  
  9.     for (int ii=0; ii<num_threads; ii++) {  
  10.       thread_start[ii].unlock(); // the thread waits for this lock to begin  
  11.     }  
  12.     // wait for them to be done  
  13.     for (int ii=0; ii<num_threads; ii++) {  
  14.       thread_idle[ii].lock();  // the thread releases the lock to finish  
  15.     }  
  16.   }  
  17. </span>  

结束调用:

[cpp]  

  1. <span style="font-size:18px;">// finish the threads  
  2.   all_done = true;  
  3.   for (int ii=0; ii<num_threads; ii++) {  
  4.     thread_start[ii].unlock(); // thread will grab this to start  
  5.     threads[ii].join();  
  6.   }  
  7.   delete [] thread_start;  
  8.   delete [] thread_idle;  
  9.   delete [] threads;  
  10. </span>  

    通过两个锁实现调用,每次调用时通过更新the_job来改变计算参数,调用时打开thread_start锁,结束后,关闭thread_idle锁。

参考文献:

 

Liang: Semi-supervised learning for natural language processing

Brown, et al.: Class-Based n-gram Models of Natural Language

代码来源:

https://github.com/percyliang/brown-cluster

转载于:https://my.oschina.net/airship/blog/895472

你可能感兴趣的文章
Range
查看>>
爬虫之lxml - etree - xpath的使用
查看>>
PyalgoTrade 打印收盘价(二)
查看>>
关于C语言指针【第二季】
查看>>
MYSQLi数据访问批量删除
查看>>
浪潮K-UNIX操作系统了解
查看>>
less: CSS 预处理语言
查看>>
知识管理系统VS文档管理系统的区别【转】
查看>>
最近点对
查看>>
《团队作业第三、第四周》五小福团队作业--Scrum 冲刺阶段--Day2
查看>>
PHP为什么会被认为是草根语言?
查看>>
解决NetBeans编辑器中文乱码问题
查看>>
ztree-demo 2
查看>>
javascript常用方法
查看>>
decode行转列,case when,
查看>>
C#数据类型与数据库字段类型对应
查看>>
JAVA 查找类的所有引用关系(python实现)
查看>>
Linux系统常用命令之top
查看>>
Android自定义控件之AlertDialog
查看>>
.net core 定时任务
查看>>