训练第一个机器学习模型

简介:

机器模型

导语

在笔者的上一篇文章中[1],使用了 k-NN 算法来识别手写字数据集,它的缺点是浪费存储空间且执行效率低。本文将使用决策树算法来解决同样的问题。相对 k-NN 算法,它更节约存储空间且执行效率更高。更重要的是,实施决策树算法的过程将训练算法并得到知识 —— 这是开发机器学习程序的一般步骤。一旦理解了这个工作流程,才有可能利用好机器学习这把利剑。

在本文中,笔者将训练一个决策树模型并使用该模型来识别手写字数据集。从中读者将可以了解到:如何构建学习模型?模型经过训练后学习到了怎样的知识?学习到的知识怎么表示和存储?又该如何利用这些学到的知识来解决同类的问题?

本文适合以下背景的读者阅读:

  • 了解 MNIST 数据集[2];
  • 使用 Javascript 作为编程语言的开发者;
  • 不需要具备算法能力和高数的背景:全文只有一道数学公式;
  • 加上示例代码,全文总共 460 行,大约需要 20 分钟的阅读时间。

作者学识有限,如有疏漏,敬请指正。

生活中的决策

在开始构建决策树之前,必须了解决策树的工作原理。更详细的内容可以从参考资料的链接[2]中获得。

一个例子是,如何教育一个学龄前的儿童辨认猫和老虎?

猫和老虎

  • 我们会拿来一些示例照片,对照这些照片根据某些特征来训练小孩,告他 A 是猫,B 是老虎;
  • 这些特征可能是,表面的颜色、耳朵的形状、体积的大小等等;
  • 我们总是希望儿童能快速辨认出猫和老虎,毕竟假如他们真的遇到了老虎,则需要和老虎保持一定的距离;
  • 其中一种筛选方法就是决策模型:把认为最重要的特征先进行甄别,然后到次要的,再到次次要的,以此来加速决策过程并得出判定。

作为一个示例,这里假设将识别老虎分为 2 个特征,分别是耳朵的形状和体积大小,那么已知的数据可能是这样的:

Index Shape of the ear Size Animal
1 Triangle Small Cat
2 Triangle Small Cat
3 Triangle Big Tiger
4 Circular Small Tiger
5 Circular Big Tiger

在程序中将使用数组的形式来表示上列数据,我把它称为「抓虎的数据集」:

const dataSet = [
  ['Triangle', 'Small', 'Cat'],
  ['Triangle', 'Small', 'Cat'],
  ['Triangle', 'Big', 'Tiger'],
  ['Circular', 'Small', 'Tiger'],
  ['Circular', 'Big', 'Tiger'],
];

根据已有的数据集(经验),猫和老虎的决策树则是这样:

「抓虎」的决策树

这就是决策树的工作原理了。因为属于分类算法,所以决策树也可以推演到 MNIST 数据集的识别中。把 728 个点作为特征,对应的数字作为分类目标即可应用决策树算法。当然决策树算法不适合解决 MNIST 数据集这类特征为数值型的问题,但是因为它易于理解和实现,人们在通过解释后都有能力去理解决策树所表达的意义,因此作为机器学习中训练模型的算法来进行入门则非常合适。

那么决策树模型在程序中应该如何构建和表示呢?

构建决策树

决策树的构建过程就是在训练数据集中不断划分数据集,直到找到目标分类的过程。在此过程中需要找到最好的数据集划分方式,递归地不断划分数据集,直到所有的分类都属于同一类目或没有多余特征时停止生长。可以结合上一章节的「抓虎」的决策树进行理解。

找出最佳特征来划分数据

不难看出,构建决策树的关键问题是如何找出最佳的特征来划分数据集。先要回答问题是,假设我按照某个特征将数据集一分为二,那么有 N 种划分方式,哪一种才算做「最好的划分方式」?这就得引入香农熵的概念。

香农熵

划分数据集的大原则是:将无序的数据变得更加有序。

在「抓虎」的决策树中,耳朵的形状是最佳的划分特征,因为根据它来划分后的数据集更加有序了(混杂项更少)。度量集合有序程度的其中一种方法就是香农熵。香农熵是信息论中的内容,有兴趣的读者可以从参考资料的链接[4]中获得更详细的内容。在此只需要知道的是,香农熵越低则集合越有序

