独热编码和交叉熵损失函数 | 学习笔记

简介: 快速学习独热编码和交叉熵损失函数

开发者学堂课程【Tensorflow2.0入门与实战独热编码和交叉熵损失函数】学习笔记,与课程紧密联系,让用户快速学习知识。

课程地址https://developer.aliyun.com/learning/course/664/detail/11107


独热编码和交叉熵损失函数


基本介绍

一.什么是独热编码

二.使用独热编码训练要注意的事项


一. 什么是独热编码

train_lable 是顺序编码,用0、1、2、3、4、5、6、7、8、9分别代表衬衫、鞋子等等将它编码成独热编码。

独热编码是一种数值化的方法。比如有三个城市分别是北京、上海、深圳,对这三个城市进行编码,可以编码成0、1、2,也可以编码成与它长度相同的维度。北京编码为【1,0,0】,上海【0,1,0】,深圳【0,0,1】标志为这个城市的为1其它地方都为0。 

当 label 进行独热编码时使用 categorical_ crossentropy 交叉熵。将 train_lable 进行独热编码改成 train_label_onehot,onehot 代表独热编码,tf.keras.utils.to_categorical 将顺序编码变为独热编码。

train_label_onehot=tf.keras.utils.to_categorical(train_lable)

train label_onehot

array([[0.,0.,0.,..., 0.,0.,1.],

[1.,0.,0.,..., 0.,0.,0.],

[1.,0.,0.,..., 0.,0.,0.],

...,

[0.,0.,0.,..., 0.,0.,0.],

[1.,0.,0.,..., 0.,0.,0.],

[0.,0.,0.,..., 0.,0.,0.]],dtype=float32)

9将会变化成[0.,0.,0.,..., 0.,0.,1.],长度为十的向量,最后一位被标注为1

train_label_onehot[0]

array([0.,0.,0.,0.,0.,0.,0.,0.,0.,1.],dtype=float32)

最后一个5变化成独热编码[0.,0.,0.,0.,0.,1.,0.,0.,0.,0.]

train_label_onehot[-1]

array([0.,0.,0.,0.,0.,1.,0.,0.,0.,0.],dtype=float32)

可以将 test 数据集做成独热编码,

同样使用tf.keras.utils.to_categorical

test_label_onehot =tf.keras.utils.to_categorical(test_label)

test_label_onehot

array([[0.,0.,0.,..., 0.,0.,1.].

[0.,0.,1.,..., 0.,0.,0.],

[0.,1.,0.,..., 0.,0.,0.],

...,

[0.,0.,0.,..., 0.,1.,0.],

[0.,1.,0.,..., 0.,0.,0.],

[0.,0.,0.,..., 0.,0.,0.]],dtype=float32)

分类数据进行处理时也会讲到独热编码,可以算为一种方法。

test_label 第一个为9变化为[0.,0.,0.,..., 0.,0.,1.],第二个2变化为[0.,0.,1.,..., 0.,0.,0.]

 

二.使用独热编码训练要注意的事项

建立与之前相同的网络,复制网络。训练5个 epochs

model =tf.keras. Sequential()

model.add(tf.keraslayers.Flatten(input_shape=(28,28)))#28*28

