TensorFlow 2 quickstart for beginners

简介: This short introduction uses Keras to:1. Build a neural network that classifies images.2. Train this neural network.3. And, finally, evaluate the accuracy of the model.

This short introduction uses Keras to:

  1. Build a neural network that classifies images.
  2. Train this neural network.
  3. And, finally, evaluate the accuracy of the model.
import tensorflow as tf

Load and prepare the MNIST dataset. Convert the samples from integers to floating-point numbers:

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Build the tf.keras.Sequential model by stacking layers. Choose an optimizer and loss function for training:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

For each example the model returns a vector of "logits" or "log-odds" scores, one for each class.

predictions = model(x_train[:1]).numpy()
predictions

The tf.nn.softmax function converts these logits to "probabilities" for each class:

tf.nn.softmax(predictions).numpy()

Note: It is possible to bake this tf.nn.softmax in as the activation function for the last layer of the network. While this can make the model output more directly interpretable, this approach is discouraged as it's impossible to provide an exact and numerically stable loss calculation for all models when using a softmax output.

The losses.SparseCategoricalCrossentropy loss takes a vector of logits and a True index and returns a scalar loss for each example.

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

This loss is equal to the negative log probability of the true class: It is zero if the model is sure of the correct class.

This untrained model gives probabilities close to random (1/10 for each class), so the initial loss should be close to -tf.math.log(1/10) ~= 2.3

loss_fn(y_train[:1], predictions).numpy()
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

The Model.fit method adjusts the model parameters to minimize the loss:

model.fit(x_train, y_train, epochs=5)

The Model.evaluate method checks the models performance, usually on a "Validation-set" or "Test-set".

model.evaluate(x_test,  y_test, verbose=2)

The image classifier is now trained to ~98% accuracy on this dataset. To learn more, read the TensorFlow tutorials.

If you want your model to return a probability, you can wrap the trained model, and attach the softmax to it:

probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])
probability_model(x_test[:5])

代码链接: https://codechina.csdn.net/csdn_codechina/enterprise_technology/-/blob/master/CV_Classification/TensorFlow%202%20quickstart%20for%20beginners.ipynb

目录
相关文章
|
TensorFlow 算法框架/工具 Python
TensorFlow安装部署
1.环境依赖 Centos7 组件 版本 Python 2.7.5 TensorFlow 0.14.0 pyhton依赖库 Package Version -------------------- --------- absl-py 0.
1961 0
|
PyTorch 算法框架/工具
快速安装Pytorch
快速安装Pytorch
103 0
快速安装Pytorch
|
Web App开发 存储 数据可视化
【Pytorch 安装TensorboardX及使用
【Pytorch 安装TensorboardX及使用
1267 0
|
TensorFlow 算法框架/工具 异构计算
YOLO实践应用之搭建开发环境(Windows系统、Python 3.8、TensorFlow2.3版本)
基于YOLO进行物体检测、对象识别,先和大家分享如何搭建开发环境,会分为CPU版本、GPU版本的两种开发环境,本文会分别详细地介绍搭建环境的过程。主要使用TensorFlow2.3、opencv-python4.4.0、Pillow、matplotlib 等依赖库。
353 0
|
PyTorch 算法框架/工具 计算机视觉
Py之torchvision:torchvision库的简介、安装、使用方法之详细攻略
Py之torchvision:torchvision库的简介、安装、使用方法之详细攻略
Py之torchvision:torchvision库的简介、安装、使用方法之详细攻略
|
4月前
|
机器学习/深度学习 开发者 数据格式
Gradio如何使用
**Gradio** 是一个开源 Python 库,用于快速创建和部署机器学习模型的用户界面。它支持多种输入输出形式,如文本、图像、音频等,无需复杂 Web 开发知识即可实现模型的直观展示和交互。Gradio 特点包括简单易用、实时更新、多样的输入输出形式以及轻松部署。通过几个简单的步骤,即可创建和分享功能强大的机器学习应用。
126 0
|
并行计算 PyTorch 算法框架/工具
Win10 Python3.7 安装pytorch1.5+ mmdetection2.1,搭建mmdetection环境
Win10 Python3.7 安装pytorch1.5+ mmdetection2.1,搭建mmdetection环境
403 0
Win10 Python3.7 安装pytorch1.5+ mmdetection2.1,搭建mmdetection环境
|
机器学习/深度学习 并行计算 测试技术
ModelScope官方镜像,CPU环境镜像(python3.8)pull不存在
在pullModelScope官方镜像时,一直pull失败,发现官方镜像应该没有推送,Python3.7的是有的
|
Web App开发 开发工具 git

热门文章

最新文章