香农熵的计算公式是:

香农熵公式图

根据公式,在程序中实现计算香农熵的代码:

function calcShannonEnt(dataSet) {
  const labelCounts = {};
  for (let featVec of dataSet) {
    const currentLabel = featVec[featVec.length - 1];
    if (Object.keys(labelCounts).indexOf(currentLabel) === -1) {
      labelCounts[currentLabel] = 1;
    } else {
      labelCounts[currentLabel]++;
    }
  }

  let shannonEnt = 0.0;
  const numEntries = dataSet.length;
  for (let i in labelCounts) {
    const x = labelCounts[i];
    const probability = x / numEntries; // p(x)
    shannonEnt = shannonEnt - probability * log2(probability); // -Σp*log(p) 
  }
  return shannonEnt;
}

进行一些测试将会有助于理解香农熵的含义:

// 注意:初始化时数据集里面只有 2 个目标分类(yes or no)
const dataSet = [
  [1, 1, 'yes'],
  [1, 1, 'yes'],
  [1, 0, 'no'],
  [0, 1, 'no'],
  [0, 0 'no']
];

console.log(calcShannonEnt(dataSet)); // 0.9709505944546686

dataSet[0][dataSet[0].length - 1] = 'maybe'; // 混合更多的分类
console.log(calcShannonEnt(dataSet)); // 1.3709505944546687 (香农熵变大,说明数据集更无序了)

根据特征划分数据集

实现一个函数,根据特征来划分数据集:

function splitDataSet(dataSet, index, value) {
  const retDataSet = [];
  for (let featVec of dataSet) {
    if (featVec[index] === value) {
      let reducedFeatVec = featVec.slice(0, index);
      reducedFeatVec = reducedFeatVec.concat(featVec.slice(index + 1));
      retDataSet.push(reducedFeatVec);
    }
  }

  return retDataSet;
}

拿「抓虎」的数据集进行测试,看看划分后的数据长什么样?

console.log(splitDataSet(dataSet, 0, 'Triangle'));
// Triangle [ [ 'Small', 'Cat' ], [ 'Small', 'Cat' ], [ 'Big', 'Tiger' ] ]

console.log(splitDataSet(dataSet, 0, 'Circular'));
// Circular [ [ 'Small', 'Tiger' ], [ 'Big', 'Tiger' ] ]

从结果上看,成功地按照某个特征值把数据划分了出来。

组合计算熵的算法和划分数据集的函数,就可以找出最佳的数据划分特征项。以下是代码实现:

function uniqueDataSetColumn(dataSet, i) {
  const uniqueValues = [];
  dataSet.forEach((element) => {
    const value = element[i];
    if (uniqueValues.indexOf(value) === -1) {
      uniqueValues.push(value)
    }
  });

  return uniqueValues;
}
function chooseBestFeatureToSplit(dataSet) {
  const numberFeatures = dataSet[0].length;
  let baseEntropy = calcShannonEnt(dataSet);
  let bestInfoGain = 0.0;
  let bestFeature = -1;

  // 对比每个特征划分数据的熵,找出最佳划分特征
  for (let i = 0, length = numberFeatures - 1; length > i; i++) {
    const uniqueValues = uniqueDataSetColumn(dataSet, i);

    // 计算熵
    let newEntropy = 0.0;
    uniqueValues.forEach((value) => {
      const subDataSet = splitDataSet(dataSet, i, value);
      const probability = subDataSet.length / dataSet.length;
      newEntropy += probability * calcShannonEnt(subDataSet);
    });

    const infoGain = baseEntropy - newEntropy;
    if (infoGain > bestInfoGain) {
      bestInfoGain = infoGain;
      bestFeature = i;
    }
  }

  return bestFeature;
}

将该函数在「抓虎」的数据集进行测试,这个数据集的第一划分依据是什么特征?

console.log(chooseBestFeatureToSplit(dataSet));

如无意外,程序将输出 0。耳朵的形状是最佳的划分特征,证明程序达到了我们预想的效果。

递归构建决策树

