tensorflow2.0回归模型---如何用好keras对sklearn的api

简介: tensorflow2.0回归模型---如何用好keras对sklearn的api

之前写过如何用tf.keras搭建模型,那个时候埋下了一个伏笔,就是超参数搜索的问题。如何得到最好的模型,我们用sklearn的时候就是GridSearchCV或者RandomizedSearchCV,所以我今天想讲讲怎么通过tf.keras的api来实现超参数搜索。


1.看看官方文档的介绍


官方文档


image.png


发现调用这个api只需要写一个build_fn,也就是写一个搭建网络的函数,知道这个之后就来实战看看。


2.实现超参数搜索


  1. 首先导入数据集,我使用的是California Housing dataset
  2. 切分好训练集,验证集和测试集,并且对数据标准化
  3. 写好我们需要的网络模型,调用keras.wrappers.scikit_learn.KerasRegressor实现model可以使用sklearn的方法
  4. 训练得到最好的参数以及模型


代码:


import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib
import matplotlib.pyplot as plt
import sklearn
import os,sys
from sklearn.datasets import fetch_california_housing
housing=fetch_california_housing()
house=pd.DataFrame(housing.data)
house.columns=housing.feature_names
house['price']=housing.target
house.info()
house.head(10)
from sklearn.model_selection import train_test_split
x_train_all,x_test,y_train_all,y_test=train_test_split(housing.data,housing.target,test_size=0.25,random_state=2)
x_train,x_valid,y_train,y_valid=train_test_split(x_train_all,y_train_all,test_size=0.25,random_state=2)
#标准化
from sklearn.preprocessing import StandardScaler
scaler=StandardScaler()
x_train_scaled=scaler.fit_transform(x_train)
x_valid_scaled=scaler.fit_transform(x_valid)
x_test_scaled=scaler.fit_transform(x_test)
#构建自己的模型
def build_model(hidden_layers=1,layer_size=30,learning_rate=3e-3):
    model=keras.models.Sequential()
    model.add(keras.layers.Dense(layer_size,activation="relu",input_shape=x_train.shape[1:]))
    for _ in range(hidden_layers-1):
        model.add(keras.layers.Dense(layer_size,activation="relu"))
    model.add(keras.layers.Dense(1))
    optimizer=keras.optimizers.SGD(learning_rate)
    model.compile(loss="mse",optimizer=optimizer)
    return model
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join('logs', current_time)
callbacks=[
    keras.callbacks.EarlyStopping(patience=3,min_delta=1e-3),
    keras.callbacks.TensorBoard(log_dir=logdir)
]
model=keras.wrappers.scikit_learn.KerasRegressor(build_model)
callback=[keras.callbacks.EarlyStopping(patience=3,min_delta=1e-3)]
#使用默认参数的模型
history=model.fit(x_train_scaled,y_train,validation_data=(x_valid_scaled,y_valid),epochs=100,callbacks=callback)
def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid()
    plt.gca().set_ylim(0,1)
    plt.show()
plot_learning_curves(history)
#实现超参数搜索
from scipy.stats import reciprocal
param_distribution={
    "hidden_layers":[3,4],
    "layer_size":np.arange(24,27),
    "learning_rate":reciprocal(0.001,0.005)
}
#随机搜索
from sklearn.model_selection import RandomizedSearchCV
random_searchcv=RandomizedSearchCV(model,param_distribution,n_iter=10,verbose=0)
random_searchcv.fit(x_train_scaled,y_train,validation_data=(x_valid_scaled,y_valid),epochs=100)
print("得到的最好参数为:",random_searchcv.best_params_)
print("最好的得分为:",random_searchcv.best_score_)
model=random_searchcv.best_estimator_.model
history1=model.fit(x_train_scaled,y_train,validation_data=(x_valid_scaled,y_valid),epochs=100,callbacks=callbacks)
plot_learning_curves(history1)
print(model.evaluate(x_test_scaled,y_test,verbose=0))


image.png

image.png

默认参数的模型效果


image.png

使用随机搜索的结果

image.png


打开tensorboard看看模型


image.png

3.总结:


虽然使用超参数搜索很方便,但是也有一些需要注意的地方。比如需要知道哪些参数比较重要,就尽量精确;相反的就可以稍微放宽要求。毕竟不论是使用网格搜索还是随机搜索如果参数太多会导致计算时间很久。BTW,在使用超参数搜索的时候,我想把n_jobs设置成大于1的,运行就会报错,还没有找到很好的办法解决,这也就导致了对数据的计算时间会更久,终究调参还是一个需要经验的事情鸭!

目录
相关文章
|
16天前
|
自然语言处理 安全 API
API First:模型驱动的阿里云API保障体系
本文介绍了阿里云在API设计和管理方面的最佳实践。首先,通过API First和模型驱动的方式确保API的安全、稳定和效率。其次,分享了阿里云内部如何使用CloudSpec IDL语言及配套工具保障API质量,并实现自动化生成多语言SDK等工具。接着,描述了API从设计到上线的完整生命周期,包括规范校验、企业级能力接入、测试和发布等环节。最后,展望了未来,强调了持续提升API质量和开源CloudSpec IDL的重要性,以促进社区共建更好的API生态。
|
3月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
118 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
8天前
|
机器学习/深度学习 人工智能 安全
GLM-Zero:智谱AI推出与 OpenAI-o1-Preview 旗鼓相当的深度推理模型,开放在线免费使用和API调用
GLM-Zero 是智谱AI推出的深度推理模型,专注于提升数理逻辑、代码编写和复杂问题解决能力,支持多模态输入与完整推理过程输出。
112 24
GLM-Zero:智谱AI推出与 OpenAI-o1-Preview 旗鼓相当的深度推理模型,开放在线免费使用和API调用
|
27天前
|
存储 人工智能 API
AgentScope:阿里开源多智能体低代码开发平台,支持一键导出源码、多种模型API和本地模型部署
AgentScope是阿里巴巴集团开源的多智能体开发平台,旨在帮助开发者轻松构建和部署多智能体应用。该平台提供分布式支持,内置多种模型API和本地模型部署选项,支持多模态数据处理。
203 4
AgentScope:阿里开源多智能体低代码开发平台,支持一键导出源码、多种模型API和本地模型部署
|
2月前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
191 5
|
2月前
|
人工智能 Java API
ChatClient:探索与AI模型通信的Fluent API
【11月更文挑战第22天】随着人工智能(AI)技术的飞速发展,越来越多的应用场景开始融入AI技术以提升用户体验和系统效率。在Java开发中,与AI模型通信成为了一个重要而常见的需求。为了满足这一需求,Spring AI引入了ChatClient,一个提供流畅API(Fluent API)的客户端,用于与各种AI模型进行通信。本文将深入探讨ChatClient的底层原理、业务场景、概念、功能点,并通过Java代码示例展示如何使用Fluent API与AI模型进行通信。
62 8
|
2月前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
134 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
135 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
2月前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
110 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
4月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
142 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别

热门文章

最新文章