博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
模型评估方法
阅读量:4567 次
发布时间:2019-06-08

本文共 2478 字,大约阅读时间需要 8 分钟。

注:本文是的学习笔记

  1. Estimator对象的score方法
  2. 在交叉验证中使用scoring参数
  3. 使用sklearn.metric中的性能度量函数

Estimator对象的score方法

分类算法必须要继承ClassifierMixin类, 回归算法必须要继承RegressionMixin类,里面都有一个score

()方法。

score(self, X, y_true)函数会在内部调用predict函数获得预测响应y_predict,然后与传入的真是响应进行比较,计算得分。

使用estimator的score函数来评估模型的属性,默认情况下,分类器对应于准确率:sklearn.metrics.accuracy_score, 回归器对应于均方差: sklearn.metrics.r2_score。

在交叉验证中使用scoring参数

GridSearchCV(scoring=None)

cross_val_score(scoring=None)
...

指定在进行网格搜索或者计算交叉验证得分的时候,使用什么标准度量'estimator'的预测性能,默认是None,就是使用estimator自己的score方法来计算得分。我们可以指定别的性能度量标准,它必须是一个可调用对象,sklearn.metrics不仅为我们提供了一系列预定义的可调用对象,而且还支持自定义评估标准

Scoring Function
分类
accuracy metrics.accuracy_score
average_precision metrics.average_precision_score
f1 metrics.f1_score
f1_micro metrics.f1_score
f1_macro metrics.f1_score
f1_weighted metrics.f1_score
f1_sample metrics.f1_score
neg_log_loss metrics.log_loss
precision metrics.precision_score
recall metrics.recall_score
roc_auc metrics.roc_auc_score
聚类
adjusted_rand_score metrics.adjusted_rand_score
回归
neg_mean_absolute_erroe metrics.neg_mean_absolute_erroe
neg_mean_squared_error metrics.neg_mean_squared_error
neg_median_absolute_error metrics.neg_median_absolute_error
r2 metrics.r2

约定: 返回值越大代表性能越好

可以使用sklearn.metrics.SCORERS返回以上的评估函数。

在交叉验证中使用自定义scoring参数

  1. 把sklearn.metrics中已有的度量指标封装成符合‘scoring’参数要求的形式。
    Metrics模块中的很多的度量方法并没有被分配‘scoring’参数可用的名字。因为这些度量指标需要附加参数,比如:‘fbeta_score’。在这种情况下,如果我们想要使用‘fbeta_score’的话,必须要产生一个合适的scoring对象,产生可调用对象的最简单的方法就是使用‘make_scorer’,该函数会把'fbeta_score'这个函数转换成能够在模型评估中使用的可调用对象。
from sklearn.metrics import fbeta_score, make_scorerftwo_score = make_scorer(fbeta_score, beta=2)  # 添加参数from sklearn.model_selection import  GridSearchCVfrom sklearn.svm import LinearSVCgrid = GridSearchCV(LinearSVC(), param_grid={'C': [1,10]}, scoring=ftwo_score)
  1. 完全自定义自己的度量指标然后用'make_scorer'函数转换成符合’scoring‘参数要求的形式
from sklearn.metrics import fbeta_score, make_scorerimport numpy as npdef my_custom_loss_func(ground_truth, predictions):    diff = np.abs(ground_truth - predictions).max()    return np.log(1 + diff)loss = make_scorer(my_custom_loss_func, greater_is_better = False)score = make_scorer(my_custom_loss_func, greater_is_better = False)ground_truth = [[1,1]]predictions = [0,1]from sklearn.dummy import DummyClassifierclf = DummyClassifier(strategy='most_frequent', random_state = 0)clf = clf.fit(ground_truth, predictions)print(loss(clf, ground_truth, predictions))print(score(clf, ground_truth, predictions))

1203446-20171013142458996-55650302.png

转载于:https://www.cnblogs.com/cnkai/p/7755122.html

你可能感兴趣的文章
05-Python基础之函数基础
查看>>
水晶苍蝇拍:价值投资的“基础,重点和核心” (2010-06-23 18:30:07)
查看>>
HTML超链接的使用
查看>>
h5微信支付在微信内页使用微信公众号支付
查看>>
分区函数Partition By的与row_number()的用法以及与排序rank()的用法详解(获取分组(分区)中前几条记录)(转)...
查看>>
设计模式学习之责任链模式(Chain of Responsibility,行为型模式)(22)
查看>>
AnimatorCompatHelper clearInterpolator
查看>>
Flutter基础Widget之按钮(RaisedButton、FlatButton、OutlineButton,IconButton)
查看>>
Android自定义控件View(一)
查看>>
Java Web模块——验证码模块
查看>>
设置部门公用流程,上级领导审批,设置注意事项
查看>>
命令服务器linux中tftp服务器设置及测试,图解
查看>>
Java Binary Search
查看>>
RPM包制作总结篇
查看>>
设计模式(六)—原型模式Prototype(创建型)
查看>>
Windows下配置Jenkins 实现自动发布maven项目至tomcat(svn+maven+tomcat)
查看>>
RFID电动车防盗系统的几个问题
查看>>
PostgreSQL 建库建表脚本
查看>>
第四次作业 何雅
查看>>
input 批量修改
查看>>