将上面的函数结合起来,再不断地进行递归就可以构建出决策树模型。什么时候应该停止递归?有 2 种情况:

  1. 当所有的分类都属于同一类目时,停止划分数据 —— 该分类即是目标分类;
  2. 划分的数据集中没有其他特征时,停止划分数据 —— 根据出现次数最多的类别作为目标分类。

构建树的入参是什么?

  1. 训练数据集 —— 从训练数据中提取决策知识;
  2. 特征的标签 —— 用于绘制决策树每个节点。

以下是代码实现:

// 辅助函数,根据出现次数最多的类别作为目标分类
function majority(classList) {
  const classCount = {};
  for (let vote of classList) {
    if (Object.keys(classCount).indexOf(vote) === -1) {
      classCount[vote] = 1;
    } else {
      classCount[vote]++;
    }
  }

  let predictedClass = '';
  let topCount = 0;
  for (const voteLabel in classCount) {
    if (classCount[voteLabel] > topCount) {
      predictedClass = voteLabel;
      topCount = classCount[voteLabel];
    }
  }
  return predictedClass;
}
function createTree(dataSet, featureLabels) {
  const classList = dataSet.map((elements) => elements[elements.length - 1]);
  
  // 当所有的分类都属于同一类目时,停止划分数据
  let count = 0;
  classList.forEach((classItem) => {
    if (classItem === classList[0]) {
      count++;
    }
  });
  if (count == classList.length) {
    return classList[0]
  }

  // 数据集中没有其他特征时,停止划分数据,根据出现次数最多的类别作为返回值
  if (dataSet[0].length === 1) {
    return majority(classList);
  }

  // 1. 找到最佳划分数据集的特征
  const bestFeat = chooseBestFeatureToSplit(dataSet);
  const bestFeatLabel = featureLabels[bestFeat];
  const myTree = {[bestFeatLabel]: {}};

  // 2. 获得特征的枚举值
  const uniqueValues = uniqueDataSetColumn(dataSet, bestFeat);

  // 3. 根据特征值划分数据(创建子节点)
  uniqueValues.forEach((value) => {
    const newDataSet = splitDataSet(dataSet, bestFeat, value);
    const subLabels = featureLabels.filter((label, key) => key !== bestFeat);

    // 4. 递归划分
    myTree[bestFeatLabel][value] = createTree(newDataSet, subLabels)
  });

  return myTree;
}

自此就完成了学习模型的构建。

训练算法得到知识

将已有的数据集使用决策树模型进行训练,将会得到怎样的知识?

以「抓虎」为例,运行以下代码:

const tree = createTree(dataSet, ['Shape', 'Size']);
// {"Shape":{"Triangle":{"Size":{"Small":"Cat","Big":"Tiger"}},"Circular":"Tiger"}}

可见,能得到的知识是针对数据集学习到的特征权重顺序排列,是层层筛选决策的依据。

为了更加直观和易于理解,可以将数据可视化(关于如何进行数据可视化不是本文的内容),它大概长这样:

决策树图

在程序中加入知识的存储和提取函数,方便利用已有的知识进行推理。所以再声明 2 个辅助函数:

function storeTree(inputTree, filename) {
  fs.writeFileSync(filename, JSON.stringify(inputTree));
}

function grabTree(filename) {
  return JSON.parse(fs.readFileSync(filename, 'utf8'))
}

使用已有的知识进行推理

只需要写一个解析树的函数就可以将学习到决策知识推理到同类的数据集中。以下是代码实现:

function classify(inputTree, featureLabels, testVec) {
  const firstStr = Object.keys(inputTree)[0];
  const secondElement = inputTree[firstStr];
  const featIndex = featureLabels.indexOf(firstStr);
  const key = testVec[featIndex];
  const valueOfFeat = secondElement[key];
  if (typeof valueOfFeat === 'object') {
    return classify(valueOfFeat, featureLabels, testVec);
  } else {
    return valueOfFeat;
  }
}

以「抓虎」为例,下次见到一个耳朵形状是三角形,体积较小的动物,根据我们之前学习到的知识,它应该是猫还是老虎?

console.log(classify(tree, ['Shape', 'Size'], ['Triangle', 'Small']));
// Cat

如无意外,将会输出 "Cat"。

应用到 MNIST 数据集

最后,组合上面的函数,将其应用到 MNIST 数据集的识别中。

