k-means聚类算法原理及其实现

简介:

k-means(k-均值)算法是一种基于距离的聚类算法,它用质心(Centroid)到属于该质心的点距离这个度量来实现聚类,通常可以用于N维空间中对象。下面,我们以二维空间为例,概要地总结一下k-means聚类算法的一些要点:

  • 除了随机选择的初始质心,后续迭代质心是根据给定的待聚类的集合S中点计算均值得到的,所以质心一般不是S中的点,但是标识的是一簇点的中心。
  • 基本k-means算法,开始需要随机选择指定的k个质心,因为初始k个质心是随机选择的,所以每次执行k-means聚类的结果可能都不相同。如果初始随机选择的质心位置不好,可能造成k-means聚类的结果非常不理想。
  • 计算质心:假设k-means聚类过程中,得到某一个簇的集合Ci={p(x1,y1), p(x2,y2), …,p(xn,yn)},则簇Ci的质心,质心x坐标为(x1+x2+ …+xn)/n,质心y坐标为(y1+y2+ …+yn)/n。
  • k-means算法的终止条件:质心在每一轮迭代中会发生变化,然后需要重新将非质心点指派给最近的质心而形成新的簇,如果只有很少的一部分点在迭代过程中,还在改变簇(如,更新一次质心,有些点从一个簇移动到另一个簇),那么满足这样一个收敛条件,可以提前结束迭代过程。
  • k-means算法的框架是:首先随机选择k个初始质心点,然后执行聚类处理迭代,不断更新质心,直到满足算法收敛条件。由于该算法收敛于局部最优,所以多次执行聚类算法,通过比较,选择聚类效果最好的结果作为最终的结果。
  • k-means算法聚类完成后,没有离群点,所有的点都会被指派到对应的簇中。

由于k-means算法比较简单,对于算法的实现过程,我们概要地描述如下:

  1. 随机选择k个初始质心;
  2. 如果没有满足聚类算法终止条件,则继续执行步骤3,否则转步骤5;
  3. 计算每个非质心点p到k个质心的欧几里德距离,将p指派给距离最近的质心;
  4. 根据上一步的k个质心及其对应的非质心点集,重新计算新的质心点,然后转步骤2;
  5. 输出聚类结果,算法可以执行多次,使用散点图比较不同的聚类结果。

下面,我们详细说明上述步骤:

随机选择初始质心

由于随机选择初始质心,每次执行聚类选择的初始质心都不相同,这也导致k-means算法聚类后,没有确定的结果,或者说,可能两次聚类的结果完全不同。该过程的实现,比较简单,只要随机选择给定待聚类点集合中的点即可,初始质心是实际存在的点,代码如下所示:

01 @Override
02 public TreeSet<Centroid> select(int k, List<Point2D> points) {
03 TreeSet<Centroid> centroids = Sets.newTreeSet();
04 Set<Point2D> selectedPoints = Sets.newHashSet();
05 while(selectedPoints.size() < k) { // 先随机选择k个点
06 int index = random.nextInt(points.size());
07 Point2D p = points.get(index);
08 selectedPoints.add(p);
09 }
10
11 Iterator<Point2D> iter = selectedPoints.iterator();
12 int id = 0;
13 while(iter.hasNext()) { // 构造Centroid质心对象,分配一个id作为簇的唯一标识
14 centroids.add(new Centroid(id++, iter.next()));
15 }
16 return centroids;
17 }

有一些方法,可以在这一步中,解决初始质心选择的随机性,可以将选择初始质心作为选择策略的设计,根据需要选择不同的策略,比如,可以这样设计策略接口:

1 public interface SelectInitialCentroidsPolicy {
2
3 TreeSet<Centroid> select(int k, List<Point2D> points);
4 }

我们这里只给了简单地随机选择策略,也是基本k-means算法最基础的策略。其他方法,可以查阅相关资料。

计算欧几里德距离,指派点到质心所在簇