model.add(tf.keras.layers.Dense(128,activation='relu')) model.add(tf.keras.lavers.Dense(10.activation=softmax)

model.compile(optimizer=adam',

loss='categorical_crossentropy’,

metrics=['acc']

)

model. fit(train_image, train_lable,epochs=5)

注意:

当 label 是顺序编码时使用的交叉熵损失函数是 sparse_categorical_crossentropy,label 是独热编码时使用 categorical_crossentropy,loss 值独热编码使用 categorical_crossentropy。数据使用的是 train 数据,但是编码使用的是 categorical_crossentropy 所以会报错。应该使用 train_label_onehot

model. fit(train_image,train_label_onehot,epochs=5

对 test_image 进行 predict,查看 predict 的形状,有10000个长度为10的向量

predict=model.predict(test_image)

predict.shape

(10000, 10)

test 数据集的形状,10000张图片

test_image.shape

(10000,28,28)

第一个分类结果

predict[0]

array([7.7417062e-05,1.2555851e-07,5.2015298e-06,3.90

63170e-06,6.1778355e-06,1.3308496e-02,5.2028918e-05,1,2039219e-02,6.5957895e-05,9.7444147e-01]dtype=float32)

softmax 的输出所有分量样本之和为1,所有样本相加的可能性为100%。哪个最大就是要预测的值。

使用 np.argamx 取出最大值所在的缩影

np.argamx(predict[0])

9

test_label[0]

9

说明预测的结果是正确的。

相关文章
|
9月前
|
机器学习/深度学习 编解码 算法
NeoBERT:4096 tokens上下文窗口,参数更少但性能翻倍
NeoBERT是新一代双向编码器模型,整合了前沿架构改进、大规模数据集和优化预训练策略,缩小了传统编码器与高性能自回归语言模型的差距。它支持4096 tokens的扩展上下文窗口,仅250M参数规模,却在MTEB基准中超越多个更大参数量的模型。通过技术创新如旋转位置嵌入和SwiGLU激活函数,以及两阶段预训练策略,NeoBERT在高效性和性能上取得了显著突破。
240 26
NeoBERT:4096 tokens上下文窗口,参数更少但性能翻倍
|
10月前
|
机器学习/深度学习 计算机视觉
YOLOv11改进策略【注意力机制篇】| 2024 PPA 并行补丁感知注意模块,提高小目标关注度
YOLOv11改进策略【注意力机制篇】| 2024 PPA 并行补丁感知注意模块,提高小目标关注度
289 1
YOLOv11改进策略【注意力机制篇】| 2024 PPA 并行补丁感知注意模块,提高小目标关注度
|
机器学习/深度学习 人工智能 算法框架/工具
深入浅出:使用深度学习进行图像分类
【8月更文挑战第31天】在本文中,我们将一起探索如何利用深度学习技术对图像进行分类。通过简明的语言和直观的代码示例,我们将了解构建和训练一个简单卷积神经网络(CNN)模型的过程。无论你是初学者还是有一定基础的开发者,这篇文章都将为你提供清晰的指导和启发性的见解,帮助你理解并应用深度学习解决实际问题。
|
机器学习/深度学习 存储 人工智能
【博士每天一篇文献-算法】改进的PNN架构Progressive learning A deep learning framework for continual learning
本文提出了一种名为“Progressive learning”的深度学习框架,通过结合课程选择、渐进式模型容量增长和剪枝机制来解决持续学习问题,有效避免了灾难性遗忘并提高了学习效率。
524 4
|
机器学习/深度学习 自然语言处理 搜索推荐
基于图神经网络的电商购买预测
基于图神经网络的电商购买预测
|
数据处理 索引 Python
Pandas中的filter函数:有点鸡肋
Pandas中的filter函数:有点鸡肋
400 1
|
机器学习/深度学习 算法 Python
在Python中,独热编码(One-Hot Encoding)
在Python中,独热编码(One-Hot Encoding)
1522 8
|
机器学习/深度学习 算法 计算机视觉
【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting
LwF(Learning without Forgetting)是一种机器学习方法,通过知识蒸馏损失来在训练新任务时保留旧任务的知识,无需旧任务数据,有效解决了神经网络学习新任务时可能发生的灾难性遗忘问题。
1199 9
|
数据处理 索引 Python
Pandas中的filter函数:有点鸡肋
Pandas中的filter函数:有点鸡肋
461 0
|
编解码 开发工具 git
技术心得记录:小波变换(wavelettransform)的通俗解释(一)
技术心得记录:小波变换(wavelettransform)的通俗解释(一)
1075 0