值得注意的是,在数据准备环节需要一些工作以适应上文构建的算法:

  • 将特征由数值型转化为标称型,这里我用了 0 / 1;
  • 将分类值由 one-hot 向量转化为具体的数字。

准备数据

const mnist = require('mnist');
const fs = require('fs');
const path = require('path');
const trainingCount = 8000;
const testCount = 2000;
const {training, test} = mnist.set(trainingCount, testCount);

fs.writeFileSync(path.join(__dirname, 'mnist_trainingData.json'), JSON.stringify(training));
fs.writeFileSync(path.join(__dirname, 'mnist_testData.json'), JSON.stringify(test));

学习阶段

const mnist = require('mnist');
const path = require('path');
const fs = require('fs');

// 1. 加载数据
const trainingData = JSON.parse(fs.readFileSync(path.join(__dirname, 'mnist_trainingData.json'), 'utf8'));

// 2. 准备数据
let data = [];
trainingData.forEach(({input, output}) => {
  // 将分类值由 one-hot 向量转化为具体的数字
  const number = String(output.indexOf(output.reduce((max, activation) => Math.max(max, activation), 0)));
  
  // 数值型特征转换为标称型
  data.push(toZeroOne(input).concat([number]));
});

// 特征的标签
const labels = mnist[0].get().map((number, key) => `number_${key}`);

// 3. 分析数据:在命令行中检查数据,确保它的格式符合要求
console.log('data', JSON.stringify(data[0]));
console.log('labels', JSON.stringify(labels));

// 4. 训练算法
const startTime = Date.now();
const tree = createTree(data, labels);
console.log('tree', JSON.stringify(tree));
console.log(`Spend: ${(Date.now() - startTime) / 1000}s`);

// 存储学到的知识
storeTree(tree, path.join(__dirname, 'mnist_tree.txt'));

在笔者的电脑上大概运行了 10 分钟:

学习解决的耗时

看起来运行时间很长,那怎么能说比 k-NN 算法更有效率?!

其实这是训练阶段的耗时,而训练阶段往往是离线处理,有大量的手段可以优化这部分的性能。

应用阶段

const mnist = require('mnist');
const path = require('path');
const fs = require('fs');

// 1. 加载测试数据
const testData = JSON.parse(fs.readFileSync(path.join(__dirname, 'mnist_testData.json'), 'utf8'));
const testCount = testData.length;

// 获取先前学习的知识
const tree = grabTree(path.join(__dirname, './mnist_tree.txt'));
const labels = mnist[0].get().map((number, key) => `number_${key}`);

// 2. 测试算法
let errorCount = 0;
const startTime = Date.now();
testData.forEach(({input, output}, key) => {
  const number = output.indexOf(output.reduce((max, activation) => Math.max(max, activation), 0));
  const predicted = classify(tree, labels, toZeroOne(input));
  const result = predicted == number;
  console.log(`${key}. number is ${number}, predicted is ${predicted}, result is ${result}`);

  if (!result) {
    errorCount++;
  }
});
console.log(`The total number of errors is: ${errorCount}`);
console.log(`The total error rate is: ${errorCount / testCount}`);
console.log(`Spend: ${(Date.now() - startTime) / 1000}s`);

// 3. 使用算法
const number = 8;
console.log('Result is', classify(tree, labels, toZeroOne(mnist[number].get())));

如无意外,终端命令行中将输出以下结果:

应用的输出结果

在同样的数据集中,笔者上一篇文章构建的 k-NN 算法,运行时长是 325 秒,错误率是 0.05。这组数据该如何解读?笔者认为:

  1. 决策树的在预测阶段计算量非常小,所以执行效率非常高;
  2. 本文做特征处理时丢失了很多信息,数值型特征转换到 0/1 的方式太过于粗暴。

使用决策树算法来识别 MNIST 数据集效果很不理想,不过从中可以看到构建一个机器学习应用的完整过程。

参考资料

  1. 机器学习,Hello World from Javascript!
  2. MNIST 数据集
  3. 决策树
  4. 香农熵
  5. 本文示例代码

文章封面图由 Igor Ovsyannykov 发表在 Unsplash