计算每个非质心点到全部k个质心点的距离,将该非质心点指派给距离最小的质心点所在的簇。如果输出的数据量比较大,可以将数据集合进行分割,基于多线程去并行处理,最后再合并结果。我们的实现思路是:每个线程都共享k个质心的集合,然后将非质心点均匀分发到多个线程的队列中,然后每个线程从队列取出非质心点,计算非质心点到k个质心的距离,并计算出距离最短的质心,将该非质心点指派给该质心所在的簇。实现代码如下所示:

01 Task task = q.poll();
02 Point2D p1 = task.point;
03
04 // assign points to a nearest centroid
05 Distance minDistance = null;
06 for(Centroid centroid : task.centroids) { // 计算一个非质心点,到k个质心的距离,并计算距离最短的
07 double distance = MetricUtils.euclideanDistance(p1, centroid);
08 if(minDistance != null) {
09 if(distance < minDistance.distance) {
10 minDistance = new Distance(p1, centroid, distance);
11 }
12 } else {
13 minDistance = new Distance(p1, centroid, distance);
14 }
15 }
16 LOG.debug("Assign Point2D[" + p1 + "] to Centroid[" + minDistance.centroid + "]");
17
18 Multiset<Point2D> pointsBelongingToCentroid = localClusteredPoints.get(minDistance.centroid);
19 if(pointsBelongingToCentroid == null) {
20 pointsBelongingToCentroid = HashMultiset.create();
21 localClusteredPoints.put(minDistance.centroid, pointsBelongingToCentroid); // localClusteredPoints是局部的,key为质心,value为属于该质心的非质心点的集合
22 }
23 pointsBelongingToCentroid.add(p1);

这样,经过一轮的迭代计算,每个线程都处理完,得到一个局部的指派的簇的集合,然后对每个局部集合进行合并,得到一个全局的、质心到属于该质心的点的簇的集合,作为下一次迭代的输入,也比较容易处理。

迭代终止条件计算

这一步应该算是k-means算法聚类过程中比较核心的步骤。我们考虑了如下3个终止条件:

  1. 比较相邻的2轮迭代结果,在2轮过程中移动的非质心点的个数,设置移动非质心点占比全部点数的最小比例值,如果达到则算法终止
  2. 为了防止k-means聚类过程长时间不收敛,设置最大迭代次数,如果达到最大迭代次数还没有达到上述条件,则也终止计算
  3. 如果相邻2次迭代过程,质心没有发生变化,则算法终止,这是最强的终止约束条件。能够满足这种条件,几乎是不可能的,除非两次迭代过程中没有非质心点重新指派给到另一个不同的质心。

我们计算k-means聚类的核心代码框架,如下所示:

