sklearn的GridSearchCV——网格搜索超参数调优

摘要:
基本上,参数不冲突。当参数不冲突时,可以直接使用字典传递参数和要映射到GridSearchCV的候选值。这里的参数冲突是指以下情况:① 参数值是有限的:当参数a=‘a’时,参数b只能取‘b’;当参数a=“a”时,参数b可以取“b”或“b”② 参数互斥:只能从sklearninportdatasetsfromsklearn.svmiportSVCromsklearn.mod中选择一个参数a或b

基本使用

参数不冲突

参数不冲突时,直接用一个字典传递参数和要对应的候选值给GridSearchCV即可

我这里的参数冲突指的是类似下面这种情况:
① 参数取值受限:参数a='a'时,参数b只能取'b',参数a='A'时,参数b能取'b'或'B'
② 参数互斥:参数 a 或 b 二者只能选一个

from sklearn import datasets
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
iris = datasets.load_iris()
model = SVC(random_state=seed)

# 需调参数及候选值
parameters = {
    'C': [0.1, 1, 10], 
    'kernel': ['rbf', 'linear']
}

# 评价依据
# https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter
scores = {
    'acc': 'accuracy',         # 准确率
    'f1_mi': 'f1_micro',       # 一种多分类f1值
}

# 网格搜索实例
gs = GridSearchCV(
    model,
    parameters,
    cv=5,                      # 交叉验证数
    scoring=scores,            # 评价指标
    refit='f1_mi',             # 在此指标下,用最优表现的参数重新训练模型
#     return_train_score=True,   # gs.cv_results_额外保存训练集的评价结果
    verbose=1,                 # 日志信息,默认0不输出
    n_jobs=2                   # 并行加速
)

# 一共要跑的任务数=参数1候选值*...*参数i候选值*交叉验证数
# 这里就是3*2*5=30
gs.fit(iris.data, iris.target)

借助 make_scorer 可以自定义评价指标,如果指标越小越好,那么需要设置greater_is_better=False,sklearn会将这样的指标取负,越小越好取负之后就等同于越大越好。

from sklearn.metrics import make_scorer
def custom_loss_func(y_true, y_pred):
    return len(y_true[y_true!=y_pred])/len(y_true)
# greater_is_better=False,指标越小越好
# needs_proba=False,指标通过标签计算,不是通过概率
loss_socre = make_scorer(custom_loss_func, greater_is_better=False, needs_proba=False)
scores = {
    'acc': 'accuracy',         # 准确率
    'f1_mi': 'f1_micro',       # 一种多分类f1值
    'loss': loss_socre         # 自定义评价指标
}

再通过 gs.best_params_ 获取最优模型的参数,gs.best_estimator_取得最优模型(想这样操作的话GridSearchCV的refit参数不能为False),

print("最优参数")
print(gs.best_params_)
print("最佳模型的评分")
print(gs.best_score_)
print("最优模型")
best_model = gs.best_estimator_  # GridSearchCV的refit参数不能为False

gs.cv_results_ 存放了网格搜索的结果,如果想查看可以借助pandas,我们这里只列出了和评价指标有关的结果

"""
用表格查看训练信息
"""
cv_results = pd.DataFrame(gs.cv_results_)
# 查看其他指标的结果和参数,比如这里按平均准确率排序
cv_results = cv_results.sort_values(by="mean_test_acc", ascending=False)
shown_columns = ["mean_test_"+col for col in scores.keys()] + ["params"]
cv_results[shown_columns].head(3)

sklearn的GridSearchCV——网格搜索超参数调优第1张

参数冲突

参数冲突时,互斥参数搜索空间用不同字典来描述,然后将这些字典放到列表中,再传递给GridSearchCV

parameters = [
    {
        'C': [0.1, 1, 10], 
        'kernel': ['rbf', 'linear']
    },
    {
        'C': [0.1, 1, 10],
        'kernel': ['poly'],
        'degree': [1, 3, 5]
    }
]

复合调参

管道可以用来连接多个操作,比如特征选择+模型训练,数据处理+模型训练等等。如果这些操作也有参数可调,可以用 GridSearchCV 对它们一起调参

from sklearn import datasets
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectKBest, chi2, f_classif
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
iris = datasets.load_iris()

pipe = Pipeline([
    ('selector', SelectKBest()),       # 特征选择
    ('model', SVC(random_state=seed))  # 模型
])

# “双下划线”指定要调整的部件及其参数
parameters = [
    {
        'selector__score_func': [chi2, f_classif],
        'selector__k': [2, 3, 4],
        'model__C': [0.1, 1, 10], 
        'model__kernel': ['rbf', 'linear']
    },
    {
        'selector__score_func': [chi2, f_classif],
        'selector__k': [2, 3, 4],
        'model__C': [0.1, 1, 10],
        'model__kernel': ['poly'],
        'model__degree': [1, 3, 5]
    }
]


gs = GridSearchCV(
    pipe,
    parameters,
    cv=5,
    scoring='accuracy',
    verbose=1,
    n_jobs=2,
)

gs.fit(iris.data, iris.target)

这时候获得的 best_estimator_ 是管道,我们可以用索引获取需要的组件(特征选择器或模型)

print("最优组合")
# best_pipe = gs.best_estimator_
best_selector = gs.best_estimator_[0]
best_model = gs.best_estimator_[1]

免责声明:文章转载自《sklearn的GridSearchCV——网格搜索超参数调优》仅用于学习参考。如对内容有疑问,请及时联系本站处理。

上篇加载NT驱动openssh安装/更新教程(CentOS)下篇

宿迁高防,2C2G15M,22元/月;香港BGP,2C5G5M,25元/月 雨云优惠码:MjYwNzM=

相关文章

机器学习实战:基于Scikit-Learn和TensorFlow 读书笔记 第6章 决策树

数据挖掘作业,要实现决策树,现记录学习过程 win10系统,Python 3.7.0 构建一个决策树,在鸢尾花数据集上训练一个DecisionTreeClassifier: from sklearn.datasets importload_iris from sklearn.tree importDecisionTreeClassifier iris =l...

iris 框架在服务端解决跨域问题

1. 编写中间件,将允许跨域的header添加到响应头 //Cors funcCors(ctxiris.Context){ ctx.Header("Access-Control-Allow-Origin","*") //ctx.Header("Access-Control-Allow-Headers","DNT,X-Mx-ReqToken,Keep-Al...

Go Iris学习笔记01

Iris MVC支持文档: 支持所有 HTTP 方法, 例如,如果想要写一个 GET 那么在控制器中也要写一个 Get() 函数,你可以在一个控制器内定义多个函数。 每个控制器通过 BeforeActivation 自定义事件回调,用来自定义控制器的结构的方法与自定义路径处理程序,如下:(还未实验) func (m *MyController) Befor...

CSS动态滤镜

CSS动态滤镜   动态滤镜可以为页面添加动人的淡入淡出、图象转化效果,它可以分为两种blend(混合)和reveal(显示),前者可以使对象渐渐消失或出现,后者提供了24种图象转化的效果。对于动态滤镜的调用除去象在静态滤镜中要定义的滤镜类型,参数等等,还用到脚本语言控制它的状态。  首先,在开始一个动态效果之前,先需要进行装备(Apply),然后播放(P...

鸢尾花种类预测--数据集

1 案例:鸢尾花种类预测 Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。关于数据集的具体介绍: 2 scikit-learn中数据集介绍 2.1 scikit-learn数据集API介绍 sklearn.datasets 加载获取流行数据集 datasets.load_...