相关文章
|
16天前
|
人工智能 JSON 算法
Qwen2.5-Coder 系列模型在 PAI-QuickStart 的训练、评测、压缩及部署实践
阿里云的人工智能平台 PAI,作为一站式、 AI Native 的大模型与 AIGC 工程平台,为开发者和企业客户提供了 Qwen2.5-Coder 系列模型的全链路最佳实践。本文以Qwen2.5-Coder-32B为例,详细介绍在 PAI-QuickStart 完成 Qwen2.5-Coder 的训练、评测和快速部署。
Qwen2.5-Coder 系列模型在 PAI-QuickStart 的训练、评测、压缩及部署实践
|
20天前
|
机器学习/深度学习 PyTorch API
优化注意力层提升 Transformer 模型效率:通过改进注意力机制降低机器学习成本
Transformer架构自2017年被Vaswani等人提出以来,凭借其核心的注意力机制,已成为AI领域的重大突破。该机制允许模型根据任务需求灵活聚焦于输入的不同部分,极大地增强了对复杂语言和结构的理解能力。起初主要应用于自然语言处理,Transformer迅速扩展至语音识别、计算机视觉等多领域,展现出强大的跨学科应用潜力。然而,随着模型规模的增长,注意力层的高计算复杂度成为发展瓶颈。为此,本文探讨了在PyTorch生态系统中优化注意力层的各种技术,
48 6
优化注意力层提升 Transformer 模型效率:通过改进注意力机制降低机器学习成本
|
9天前
|
机器学习/深度学习 人工智能 算法
人工智能浪潮下的编程实践:构建你的第一个机器学习模型
在人工智能的巨浪中,每个人都有机会成为弄潮儿。本文将带你一探究竟,从零基础开始,用最易懂的语言和步骤,教你如何构建属于自己的第一个机器学习模型。不需要复杂的数学公式,也不必担心编程难题,只需跟随我们的步伐,一起探索这个充满魔力的AI世界。
26 12
|
16天前
|
机器学习/深度学习 Python
机器学习中评估模型性能的重要工具——混淆矩阵和ROC曲线。混淆矩阵通过真正例、假正例等指标展示模型预测情况
本文介绍了机器学习中评估模型性能的重要工具——混淆矩阵和ROC曲线。混淆矩阵通过真正例、假正例等指标展示模型预测情况,而ROC曲线则通过假正率和真正率评估二分类模型性能。文章还提供了Python中的具体实现示例,展示了如何计算和使用这两种工具来评估模型。
35 8
|
16天前
|
机器学习/深度学习 Python
机器学习中模型选择和优化的关键技术——交叉验证与网格搜索
本文深入探讨了机器学习中模型选择和优化的关键技术——交叉验证与网格搜索。介绍了K折交叉验证、留一交叉验证等方法,以及网格搜索的原理和步骤,展示了如何结合两者在Python中实现模型参数的优化,并强调了使用时需注意的计算成本、过拟合风险等问题。
35 6
|
19天前
|
机器学习/深度学习 数据采集 算法
从零到一:构建高效机器学习模型的旅程####
在探索技术深度与广度的征途中,我深刻体会到技术创新既在于理论的飞跃,更在于实践的积累。本文将通过一个具体案例,分享我在构建高效机器学习模型过程中的实战经验,包括数据预处理、特征工程、模型选择与优化等关键环节,旨在为读者提供一个从零开始构建并优化机器学习模型的实用指南。 ####
|
23天前
|
人工智能 边缘计算 JSON
DistilQwen2 蒸馏小模型在 PAI-QuickStart 的训练、评测、压缩及部署实践
本文详细介绍在 PAI 平台使用 DistilQwen2 蒸馏小模型的全链路最佳实践。
|
20天前
|
机器学习/深度学习 人工智能 算法
探索机器学习中的线性回归模型
本文深入探讨了机器学习中广泛使用的线性回归模型,从其基本概念和数学原理出发,逐步引导读者理解模型的构建、训练及评估过程。通过实例分析与代码演示,本文旨在为初学者提供一个清晰的学习路径,帮助他们在实践中更好地应用线性回归模型解决实际问题。
|
29天前
|
机器学习/深度学习 数据采集 监控
如何使用机器学习模型来自动化评估数据质量?
如何使用机器学习模型来自动化评估数据质量?
|
25天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
69 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型