01 @Override
02 public void clustering() {
03 // start centroid calculators
04 for (int i = 0; i < parallism; i++) { // 启动parallism个线程计算距离并指派簇
05 CentroidCalculator calculator = new CentroidCalculator(calculatorQueueSize);
06 calculators.add(calculator);
07 executorService.execute(calculator);
08 LOG.info("Centroid calculator started: " + calculator);
09 }
10
11 // sort by centroid id ASC
12 TreeSet<Centroid> centroids = selectInitialCentroidsPolicy.select(k, allPoints);// 随机选择初始质心
13 LOG.info("Initial selected centroids: " + centroids);
14
15 // 下面进入迭代过程
16 int iterations = 0;
17 boolean stopped = false;
18 CentroidSetWithClusteringPoints lastClusteringResult = null; // 上一轮聚类结果
19 CentroidSetWithClusteringPoints currentClusteringResult = null; // 当前轮聚类结果
20 int totalPointCount = allPoints.size();
21 float currentClusterMovingPointRate = 1.0f;
22 try {
23 // enter clustering iteration procedure
24 while(currentClusterMovingPointRate > maxMovingPointRate
25 && !stopped
26 && iterations < maxIterations) { // 3个终止条件约束
27 LOG.info("Start iterate: #" + (++iterations));
28
29 currentClusteringResult = computeCentroids(centroids); // 每一轮重新计算质心点
30 LOG.info("Re-computed centroids: " + centroids);
31
32 // compute centroid convergence status
33 int numMovingPoints = 0;
34 if(lastClusteringResult == null) {
35 numMovingPoints = totalPointCount;
36 } else {
37 // compare 2 iterations' result for centroid computation
38 numMovingPoints = analyzeMovingPoints(lastClusteringResult.clusteringPoints, currentClusteringResult.clusteringPoints); // 分析两轮聚类结果:在簇之间移动的非质心点的集合
39
40 // check iteration stop condition
41 boolean isIdentical = (currentClusteringResult.centroids.size() ==
42 Multisets.intersection(HashMultiset.create(lastClusteringResult.centroids), HashMultiset.create(currentClusteringResult.centroids)).size()); // 检测终止最强约束条件:两轮迭代是否没有非质心点发生重新指派,即质心完全没变
43 if(iterations > 1 && isIdentical) {
44 stopped = true;
45 }
46 }
47 lastClusteringResult = currentClusteringResult;
48 centroids = currentClusteringResult.centroids;
49 currentClusterMovingPointRate = (float) numMovingPoints / totalPointCount; // 计算非质心点移动比例
50
51 LOG.info("Clustering meta: k=" + k +
52 ", numMovingPoints=" + numMovingPoints +
53 ", totalPointCount=" + totalPointCount +
54 ", stopped=" + stopped +
55 ", currentClusterMovingPointRate=" + currentClusterMovingPointRate );
56
57 // reset some structures
58 reset();
59 for(CentroidCalculator calculator : calculators) {
60 calculator.reset();
61 }
62
63 LOG.info("Finish iterate: #" + iterations);
64 }
65 } finally {
66 // notify all calculators to exit normally
67 clusteringCompletedFinally = true;
68
69 LOG.info("Shutdown executor service: " + executorService);
70 executorService.shutdown();
71
72 // process final clustering result
73 LOG.info("Final clustering result: ");
74 Iterator<Entry<Centroid, Multiset<Point2D>>> iter = currentClusteringResult.clusteringPoints.entrySet().iterator();
75 while(iter.hasNext()) { // 达到终止条件后,处理最终的结果
76 Entry<Centroid, Multiset<Point2D>> entry = iter.next();
77 int id = entry.getKey().getId();
78 Set<ClusterPoint<Point2D>> set = Sets.newHashSet();
79 for(Point2D p : entry.getValue()) {
80 set.add(new ClusterPoint2D(p, id));
81 }
82 clusteredPoints.put(id, set);
83 id++;
84 }
85 centroidSet = currentClusteringResult.clusteringPoints.keySet();
86 }
87 }

下面,我们讨论一下,如何根据两次聚类迭代结果,计算在簇之间移动的点的个数。如果把两轮聚类迭代结果中的k个簇分别从整体上来比较,得出在前后两轮迭代结果中在簇之间移动的非质心点的个数,可能比较麻烦,也容易陷入混乱的计算逻辑中。
我们可以这么思考:假设a、b两轮迭代结束,a轮中生成k个簇的集合Ca={C(a1),C(a2), …,C(ak)},b轮中生成k个簇的集合Cb={C(b1),C(b2), …,C(bk)},我们假设生成的簇是有编号的,而且,a轮生成的簇C(ai),在b轮重新计算质心后生成的新簇为C(bi),这样一一对应起来,分别计算在簇C(ai)与簇C(bi)之间移动的点的个数,首先计算簇C(ai)与簇C(bi)的交集S:

1 S = C(ai) ∩ C(bi)

然后,分别计算簇C(ai)、簇C(bi)与S的差集Dai、Dbi:

1 Dai = Ca - S = Ca - (C(ai) ∩ C(bi) )
2 Dbi = Cb - S = Cb - (C(ai) ∩ C(bi) )

