一、训练模型
# 1. 加载并标准化数据集importtorchimporttorchvisionimporttorchvision.transformsastransformsimportsslssl._create_default_https_context=ssl._create_unverified_contexttransform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) batch_size=4trainset=torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader=torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0) testset=torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader=torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0) classes= ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 作图展示部分数据集样例importmatplotlib.pyplotaspltimportnumpyasnpdefimshow(img): img=img/2+0.5npimg=img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() # 随机获取部分样例数据dataiter=iter(trainloader) images, labels=next(dataiter) # show imagesimshow(torchvision.utils.make_grid(images)) # print labelsprint(' '.join(f'{classes[labels[j]]:5s}'forjinrange(batch_size))) # 2. 定义神经网络importtorch.nnasnnimporttorch.nn.functionalasFclassNet(nn.Module): def__init__(self): super().__init__() self.conv1=nn.Conv2d(3, 6, 5) self.pool=nn.MaxPool2d(2, 2) self.conv2=nn.Conv2d(6, 16, 5) self.fc1=nn.Linear(16*5*5, 120) self.fc2=nn.Linear(120, 84) self.fc3=nn.Linear(84, 10) defforward(self, x): x=self.pool(F.relu(self.conv1(x))) x=self.pool(F.relu(self.conv2(x))) x=torch.flatten(x, 1) # flatten all dimensions except batchx=F.relu(self.fc1(x)) x=F.relu(self.fc2(x)) x=self.fc3(x) returnxnet=Net() # 3. 定义损失函数和优化器importtorch.optimasoptimcriterion=nn.CrossEntropyLoss() optimizer=optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 4. 训练神经网络forepochinrange(2): # loop over the dataset multiple timesrunning_loss=0.0fori, datainenumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels]inputs, labels=data# zero the parameter gradientsoptimizer.zero_grad() # forward + backward + optimizeoutputs=net(inputs) loss=criterion(outputs, labels) loss.backward() optimizer.step() # print statisticsrunning_loss+=loss.item() ifi%2000==1999: # print every 2000 mini-batchesprint(f'[{epoch+1}, {i+1:5d}] loss: {running_loss/2000:.3f}') running_loss=0.0print('Finished Training') # 保存模型PATH='./cifar_net.pth'torch.save(net.state_dict(), PATH)
二、使用本地图片测试模型
importtorch.nnasnnimporttorch.nn.functionalasFimporttorchimporttorchvision.transformsastransformsimportiofromPILimportImageclassNet(nn.Module): def__init__(self): super().__init__() self.conv1=nn.Conv2d(3, 6, 5) self.pool=nn.MaxPool2d(2, 2) self.conv2=nn.Conv2d(6, 16, 5) self.fc1=nn.Linear(16*5*5, 120) self.fc2=nn.Linear(120, 84) self.fc3=nn.Linear(84, 10) defforward(self, x): x=self.pool(F.relu(self.conv1(x))) x=self.pool(F.relu(self.conv2(x))) x=torch.flatten(x, 1) # flatten all dimensions except batchx=F.relu(self.fc1(x)) x=F.relu(self.fc2(x)) x=self.fc3(x) returnx# 加载网络模型参数PATH='./cifar_net.pth'net=Net() net.load_state_dict(torch.load(PATH)) transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) deftransform_image(image_bytes): my_transforms=transforms.Compose([transforms.Resize(255), transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) image=Image.open(io.BytesIO(image_bytes)) returnmy_transforms(image).unsqueeze(0) file=open('cat.jpg', 'rb') img_bytes=file.read() tensor=transform_image(image_bytes=img_bytes) outputs=net(tensor) classes= ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') _, predicted=torch.max(outputs, 1) print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'forjinrange(1)))
- 运行效果
三、Flask 在线模型服务
fromflaskimportFlaskimportsslssl._create_default_https_context=ssl._create_unverified_contextimportioimportjsonimporttorchfromtorchvisionimportmodelsfromtorchvisionimporttransformsfromPILimportImageapp=Flask(__name__) imagenet_class_index=json.load(open('./imagenet_class_index.json')) model=models.densenet121(pretrained=True) model.eval() device='cpu'deftransform_image(image_bytes): my_transforms=transforms.Compose([transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) image=Image.open(io.BytesIO(image_bytes)) returnmy_transforms(image).unsqueeze(0) defget_prediction(image_bytes): tensor=transform_image(image_bytes=image_bytes) outputs=model.forward(tensor) _, y_hat=outputs.max(1) predicted_idx=str(y_hat.item()) returnimagenet_class_index[predicted_idx] route('/predict', methods=['POST']) .defpredict(): ifrequest.method=='POST': file=request.files['file'] img_bytes=file.read() class_id, class_name=get_prediction(image_bytes=img_bytes) returnjsonify({'class_id': class_id, 'class_name': class_name}) defbatch_prediction(image_bytes_batch): image_tensors= [transform_image(image_bytes=image_bytes) forimage_bytesinimage_bytes_batch] tensor=torch.cat(image_tensors).to(device) outputs=model.forward(tensor) _, y_hat=outputs.max(1) predicted_ids=y_hat.tolist() return [imagenet_class_index[str(i)] foriinpredicted_ids] fromflaskimportjsonify, requestfromservice_streamerimportThreadedStreamerstreamer=ThreadedStreamer(batch_prediction, batch_size=64) route('/stream_predict', methods=['POST']) .defstream_predict(): ifrequest.method=='POST': file=request.files['file'] img_bytes=file.read() class_id, class_name=streamer.predict([img_bytes])[0] returnjsonify({'class_id': class_id, 'class_name': class_name}) if__name__=='__main__': app.run()
四、调用在线模型
importrequestsresp=requests.post("http://localhost:5000/stream_predict", files={"file": open('dog.jpg','rb')}) print(resp.json())
- 链接效果
参考链接
Vision Recognition Service with Flask and service streamer