keras 迁移学习inception_v3,缺陷检测

简介:

from keras.models import Sequential
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.preprocessing import image
from keras.preprocessing.image import img_to_array
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from keras.utils import to_categorical
from keras.applications import inception_v3
from keras.layers import GlobalAveragePooling2D
from keras.models import Model
import matplotlib.pyplot as plt
import imutils
import numpy as np
import argparse
import random
import pickle
import cv2
import os
from PIL import Image
import matplotlib
matplotlib.use("Agg")

# 获取该路径下所有图片
path = list(imutils.paths.list_images(r'C:\Users\Desktop\guangdong\train'))

imagePaths = sorted(path)
random.shuffle(imagePaths)

name_dic = {'正常':'norm','不导电':'defect1','擦花':'defect2','横条压凹':'defect3','桔皮':'defect4','漏底':'defect5',
'碰伤':'defect6','起坑':'defect7','凸粉':'defect8','涂层开裂':'defect9','脏点':'defect10','其他':'defect11'}

# 将其他文件夹中,名称都改为其他
other_list_1 = os.listdir(r'C:\Users\Desktop\guangdong\train\guangdong_round1_train2_20180916\guangdong_round1_train2_20180916\瑕疵样本\其他')
other_list = other_list_1[1:]

other_dic = { '伤口':'其他', '划伤':'其他', '变形':'其他', '喷流':'其他', '喷涂碰伤':'其他', '打白点':'其他',
'打磨印':'其他','拖烂':'其他', '杂色':'其他', '气泡':'其他', '油印':'其他', '油渣':'其他',
'漆泡':'其他', '火山口':'其他', '碰凹':'其他', '粘接':'其他', '纹粗':'其他', '角位漏底':'其他',
'返底':'其他', '铝屑':'其他', '驳口':'其他'}

# 打印出name_dic里的英文部分,手动复制,再在每个后面添加‘:’及相应的数字
name_dic.values()
digit_dir = {'norm':0, 'defect1':1, 'defect2':2, 'defect3':3, 'defect4':4, 'defect5':5, 'defect6':6, 'defect7':7, 'defect8':8,
'defect9':9, 'defect10':10, 'defect11':11}

# 将图片resize成inception_v3需要的(299,299)大小,并转化成array
labels = []
data =[]
for imagePath in imagePaths:
img = Image.open(imagePath)
img = img.resize((299,299))
img = img_to_array(img)
data.append(img)
label_gbk = imagePath.split('\\')[-1].split('2')[0]
if label_gbk in other_list:
label_gbk = other_dic[label_gbk]
label_english = name_dic[label_gbk]
label = digit_dir[label_english]
print(label_gbk,':',label_english,':',label)
labels.append(label)

# 像素归一化(有利于加速收敛)
labels = np.array(labels)
data = np.array(data, dtype="float") / 255.0
# 标签one-hot
labels = to_categorical(labels)

x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)
# 数据增强
train_aug = ImageDataGenerator(rotation_range=25, width_shift_range=0.1,height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
horizontal_flip=True, fill_mode="nearest",preprocessing_function=inception_v3.preprocess_input)
# inception_v3基础模型,include_top=False就是要修改原模型的最后一层
base_model = inception_v3.InceptionV3(weights='imagenet',include_top=False)

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(units=1024,activation='relu')(x)
predictions = Dense(units=12,activation='softmax')(x)
model = Model(inputs=base_model.input, output=predictions)

base_model.summary()
model.summary()

# 不训练基础层
for layer in base_model.layers:
layer.trainable = False


model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

# batch_size最好选2的n次方,参考的是内存格式
history_tl = model.fit_generator(generator=train_aug.flow(x=x_train,y=y_train,batch_size=32),validation_data=(x_test, y_test),
steps_per_epoch=len(x_train)//32,epochs=10,verbose=1)


model.save()
目录
相关文章
|
6月前
|
机器学习/深度学习 数据采集 PyTorch
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
|
机器学习/深度学习 PyTorch 算法框架/工具
基于pytorch搭建VGGNet神经网络用于花类识别
基于pytorch搭建VGGNet神经网络用于花类识别
448 0
基于pytorch搭建VGGNet神经网络用于花类识别
|
3月前
|
API 异构计算
4.3.2 图像分类ResNet实战:眼疾识别——模型构建
这篇文章介绍了如何使用飞桨框架中的ResNet50模型进行眼疾识别的实战,通过5个epoch的训练,在验证集上达到了约96%的准确率,并提供了模型构建、训练、评估和预测的详细代码实现。
|
5月前
|
机器学习/深度学习 分布式计算 并行计算
基于YOLO和Darknet预训练模型的对象检测
【6月更文挑战第6天】基于YOLO和Darknet预训练模型的对象检测。
51 2
|
6月前
|
机器学习/深度学习 算法 Serverless
YoLo_V4模型训练过程
YoLo_V4模型训练过程
89 0
|
6月前
|
机器学习/深度学习 算法 PyTorch
使用PyTorch实现去噪扩散模型
在深入研究去噪扩散概率模型(DDPM)如何工作的细节之前,让我们先看看生成式人工智能的一些发展,也就是DDPM的一些基础研究。
92 0
|
PyTorch 算法框架/工具
ShuffleNet v2网络结构复现(Pytorch版)
ShuffleNet v2网络结构复现(Pytorch版)
ShuffleNet v2网络结构复现(Pytorch版)
|
机器学习/深度学习 编解码 自然语言处理
基于EasyCV复现ViTDet:单层特征超越FPN
ViTDet其实是恺明团队MAE和ViT-based Mask R-CNN两个工作的延续。MAE提出了ViT的无监督训练方法,而ViT-based Mask R-CNN给出了用ViT作为backbone的Mask R-CNN的训练技巧,并证明了MAE预训练对下游检测任务的重要性。而ViTDet进一步改进了一些设计,证明了ViT作为backone的检测模型可以匹敌基于FPN的backbone(如SwinT和MViT)检测模型。
|
机器学习/深度学习 算法 PyTorch
使用Pytorch实现对比学习SimCLR 进行自监督预训练
SimCLR(Simple Framework for Contrastive Learning of Representations)是一种学习图像表示的自监督技术。 与传统的监督学习方法不同,SimCLR 不依赖标记数据来学习有用的表示。 它利用对比学习框架来学习一组有用的特征,这些特征可以从未标记的图像中捕获高级语义信息。
1035 1
|
搜索推荐 TensorFlow 数据处理
【推荐系统】TensorFlow复现论文DeepCrossing特征交叉网络结构
【推荐系统】TensorFlow复现论文DeepCrossing特征交叉网络结构
122 1
【推荐系统】TensorFlow复现论文DeepCrossing特征交叉网络结构