这样,差集Dai和Dbi中的点都是在两轮聚类中移动的非质心点,由于一个簇中的点可能移动到另一个簇中,如某非质心点p,从C(ai)移动到C(bj),其中i不等于j,那么在计算差集Dai与Dbi时,发现C(ai)中少了点p,点p被放入差集Dai;在计算簇C(aj)与簇C(bj)时,发现C(bj)中多了一个点p,则点p又被放入差集Dbj。可见,点p被放入到两个差集Dai和Dbj中,所以我们需要对最终得到的k个差集先做并计算:

1 D = Σ(Dai ∪ dbi), i=1,2, ...k

然后再对集合D做一个去重操作,得到的点的集合就是两轮迭代过程中,在簇之间移动的点的集合。
我们基于上述计算思路实现的代码,对应上面代码中的analyzeMovingPoints方法,代码实现如下所示:

01 private int analyzeMovingPoints(TreeMap<Centroid, Multiset<Point2D>> lastClusteringPoints,
02 TreeMap<Centroid, Multiset<Point2D>> currentClusteringPoints) {
03 // Map<current, Map<last, intersected point count>>
04 Set<Point2D> movingPoints = Sets.newHashSet(); // 用来收集移动的点,使用Set集合类去重
05 Iterator<Entry<Centroid, Multiset<Point2D>>> lastIter = lastClusteringPoints.entrySet().iterator();
06 Iterator<Entry<Centroid, Multiset<Point2D>>> currentIter = currentClusteringPoints.entrySet().iterator();
07 while(lastIter.hasNext() && currentIter.hasNext()) {
08 Entry<Centroid, Multiset<Point2D>> last = lastIter.next();
09 Entry<Centroid, Multiset<Point2D>> current = currentIter.next();
10 Multiset<Point2D> intersection = Multisets.intersection(last.getValue(), current.getValue()); // 计算交集S = C(ai) ∩ C(bi)
11 movingPoints.addAll(Multisets.difference(last.getValue(), intersection)); // 计算差集Dai = Ca - S = Ca - (C(ai) ∩ C(bi) )
12 movingPoints.addAll(Multisets.difference(current.getValue(), intersection));// 计算差集Dbi = Cb - S = Cb - (C(ai) ∩ C(bi) )
13 }
14 return movingPoints.size();
15 }

通过上面的计算逻辑,就能够计算出两轮聚类过程中,在簇之间移动的点的集合和个数。

聚类效果

每次执行k-means聚类,得到的结果都不相同,我们可以执行两次,取k=10,看一下聚类结果的散点图,如下图所示:

图中,标号为9999的点为质心点,上面两图对比可以看出,聚类结果中簇的形状是不同的,其中红色值满足迭代停止条件的质心的坐标位置。
下面,我们选择不同的k值:5、10、20、50,分别执行k-means聚类,然后对比聚类结果,如下图所示:

总结

通过上面的实现,我们知道基本k-means聚类算法的实现过程比较简单,很容易实现。另外,该聚类算法适用于处理具有中心的球形簇,而且运行相当有效。但是,该聚类算法的结果受随机选择的质心的影响,每次计算都得到不同的结果,而且当待聚数据的具有不同的尺寸,或者密度非常不均匀,聚类结果非常的差。为了解决k-means聚类随机算法选择初始质心的问题,会有很多处理方法,可以查阅相关资料,其中bisecting k-means算法(二分k-均值)就是基于基本k-means得到的一种变体,能够比较好地处理,不受随机选择初始质心的影响,后续我们会实现并详细讨论。

目录
相关文章
|
5天前
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
32 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
14天前
|
算法 Java 数据库
理解CAS算法原理
CAS(Compare and Swap,比较并交换)是一种无锁算法,用于实现多线程环境下的原子操作。它通过比较内存中的值与预期值是否相同来决定是否进行更新。JDK 5引入了基于CAS的乐观锁机制,替代了传统的synchronized独占锁,提升了并发性能。然而,CAS存在ABA问题、循环时间长开销大和只能保证单个共享变量原子性等缺点。为解决这些问题,可以使用版本号机制、合并多个变量或引入pause指令优化CPU执行效率。CAS广泛应用于JDK的原子类中,如AtomicInteger.incrementAndGet(),利用底层Unsafe库实现高效的无锁自增操作。
理解CAS算法原理
|
1月前
|
存储 人工智能 缓存
【AI系统】布局转换原理与算法
数据布局转换技术通过优化内存中数据的排布,提升程序执行效率,特别是对于缓存性能的影响显著。本文介绍了数据在内存中的排布方式,包括内存对齐、大小端存储等概念,并详细探讨了张量数据在内存中的排布,如行优先与列优先排布,以及在深度学习中常见的NCHW与NHWC两种数据布局方式。这些布局方式的选择直接影响到程序的性能,尤其是在GPU和CPU上的表现。此外,还讨论了连续与非连续张量的概念及其对性能的影响。
56 3
|
2月前
|
机器学习/深度学习 人工智能 算法
探索人工智能中的强化学习:原理、算法与应用
探索人工智能中的强化学习:原理、算法与应用
|
2月前
|
机器学习/深度学习 人工智能 算法
探索人工智能中的强化学习:原理、算法及应用
探索人工智能中的强化学习:原理、算法及应用
|
2天前
|
算法 数据安全/隐私保护
室内障碍物射线追踪算法matlab模拟仿真
### 简介 本项目展示了室内障碍物射线追踪算法在无线通信中的应用。通过Matlab 2022a实现,包含完整程序运行效果(无水印),支持增加发射点和室内墙壁设置。核心代码配有详细中文注释及操作视频。该算法基于几何光学原理,模拟信号在复杂室内环境中的传播路径与强度,涵盖场景建模、射线发射、传播及接收点场强计算等步骤,为无线网络规划提供重要依据。
|
15天前
|
机器学习/深度学习 算法
基于改进遗传优化的BP神经网络金融序列预测算法matlab仿真
本项目基于改进遗传优化的BP神经网络进行金融序列预测,使用MATLAB2022A实现。通过对比BP神经网络、遗传优化BP神经网络及改进遗传优化BP神经网络,展示了三者的误差和预测曲线差异。核心程序结合遗传算法(GA)与BP神经网络,利用GA优化BP网络的初始权重和阈值,提高预测精度。GA通过选择、交叉、变异操作迭代优化,防止局部收敛,增强模型对金融市场复杂性和不确定性的适应能力。
149 80
|
3天前
|
机器学习/深度学习 数据采集 算法
基于GA遗传优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
本项目基于MATLAB2022a实现时间序列预测,采用CNN-GRU-SAM网络结构。卷积层提取局部特征,GRU层处理长期依赖,自注意力机制捕捉全局特征。完整代码含中文注释和操作视频,运行效果无水印展示。算法通过数据归一化、种群初始化、适应度计算、个体更新等步骤优化网络参数,最终输出预测结果。适用于金融市场、气象预报等领域。
基于GA遗传优化的CNN-GRU-SAM网络时间序列回归预测算法matlab仿真
|
3天前
|
算法
基于龙格库塔算法的锅炉单相受热管建模与matlab数值仿真
本设计基于龙格库塔算法对锅炉单相受热管进行建模与MATLAB数值仿真,简化为喷水减温器和末级过热器组合,考虑均匀传热及静态烟气处理。使用MATLAB2022A版本运行,展示自编与内置四阶龙格库塔法的精度对比及误差分析。模型涉及热传递和流体动力学原理,适用于优化锅炉效率。
|
8天前
|
机器学习/深度学习 算法
基于遗传优化的双BP神经网络金融序列预测算法matlab仿真
本项目基于遗传优化的双BP神经网络实现金融序列预测,使用MATLAB2022A进行仿真。算法通过两个初始学习率不同的BP神经网络(e1, e2)协同工作,结合遗传算法优化,提高预测精度。实验展示了三个算法的误差对比结果,验证了该方法的有效性。

热门文章

最新文章