超参数优化 - 贝叶斯优化方法¶

In [1]:
import numpy as np
import pandas as pd
import sklearn
import matplotlib as mlp
import matplotlib.pyplot as plt
import seaborn as sns
import time
import re, pip, conda
In [2]:
for package in [sklearn,mlp,np,pd,sns,pip,conda]:
    print(re.findall("([^']*)",str(package))[2],package.__version__)
sklearn 1.0.1
matplotlib 3.4.3
numpy 1.21.4
pandas 1.3.4
seaborn 0.11.2
pip 21.3.1
conda 4.11.0
In [ ]:
#pip install --upgrade scikit-learn
#conda update scikit-learn

目录

一 贝叶斯优化基础方法
  1 贝叶斯优化的基本流程
  2 贝叶斯优化用于HPO
二 贝叶斯优化的实现
  1 基于Bayes_opt实现GP优化
  2 基于HyperOpt实现TPE优化
  3 基于Optuna实现多种贝叶斯优化

一 贝叶斯优化基础方法¶

在之前的课程中我们讲解了网格搜索、随机网格搜索与Halving网格搜索,无论具体每种网格搜索的思想如何变化,网格优化都是在一个大参数空间中、尽量对所有点进行验证后再返回最优损失函数值的方法,这一类方法在计算量与计算时间上有着不可避免的缺陷,因此才会有随机、Halving等试图缩短训练时间、让整体网格搜索更加适合于大型数据和大型空间的手段。然而,尽管sklearn在提高网格搜索效率方面做出了种种优化,但上述方法仍然无法在效率和精度上做到双赢,若希望更快速的进行参数搜索、并且搜索出一组泛化能力尽可能强的参数,目前的常见做法还是选用一些带有先验过程的调参工具,即一些基于贝叶斯过程调参工具。

贝叶斯优化方法是当前超参数优化领域的SOTA手段(State of the Art),被认为是当前最为先进的优化框架,它可以被应用于AutoML的各大领域,不止限于超参数搜索(HPO)的领域,更是可以被用于神经网络架构搜索NAS以及元学习等先进的领域。现代几乎所有在效率和效果上取得优异成果的超参数优化方法都是基于贝叶斯优化的基本理念而形成的,因此贝叶斯优化是整个AutoML中学习的重点。

然而,虽然贝叶斯优化非常强大,但整体的学习难度却非常高。在学习贝叶斯优化之前,学习者不仅需要充分理解机器学习的主要概念和算法、熟悉典型的超参数优化流程,还需要对部分超出微积分、概率论和线性代数的数学知识有所掌握。特别的是,贝叶斯优化算法本身,与贝叶斯优化用于HPO的过程还有区别。在我们课程有限的时间内,我将重点带大家来看贝叶斯优化用于HPO的核心过程。

1 贝叶斯优化的基本流程¶

首先,我们不理会HPO的问题,先来看待下面的例子。假设现在我们知道一个函数$f(x)$的表达式以及其自变量$x$的定义域,现在,我们希望求解出$x$的取值范围上$f(x)$的最小值,你打算如何求解这个最小值呢?面对这个问题,无论是从单纯的数学理论角度,还是从机器学习的角度,我们都已经见过好几个通俗的思路:

  • 1 我们可以对$f(x)$求导、令其一阶导数为0来求解其最小值

函数$f(x)$可微,且微分方程可以直接被求解

  • 2 我们可以通过梯度下降等优化方法迭代出$f(x)$的最小值

函数$f(x)$可微,且函数本身为凸函数,可以用凸优化的方式求解

  • 3 我们将全域的$x$带入$f(x)$计算出所有可能的结果,再找出最小值

函数$f(x)$相对不复杂、自变量维度相对低、计算量可以承受

当我们知道函数$f(x)$的表达式时,以上方法常常能够有效,但每个方法都有自己的前提条件。假设现在函数$f(x)$是一个平滑均匀的函数,但它异常复杂、且不可微,我们无法使用上述三种方法中的任意一种方法求解,但我们还是想求解其最小值,可以怎么办呢?由于函数异常复杂,带入任意$x$计算的所需的时间很长,所以我们不太可能将全域$x$都带入进行计算,但我们还是可以从中随机抽样部分观测点来观察整个函数可能存在的趋势。于是我们选择在$x$的定义域上随机选择了4个点,并将4个点带入$f(x)$进行计算,得到了如下结果:

01

好了,现在有了这4个观测值,你能告诉我$f(x)$的最小值在哪里吗?你认为最小值点可能在哪里呢?大部分人会倾向于认为,最小值点可能非常接近于已观测出4个$f(x)$值中最小的那个值,但也有许多人不这么认为。当我们有了4个观测值,并且知道我们的函数时相对均匀、平滑的函数,那我们可能对函数的整体分布有如下猜测:

02

当我们对函数整体分布有一个猜测时,这个分布上一定会存在该函数的最小值。同时,不同的人可能对函数的整体分布有不同的猜测,不同猜测下对应的最小值也是不同的。

03

04

现在,假设我们邀请了数万个人对该问题做出猜测,每个人所猜测的曲线如下图所示。不难发现,在观测点的附近,每个人猜测的函数值差距不大,但是在远离远侧点的地方,每个人猜测的函数值就高度不一致了。这也是当然的,因为观测点之间函数的分布如何完全是未知的,并且该分布离观测点越远时,我们越不确定真正的函数值在哪里,因此人们猜测的函数值的范围非常巨大。

05

In [ ]:
[0,1] - 100个小区间

[0,0.01] n1
[0.01,0.02] n2
[0.02,0.03] n3
...
[0.99,1] n100

现在,我们将所有猜测求均值,并将任意均值周围的潜在函数值所在的区域用色块表示,可以得到一条所有人猜测的平均曲线。不难发现,色块所覆盖的范围其实就是大家猜测的函数值的上界和下界,而任意$x$所对应的上下界差异越大,表示人们对函数上该位置的猜测值的越不确定。因此上下界差异可以衡量人们对该观测点的置信度,色块范围越大,置信度越低。

在观测点周围,置信度总是很高的,远离观测点的地方,置信度总是很低,所以如果我们能够在置信度很低的地方补充一个实际的观测点,我们就可以很快将众人的猜测统一起来。以下图为例,当我们在置信度很低的区间内取一个实际观测值时,围绕该区间的“猜测”会立刻变得集中,该区间内的置信度会大幅升高。

当整个函数上的置信度都非常高时,我们可以说我们得出了一条与真实的$f(x)$曲线高度相似的曲线$f^*$,次数我们就可以将$f^*$的最小值当作真实$f(x)$的最小值来看待。自然,如果估计越准确,$f^*$越接近$f(x)$,则$f^*$的最小值也会越接近于$f(x)$的真实最小值。那如何才能够让$f^*$更接近$f(x)$呢?根据我们刚才提升置信度的过程,很明显——观测点越多,我们估计出的曲线会越接近真实的$f(x)$。然而,由于计算量有限,我们每次进行观测时都要非常谨慎地选择观测点。那现在,如何选择观测点才能够最大程度地帮助我们估计出$f(x)$的最小值呢?

有非常多的方法,其中最简单的手段是使用最小值出现的频数进行判断。由于不同的人对函数的整体分布有不同的猜测,不同猜测下对应的最小值也是不同的,根据每个人猜测的函数结果,我们在$X$轴上将定义域区间均匀划分为100个小区间,如果有某个猜测的最小值落在其中一个区间中,我们就对该区间进行计数(这个过程跟对离散型变量绘制直方图的过程完全一致)。

当有数万个人进行猜测之后,我们同时也绘制了基于$X$轴上不同区间的频数图,频数越高,说明猜测最小值在该区间内的人越多,反之则说明该猜测最小值在该区间内的人越少。该频数一定程度上反馈出最小值出现的概率,频数越高的区间,函数真正的最小值出现的概率越高。

当我们将$X$轴上的区间划分得足够细后,绘制出的频数图可以变成概率密度曲线,曲线的最大值所对应的点是$f(x)$的最小值的概率最高,因此很明显,我们应该将曲线最大值所对应的点确认为下一个观测点。根据图像,我们知道最小值最有可能在的区间就在x=0.7左右的位置。当我们不取新的观测点时,现在$f(x)$上可以获得的可靠的最小值就是x=0.6时的点,但我们如果在x=0.7处取新的观测值,我们就很有可能找到比当前x=0.6的点还要小的$f_{min}$。因此,我们可以就此决定,在x=0.7处进行观测。

当我们在x=0.7处取出观测值之后,我们就有了5个已知的观测点。现在,我们再让数万人根据5个已知的观测点对整体函数分布进行猜测,猜测完毕之后再计算当前最小值频数最高的区间,然后再取新的观测点对$f(x)$进行计算。当允许的计算次数被用完之后(比如,500次),整个估计也就停止了。

你发现了吗?在这个过程当中,我们其实在不断地优化我们对目标函数$f(x)$的估计,虽然没有对$f(x)$进行全部定义域上的计算,也没有找到最终确定一定是$f(x)$分布的曲线,但是随着我们观测的点越来越多,我们对函数的估计是越来越准确的,因此也有越来越大的可能性可以估计出$f(x)$真正的最小值。这个优化的过程,就是贝叶斯优化。

2 贝叶斯优化用于超参数优化(HPO)¶


在贝叶斯优化的数学过程当中,我们主要执行以下几个步骤:

  • 1 定义需要估计的$f(x)$以及$x$的定义域

  • 2 取出有限的n个$x$上的值,求解出这些$x$对应的$f(x)$(求解观测值)

  • 3 根据有限的观测值,对函数进行估计(该假设被称为贝叶斯优化中的先验知识),得出该估计$f^*$上的目标值(最大值或最小值)

  • 4 定义某种规则,以确定下一个需要计算的观测点

并持续在2-4步骤中进行循环,直到假设分布上的目标值达到我们的标准,或者所有计算资源被用完为止(例如,最多观测m次,或最多允许运行t分钟)。


以上流程又被称为序贯模型优化(SMBO),是最为经典的贝叶斯优化方法。在实际的运算过程当中,尤其是超参数优化的过程当中,有以下具体细节需要注意:

  • 当贝叶斯优化不被用于HPO时,一般$f(x)$可以是完全的黑盒函数(black box function,也译作黑箱函数,即只知道$x$与$f(x)$的对应关系,却丝毫不知道函数内部规律、同时也不能写出具体表达式的一类函数),因此贝叶斯优化也被认为是可以作用于黑盒函数估计的一类经典方法。但在HPO过程当中,需要定义的$f(x)$一般是交叉验证的结果/损失函数的结果,而我们往往非常清楚损失函数的表达式,只是我们不了解损失函数内部的具体规律,因此HPO中的$f(x)$不能算是严格意义上的黑盒函数。

  • 在HPO中,自变量$x$就是超参数空间。在上述二维图像表示中,$x$为一维的,但在实际进行优化时,超参数空间往往是高维且极度复杂的空间。

  • 最初的观测值数量n、以及最终可以取到的最大观测数量m都是贝叶斯优化的超参数,最大观测数量m也决定了整个贝叶斯优化的迭代次数。m越大贝叶斯优化所需时间越长

  • 在第3步中,根据有限的观测值、对函数分布进行估计的工具被称为概率代理模型(Probability Surrogate model),毕竟在数学计算中我们并不能真的邀请数万人对我们的观测点进行连线。这些概率代理模型自带某些假设,他们可以根据廖廖数个观测点估计出目标函数的分布$f^*$(包括$f^*$上每个点的取值以及该点对应的置信度)。在实际使用时,概率代理模型往往是一些强大的算法,最常见的比如高斯过程、高斯混合模型等等。传统数学推导中往往使用高斯过程,但现在最普及的优化库中基本都默认使用基于高斯混合模型的TPE过程。

  • 在第4步中用来确定下一个观测点的规则被称为采集函数(Aquisition Function),采集函数衡量观测点对拟合$f^*$所产生的影响,并选取影响最大的点执行下一步观测,因此我们往往关注采集函数值最大的点。最常见的采集函数主要是概率增量PI(Probability of improvement,比如我们计算的频数)、期望增量(Expectation Improvement)、置信度上界(Upper Confidence Bound)、信息熵(Entropy)等等。上方gif图像当中展示了PI、UCB以及EI。其中大部分优化库中默认使用期望增量,因为期望增量是所有采集函数中敏感性比较居中,不仅可以可以衡量期望增减,还可以衡量增减大小。采集函数跟决策树中分枝时衡量不纯度衡量指标很像,就像基尼系数、信息熵。

在超参数优化(HPO)中使用贝叶斯优化时,我们常常会看见下面的图像,这张图像表现了贝叶斯优化的全部基本元素,我们的目标就是在采集函数指导下,让$f^*$尽量接近$f(x)$。 11.png

现在我们已经了解贝叶斯优化的基本流程。与许多算法一样,基础流程足以支撑我们使用已经搭建好的优化库进行超参数优化了,即便我们没有对优化原理的每个细节都了如指掌,我们也可以通过实验反馈出的结果来直接判断是否应该调整我们的代码。接下来,我们会先学习如何应用贝叶斯优化的各类库实现不同的贝叶斯优化算法。

二 贝叶斯优化的实现¶

贝叶斯优化是当今黑盒函数估计领域最为先进和经典的方法,在同一套序贯模型下使用不同的代理模型以及采集函数、还可以发展出更多更先进的贝叶斯优化改进版算法,因此,贝叶斯优化的其算法本身就多如繁星,实现各种不同种类的贝叶斯优化的库也是琳琅满目,几乎任意一个专业用于超参数优化的工具库都会包含贝叶斯优化的内容。我们可以在以下页面找到大量可以实现贝叶斯优化方法的HPO库:https://www.automl.org/automl/hpo-packages/ ,其中大部分库都是由独立团队开发和维护,因此不同的库之间之间的优劣、性格、功能都有很大的差异。在课程中,我们将介绍如下三个可以实现贝叶斯优化的库:bayesian-optimization,hyperopt,optuna。

HPO库 优劣评价 推荐指数
bayes_opt ✅实现基于高斯过程的贝叶斯优化
✅当参数空间由大量连续型参数构成时

⛔包含大量离散型参数时避免使用
⛔算力/时间稀缺时避免使用
⭐⭐
hyperopt ✅实现基于TPE的贝叶斯优化
✅支持各类提效工具
✅进度条清晰,展示美观,较少怪异警告或报错
✅可推广/拓展至深度学习领域

⛔不支持基于高斯过程的贝叶斯优化
⛔代码限制多、较为复杂,灵活性较差
⭐⭐⭐⭐
optuna ✅(可能需结合其他库)实现基于各类算法的贝叶斯优化
✅代码最简洁,同时具备一定的灵活性
✅可推广/拓展至深度学习领域

⛔非关键性功能维护不佳,有怪异警告与报错
⭐⭐⭐⭐

bayes_opt是最为经典的贝叶斯优化库,而下面这两个是最为推荐的贝叶斯优化库。如果要指定用TPE,就用hyperopt;optuna是代码最简洁的超参数优化库。

注意,以上三个库**都不支持基于Python环境的并行或GPU加速**,大多数优化算法库只能够支持基于数据库(如MangoDB,MySQL、Apache Spark)的并行或加速,但以上库都可以被部署在分布式计算平台上。

三个库极其辅助包的安装方法分别如下,使用pip或conda安装时注意关闭梯子。

  • Bayes_opt
In [ ]:
#!pip install bayesian-optimization
#!conda install -c conda-forge bayesian-optimization
  • Hyperopt
In [ ]:
#!pip install hyperopt
  • Optuna
In [ ]:
#!pip install optuna
#!conda install -c conda-forge optuna
  • Skopt(作为Optuna辅助包安装,也可单独使用)
In [ ]:
#!pip install scikit-optimize

接下来我们会分别使用三个库来实现贝叶斯优化。在课程中,我们依然使用集成算法中的房价数据作为验证数据,并且呈现出我们之前在不同优化方法上得出的结果作为对比。同时,我们将使用与集成算法中完全一致的随机数种子、以及随机森林算法作为被优化的评估器。

  • 导入库,确认使用数据
In [12]:
# 基本工具
import numpy as np
import pandas as pd
import time
import os # 用于修改环境设置

# 算法/损失/评估指标等
import sklearn
from sklearn.ensemble import RandomForestRegressor as RFR
from sklearn.model_selection import KFold, cross_validate

# 导入贝叶斯优化器bayes_opt
from bayes_opt import BayesianOptimization

# 导入贝叶斯优化器hyperopt
import hyperopt
from hyperopt import hp, fmin, tpe, Trials, partial
from hyperopt.early_stop import no_progress_loss

# 导入贝叶斯优化器optuna
import optuna
C:\Users\zhiyuan\anaconda3\envs\kaggle\lib\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Bayes_opt版本:1.2.0

In [14]:
print(optuna.__version__)
3.1.0
In [15]:
print(hyperopt.__version__)
0.2.7
In [19]:
data = pd.read_csv(r"..\Lesson 09.随机森林模型\datasets\House Price\train_encode.csv",index_col=0)

X = data.iloc[:,:-1]
y = data.iloc[:,-1]
In [20]:
X.head()
Out[20]:
Id 住宅类型 住宅区域 街道接触面积(英尺) 住宅面积 街道路面状况 巷子路面状况 住宅形状(大概) 住宅现状 水电气 ... 半开放式门廊面积 泳池面积 泳池质量 篱笆质量 其他配置 其他配置的价值 销售月份 销售年份 销售类型 销售状态
0 0.0 5.0 3.0 36.0 327.0 1.0 0.0 3.0 3.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 1.0 2.0 8.0 4.0
1 1.0 0.0 3.0 51.0 498.0 1.0 0.0 3.0 3.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 4.0 1.0 8.0 4.0
2 2.0 5.0 3.0 39.0 702.0 1.0 0.0 0.0 3.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 8.0 2.0 8.0 4.0
3 3.0 6.0 3.0 31.0 489.0 1.0 0.0 0.0 3.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 8.0 0.0
4 4.0 5.0 3.0 55.0 925.0 1.0 0.0 0.0 3.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 11.0 2.0 8.0 4.0

5 rows × 80 columns

In [21]:
X.shape
Out[21]:
(1460, 80)
  • 确认该数据集上的历史成果
超参数优化方法(HPO) 默认参数 网格搜索 随机搜索 随机搜索
(大空间)
随机搜索
(连续型)
搜索空间/全域空间 - 1536/1536 800/1536 1536/3000 1536/无限
运行时间(分钟) - 6.36 **2.83(↓)** **3.86(↓)** 3.92
搜索最优(RMSE) 30571.266 29179.698 29251.284 **29012.905(↓)** 29148.381
重建最优(RMSE) - 28572.070 **28639.969(↑)** **28346.673(↓)** 28495.682

1 基于Bayes_opt实现GP优化¶

bayes-optimization是最早开源的贝叶斯优化库之一,也是为数不多至今依然保留着高斯过程优化的优化库,虽然高斯过程优化非常经典,但是现在使用的不是那么广泛了,主要是因为使用贝叶斯优化是为了提高速度,而高斯过程不够快。

由于开源较早、代码简单,bayes-opt常常出现在论文、竞赛kernels或网络学习材料当中,因此理解Bayes_opt的代码是极其重要的课题。不过,bayes-opt对参数空间的处理方式较为原始,也缺乏相应的提效/监控功能,对算力的要求较高,因此它往往不是我们进行优化时的第一首选库。

通常来说,当且仅当我们必须要实现基于高斯过程的贝叶斯优化,且算法的参数空间中带有大量连续型参数时,我们才会优先考虑Bayes_opt库。我们可以在github上找到bayes-optmization的官方文档(https://github.com/fmfn/BayesianOptimization) ,想要进一步了解其基本功能与原理的小伙伴可以进行阅读。

In [39]:
from bayes_opt import BayesianOptimization
  • 1 定义目标函数

目标函数的值即$f(x)$的值。贝叶斯优化会计算$f(x)$在不同$x$上的观测值,因此$f(x)$的计算方式需要被明确。在HPO过程中,我们希望能够筛选出令模型泛化能力最大的参数组合,因此$f(x)$应该是损失函数的交叉验证值或者某种评估指标的交叉验证值。需要注意的是,bayes_opt库存在三个影响目标函数定义的规则:

1 目标函数的输入必须是具体的超参数,而不能是整个超参数空间,更不能是数据、算法等超参数以外的元素,例如字典,因此在定义目标函数时,我们需要让超参数作为目标函数的输入。

2 超参数的输入值只能是浮点数,不支持整数与字符串。因此当算法的实际参数需要输入字符串时,该参数不能使用bayes_opt进行调整,当算法的实际参数需要输入整数时,则需要在目标函数中规定参数的类型。

3 bayes_opt只支持寻找$f(x)$的最大值,不支持寻找最小值。因此当我们定义的目标函数是某种损失时,目标函数的输出需要取负(即,如果使用RMSE,则应该让目标函数输出负RMSE,这样最大化负RMSE后,才是最小化真正的RMSE。)当我们定义的目标函数是准确率,或者auc等指标,则可以让目标函数的输出保持原样。

In [62]:
def bayesopt_objective(n_estimators, max_depth, max_features, min_impurity_decrease):
    
    # 定义评估器:随机森林
    # 需要调整的超参数等于目标函数的输入,不需要调整的超参数则直接等于固定值
    # 默认参数输入一定是浮点数,因此需要套上int函数处理成整数
    reg = RFR(n_estimators = int(n_estimators)
              ,max_depth = int(max_depth)
              ,max_features = int(max_features)
              ,min_impurity_decrease = min_impurity_decrease
              ,random_state = 1412
              ,verbose = False # 可自行决定是否开启森林建树的verbose
              ,n_jobs = -1)
    
    # 交叉验证
    # 定义损失的输出,5折交叉验证下的结果,输出负根均方误差(-RMSE)
    # 注意,交叉验证需要使用数据,但我们不能让数据X,y成为目标函数的输入
    cv = KFold(n_splits=5,shuffle=True,random_state=1412)
    # 输出交叉验证的结果
    validation_loss = cross_validate(reg,X,y
                                     ,scoring="neg_root_mean_squared_error"
                                     ,cv=cv
                                     ,verbose=False       # 不打印具体流程
                                     ,n_jobs=-1
                                     ,error_score='raise' # 如果交叉验证中的算法执行报错,则告诉我们错误的理由
                                    )
    
    # 交叉验证的结果
    # 交叉验证输出的评估指标是负根均方误差,因此本来就是负的损失
    # 目标函数可直接输出该损失的均值
    return np.mean(validation_loss["test_score"])
  • 2 定义参数空间

在任意超参数优化器中,优化器会将参数空格中的超参数组合作为备选组合,一组一组输入到算法中进行训练。在贝叶斯优化中,超参数组合会被输入我们定义好的目标函数$f(x)$中。

在bayes_opt中,我们使用字典方式来定义参数空间,其中参数的名称为键,参数的取值范围为值。且任意参数的取值范围为双向闭区间,以下方的空间为例,在n_estimators的取值中,80与100都可以被取到。

以下参数空间与我们在随机森林中获得最高分的随机搜索的范围高度相似。

In [41]:
param_grid_simple = {'n_estimators': (80,100)
                     , 'max_depth':(10,25)
                     , "max_features": (10,20)
                     , "min_impurity_decrease":(0,1)
                    }

需要注意的是,bayes_opt只支持填写参数空间的上界与下界,不支持填写步长等参数,且bayes_opt会将所有参数都当作连续型超参进行处理,因此bayes_opt会直接取出闭区间中任意浮点数作为备选参数。例如,取92.28作为n_estimators的值。

这也是为什么在目标函数中,我们需要对整数型超参的取值都套上int函数。假设优化器取出92.28作为n_estimators的值,实际传入随机森林算法的会是int(92.28) = 92,如此我们可以保证算法运行过程中不会因参数类型不符而报错。也因为bayes_opt的这个性质,输入bayes_opt的参数空间天生会比其他贝叶斯优化库更大/更密,因此需要的迭代次数也更多,速度会相对慢。

  • 3 定义优化目标函数的具体流程

在有了目标函数与参数空间之后,我们就可以按bayes_opt的规则进行优化了。在任意贝叶斯优化算法的实践过程中,一定都有涉及到随机性的过程,因为最开始就是随机抽取点作为观测点的,随机抽样部分观测点进行采集函数的计算等等。因而,在大部分优化库当中,这种随机性是无法控制的,并没有像sklean那样提供一个随机数种子方便我们复现,贝叶斯优化算法中即便允许我们填写随机数种子,优化算法也不能固定下来。我们可以尝试填写随机数种子,但需要记住优化算法每次运行时一定都会不一样。

虽然优化算法无法被复现,但是优化算法得出的最佳超参数的结果却是可以被复现的。只要优化完毕之后,可以从优化算法的实例化对象中取出最佳参数组合以及最佳分数,该最佳参数组合被输入到交叉验证中后,是一定可以复现其最佳分数的。如果没能复现最佳分数,则是交叉验证过程的随机数种子设置存在问题,或者优化算法的迭代流程存在问题。

In [46]:
def param_bayes_opt(init_points,n_iter):
    
    # 定义优化器,先实例化优化器
    opt = BayesianOptimization(bayesopt_objective # 需要优化的目标函数
                               ,param_grid_simple # 备选参数空间
                               ,random_state=1412 # 随机数种子,虽然无法控制住
                              )
    
    # 使用优化器,Ps:记住bayes_opt只支持最大化
    opt.maximize(init_points = init_points # 抽取多少个初始观测值
                 , n_iter=n_iter           # 一共观测/迭代多少次
                )
    
    # 优化完成,取出最佳参数与最佳分数
    params_best = opt.max["params"] # 使用键值对取出最佳参数
    score_best = opt.max["target"]  # 获取最佳分数
    
    # 打印最佳参数与最佳分数
    print("\n","\n","best params: ", params_best,
          "\n","\n","best cvscore: ", score_best)
    
    # 返回最佳参数与最佳分数
    return params_best, score_best
  • 4 定义验证函数(非必须)

因为优化后的结果是可以复现的,即我们可以对优化算法给出的最优参数进行再验证,其中验证函数与目标函数高度相似,输入参数或超参数空间、输出最终的损失函数结果。在使用sklearn中自带的优化算法时,由于优化算法自己会执行分割数据、交叉验证的步骤,因此优化算法得出的最优分数往往与我们自身验证的分数不同(因为交叉验证时的数据分割不同)。然而在贝叶斯优化过程中,目标函数中的交叉验证即数据分割都是我们自己规定的,因此原则上来说,只要在目标函数中设置了随机数种子,贝叶斯优化给出的最佳分数一定与我们验证后的分数相同,所以当你对优化过程的代码比较熟悉时,可以不用进行二次验证。

In [48]:
def bayes_opt_validation(params_best):
    
    # 输入参数
    reg = RFR(n_estimators = int(params_best["n_estimators"]) 
              ,max_depth = int(params_best["max_depth"])
              ,max_features = int(params_best["max_features"])
              ,min_impurity_decrease = params_best["min_impurity_decrease"]
              ,random_state=1412
              ,verbose=False
              ,n_jobs=-1)

    cv = KFold(n_splits=5,shuffle=True,random_state=1412)
    
    # 交叉验证
    validation_loss = cross_validate(reg,X,y
                                     ,scoring="neg_root_mean_squared_error"
                                     ,cv=cv
                                     ,verbose=False
                                     ,n_jobs=-1
                                    )
    # 返回测试集分数
    return np.mean(validation_loss["test_score"])
  • 5 执行实际优化流程
In [49]:
start = time.time()
params_best, score_best = param_bayes_opt(20,280) #初始看20个观测值,后面迭代280次
print('It takes %s minutes' % ((time.time() - start)/60))
validation_score = bayes_opt_validation(params_best)
print("\n","\n","validation_score: ",validation_score)

# 输出结果第一列是迭代次数、第2列是目标函数值、第3列 ... 是最大深度 ... 最大特征 等等参数。整体迭代300次。最后打印最佳参数、分数、耗时。
|   iter    |  target   | max_depth | max_fe... | min_im... | n_esti... |
-------------------------------------------------------------------------
|  1        | -2.948e+0 |  23.2     |  17.52    |  0.06379  |  88.79    |
|  2        | -2.909e+0 |  14.8     |  17.61    |  0.9214   |  97.58    |
|  3        | -2.9e+04  |  15.86    |  15.56    |  0.2661   |  87.98    |
|  4        | -2.887e+0 |  14.05    |  16.84    |  0.06744  |  89.72    |
|  5        | -2.887e+0 |  18.71    |  19.17    |  0.9315   |  83.7     |
|  6        | -2.895e+0 |  17.7     |  19.58    |  0.7127   |  89.18    |
|  7        | -2.968e+0 |  14.21    |  12.62    |  0.3381   |  91.51    |
|  8        | -2.91e+04 |  23.23    |  10.89    |  0.6078   |  95.06    |
|  9        | -2.891e+0 |  14.89    |  14.0     |  0.9487   |  80.16    |
|  10       | -2.958e+0 |  11.52    |  12.58    |  0.03276  |  92.56    |
|  11       | -2.91e+04 |  13.14    |  13.31    |  0.2563   |  98.24    |
|  12       | -2.95e+04 |  17.94    |  11.48    |  0.3778   |  82.09    |
|  13       | -2.913e+0 |  16.02    |  17.03    |  0.7735   |  88.31    |
|  14       | -2.925e+0 |  13.92    |  15.04    |  0.529    |  93.66    |
|  15       | -2.938e+0 |  12.51    |  13.69    |  0.4482   |  99.9     |
|  16       | -2.933e+0 |  17.73    |  10.05    |  0.4143   |  82.79    |
|  17       | -2.952e+0 |  16.6     |  10.84    |  0.9134   |  88.37    |
|  18       | -2.958e+0 |  21.92    |  15.0     |  0.8219   |  85.86    |
|  19       | -2.934e+0 |  14.07    |  11.38    |  0.05068  |  91.53    |
|  20       | -2.962e+0 |  10.35    |  17.38    |  0.7624   |  99.19    |
|  21       | -2.937e+0 |  13.97    |  16.09    |  0.3349   |  88.16    |
|  22       | -2.887e+0 |  14.95    |  16.89    |  0.07827  |  89.8     |
|  23       | -2.908e+0 |  14.21    |  17.99    |  0.4688   |  90.52    |
|  24       | -2.887e+0 |  18.29    |  19.35    |  0.8535   |  85.12    |
|  25       | -2.968e+0 |  19.49    |  20.0     |  0.0      |  84.53    |
|  26       | -2.906e+0 |  17.9     |  18.8     |  1.0      |  84.29    |
|  27       | -2.886e+0 |  14.42    |  16.55    |  0.0      |  90.57    |
|  28       | -2.885e+0 |  14.52    |  16.53    |  1.0      |  90.12    |
|  29       | -2.902e+0 |  15.17    |  15.66    |  0.319    |  89.94    |
|  30       | -2.928e+0 |  13.27    |  16.41    |  0.8249   |  90.6     |
|  31       | -2.881e+0 |  15.2     |  16.7     |  0.6155   |  90.9     |
|  32       | -2.905e+0 |  16.04    |  17.33    |  1.0      |  90.32    |
|  33       | -2.889e+0 |  18.73    |  18.79    |  1.0      |  82.43    |
|  34       | -2.891e+0 |  17.79    |  19.32    |  1.0      |  86.55    |
|  35       | -2.876e+0 |  19.43    |  17.86    |  1.0      |  83.36    |
|  36       | -2.905e+0 |  20.3     |  17.85    |  1.0      |  82.25    |
|  37       | -2.899e+0 |  18.43    |  16.97    |  0.6339   |  82.39    |
|  38       | -2.929e+0 |  18.72    |  17.58    |  0.8506   |  86.01    |
|  39       | -2.887e+0 |  16.64    |  19.95    |  0.8027   |  85.54    |
|  40       | -2.889e+0 |  16.01    |  19.89    |  0.9527   |  87.31    |
|  41       | -2.89e+04 |  15.33    |  15.78    |  1.0      |  80.25    |
|  42       | -2.939e+0 |  13.54    |  15.32    |  1.0      |  80.0     |
|  43       | -2.947e+0 |  16.93    |  14.72    |  0.81     |  80.6     |
|  44       | -2.895e+0 |  17.53    |  19.89    |  0.3215   |  87.48    |
|  45       | -2.889e+0 |  18.71    |  18.17    |  0.03282  |  82.97    |
|  46       | -2.903e+0 |  15.93    |  19.28    |  0.007221 |  86.27    |
|  47       | -2.9e+04  |  16.01    |  20.0     |  0.0      |  88.86    |
|  48       | -2.875e+0 |  15.67    |  17.5     |  1.0      |  80.3     |
|  49       | -2.878e+0 |  15.38    |  17.32    |  1.0      |  81.6     |
|  50       | -2.853e+0 |  15.99    |  18.21    |  0.0      |  81.1     |
|  51       | -2.904e+0 |  15.55    |  19.09    |  0.601    |  81.11    |
|  52       | -2.924e+0 |  16.73    |  17.63    |  0.4579   |  81.92    |
|  53       | -2.876e+0 |  15.32    |  17.65    |  0.0      |  80.81    |
|  54       | -2.925e+0 |  16.63    |  18.02    |  0.5035   |  80.17    |
|  55       | -2.852e+0 |  15.45    |  18.28    |  0.09286  |  81.23    |
|  56       | -2.853e+0 |  15.55    |  18.11    |  0.0      |  81.88    |
|  57       | -2.907e+0 |  14.55    |  18.07    |  0.0      |  82.08    |
|  58       | -2.921e+0 |  16.2     |  18.99    |  0.0997   |  81.97    |
|  59       | -2.852e+0 |  15.6     |  18.06    |  0.5865   |  81.86    |
|  60       | -2.876e+0 |  15.49    |  17.44    |  0.07222  |  81.46    |
|  61       | -2.877e+0 |  15.42    |  17.81    |  0.9484   |  82.82    |
|  62       | -2.853e+0 |  15.55    |  18.09    |  0.5876   |  80.75    |
|  63       | -2.923e+0 |  19.64    |  16.39    |  0.7572   |  83.73    |
|  64       | -2.909e+0 |  14.75    |  17.82    |  0.9799   |  81.05    |
|  65       | -2.88e+04 |  15.32    |  16.56    |  0.5475   |  83.21    |
|  66       | -2.914e+0 |  15.0     |  17.24    |  1.0      |  84.24    |
|  67       | -2.896e+0 |  15.05    |  15.7     |  1.0      |  82.43    |
|  68       | -2.895e+0 |  15.77    |  15.46    |  0.0      |  83.74    |
|  69       | -2.853e+0 |  15.26    |  18.45    |  0.02002  |  80.46    |
|  70       | -2.941e+0 |  16.16    |  16.76    |  0.2009   |  92.03    |
|  71       | -2.909e+0 |  14.34    |  19.08    |  0.0958   |  80.04    |
|  72       | -2.875e+0 |  15.59    |  17.78    |  0.0424   |  82.72    |
|  73       | -2.967e+0 |  14.36    |  12.3     |  0.6712   |  80.33    |
|  74       | -2.907e+0 |  14.23    |  19.86    |  0.9782   |  86.91    |
|  75       | -2.972e+0 |  19.33    |  20.0     |  1.0      |  88.35    |
|  76       | -2.894e+0 |  17.28    |  19.91    |  0.6293   |  90.76    |
|  77       | -2.935e+0 |  13.92    |  15.84    |  0.1577   |  83.59    |
|  78       | -2.883e+0 |  18.14    |  19.61    |  0.6678   |  92.57    |
|  79       | -2.93e+04 |  19.34    |  18.94    |  0.5768   |  92.74    |
|  80       | -2.887e+0 |  16.88    |  19.67    |  0.3616   |  93.24    |
|  81       | -2.889e+0 |  16.73    |  19.17    |  0.8361   |  91.97    |
|  82       | -2.9e+04  |  17.85    |  20.0     |  1.0      |  93.95    |
|  83       | -2.881e+0 |  15.84    |  20.0     |  0.0      |  92.1     |
|  84       | -2.9e+04  |  15.28    |  19.27    |  0.85     |  93.07    |
|  85       | -2.902e+0 |  15.43    |  19.74    |  0.3355   |  90.84    |
|  86       | -2.901e+0 |  15.86    |  19.64    |  0.252    |  95.07    |
|  87       | -2.942e+0 |  16.98    |  14.7     |  0.2749   |  86.28    |
|  88       | -2.887e+0 |  13.06    |  19.95    |  0.6557   |  94.74    |
|  89       | -2.887e+0 |  13.48    |  19.61    |  0.1196   |  96.02    |
|  90       | -2.881e+0 |  12.42    |  18.66    |  0.7506   |  95.3     |
|  91       | -2.918e+0 |  11.9     |  19.72    |  0.0      |  95.61    |
|  92       | -2.887e+0 |  13.73    |  18.69    |  1.0      |  95.17    |
|  93       | -2.877e+0 |  12.65    |  18.6     |  1.0      |  93.96    |
|  94       | -2.925e+0 |  12.07    |  17.6     |  0.9865   |  94.43    |
|  95       | -2.887e+0 |  13.28    |  18.87    |  0.0      |  94.21    |
|  96       | -2.877e+0 |  12.86    |  19.29    |  0.8052   |  92.65    |
|  97       | -2.875e+0 |  12.29    |  19.9     |  0.7932   |  93.42    |
|  98       | -2.924e+0 |  11.49    |  18.92    |  0.9259   |  92.54    |
|  99       | -2.906e+0 |  14.06    |  19.77    |  0.3828   |  92.67    |
|  100      | -2.89e+04 |  13.04    |  18.51    |  0.6317   |  96.31    |
|  101      | -2.853e+0 |  15.36    |  18.1     |  0.4371   |  80.01    |
|  102      | -2.898e+0 |  14.89    |  19.93    |  0.9062   |  96.9     |
|  103      | -2.889e+0 |  13.66    |  19.58    |  0.4008   |  98.42    |
|  104      | -2.89e+04 |  17.32    |  19.69    |  0.04412  |  91.96    |
|  105      | -2.853e+0 |  15.69    |  18.36    |  0.01086  |  80.16    |
|  106      | -2.9e+04  |  14.83    |  19.7     |  0.4319   |  99.75    |
|  107      | -2.884e+0 |  18.21    |  19.92    |  0.1281   |  99.46    |
|  108      | -2.95e+04 |  19.8     |  20.0     |  0.0      |  100.0    |
|  109      | -2.908e+0 |  17.01    |  20.0     |  0.08407  |  98.65    |
|  110      | -2.905e+0 |  17.93    |  18.06    |  0.6703   |  99.44    |
|  111      | -2.941e+0 |  25.0     |  10.0     |  0.0      |  80.0     |
|  112      | -2.928e+0 |  19.14    |  10.0     |  0.0      |  100.0    |
|  113      | -2.883e+0 |  13.7     |  18.17    |  0.9195   |  92.96    |
|  114      | -2.887e+0 |  13.06    |  19.1     |  0.93     |  93.48    |
|  115      | -2.878e+0 |  12.84    |  19.99    |  0.4258   |  91.65    |
|  116      | -2.889e+0 |  12.35    |  19.96    |  0.3574   |  89.73    |
|  117      | -2.916e+0 |  10.71    |  19.94    |  0.0938   |  88.5     |
|  118      | -2.869e+0 |  12.28    |  19.98    |  0.3843   |  99.26    |
|  119      | -2.869e+0 |  12.74    |  19.85    |  0.007808 |  99.91    |
|  120      | -2.891e+0 |  13.21    |  19.26    |  0.9999   |  99.94    |
|  121      | -2.914e+0 |  11.59    |  19.63    |  0.07302  |  99.93    |
|  122      | -2.87e+04 |  12.73    |  19.19    |  0.02993  |  98.63    |
|  123      | -2.872e+0 |  12.47    |  19.88    |  0.5012   |  97.91    |
|  124      | -2.915e+0 |  11.48    |  19.24    |  0.9974   |  97.93    |
|  125      | -2.892e+0 |  13.28    |  18.95    |  0.01142  |  99.52    |
|  126      | -2.877e+0 |  12.26    |  19.89    |  0.0868   |  92.72    |
|  127      | -2.859e+0 |  24.86    |  14.87    |  0.4054   |  99.26    |
|  128      | -2.873e+0 |  25.0     |  14.04    |  1.0      |  99.6     |
|  129      | -2.929e+0 |  25.0     |  15.45    |  0.9673   |  100.0    |
|  130      | -2.859e+0 |  24.78    |  14.17    |  0.06212  |  98.69    |
|  131      | -2.859e+0 |  24.24    |  14.61    |  0.8034   |  98.56    |
|  132      | -2.856e+0 |  23.96    |  14.38    |  0.04263  |  99.34    |
|  133      | -2.94e+04 |  24.2     |  15.22    |  0.0      |  98.55    |
|  134      | -2.923e+0 |  24.12    |  13.81    |  0.6933   |  99.04    |
|  135      | -2.858e+0 |  24.97    |  14.6     |  0.8047   |  99.24    |
|  136      | -2.858e+0 |  24.78    |  14.14    |  0.1613   |  99.61    |
|  137      | -2.859e+0 |  24.91    |  14.74    |  0.8767   |  98.07    |
|  138      | -2.856e+0 |  23.63    |  14.27    |  0.1131   |  99.81    |
|  139      | -2.917e+0 |  24.85    |  13.76    |  0.8032   |  97.85    |
|  140      | -2.859e+0 |  24.31    |  14.65    |  0.3618   |  99.91    |
|  141      | -2.845e+0 |  22.75    |  14.43    |  0.3707   |  99.28    |
|  142      | -2.845e+0 |  22.77    |  14.8     |  0.02135  |  99.85    |
|  143      | -2.845e+0 |  22.42    |  14.24    |  0.198    |  99.91    |
|  144      | -2.925e+0 |  22.83    |  15.08    |  0.5595   |  99.81    |
|  145      | -2.844e+0 |  22.85    |  14.19    |  0.0      |  99.53    |
|  146      | -2.844e+0 |  22.23    |  14.53    |  0.0      |  99.35    |
|  147      | -2.908e+0 |  22.2     |  13.9     |  0.3868   |  99.11    |
|  148      | -2.869e+0 |  12.85    |  19.93    |  0.7621   |  99.26    |
|  149      | -2.96e+04 |  21.35    |  15.39    |  0.2309   |  99.81    |
|  150      | -2.856e+0 |  23.2     |  14.49    |  0.0109   |  99.33    |
|  151      | -2.861e+0 |  24.92    |  14.82    |  0.9327   |  98.62    |
|  152      | -2.919e+0 |  11.88    |  19.23    |  0.07981  |  94.02    |
|  153      | -2.872e+0 |  12.77    |  19.05    |  0.04917  |  97.76    |
|  154      | -2.914e+0 |  11.82    |  20.0     |  0.0      |  91.19    |
|  155      | -2.856e+0 |  23.07    |  14.19    |  0.003883 |  99.96    |
|  156      | -2.872e+0 |  12.96    |  19.91    |  0.4315   |  97.25    |
|  157      | -2.937e+0 |  24.88    |  15.15    |  0.9654   |  96.93    |
|  158      | -2.913e+0 |  14.28    |  18.03    |  0.7987   |  94.11    |
|  159      | -2.927e+0 |  12.77    |  17.46    |  0.07643  |  98.25    |
|  160      | -2.87e+04 |  12.84    |  19.95    |  0.1335   |  98.81    |
|  161      | -2.893e+0 |  13.3     |  19.4     |  0.983    |  90.89    |
|  162      | -2.853e+0 |  15.37    |  18.67    |  0.8017   |  80.24    |
|  163      | -2.883e+0 |  13.16    |  18.12    |  0.33     |  92.24    |
|  164      | -2.89e+04 |  13.35    |  19.56    |  0.9955   |  97.53    |
|  165      | -2.983e+0 |  10.0     |  10.0     |  1.0      |  85.6     |
|  166      | -2.904e+0 |  15.63    |  19.01    |  0.02633  |  80.04    |
|  167      | -2.909e+0 |  14.94    |  18.14    |  0.918    |  80.02    |
|  168      | -2.853e+0 |  15.75    |  18.13    |  0.02219  |  80.51    |
|  169      | -2.844e+0 |  22.34    |  14.72    |  0.2479   |  98.64    |
|  170      | -2.843e+0 |  22.58    |  14.73    |  0.09646  |  97.92    |
|  171      | -2.964e+0 |  21.98    |  15.36    |  0.3448   |  97.86    |
|  172      | -2.856e+0 |  23.09    |  14.55    |  0.5981   |  98.69    |
|  173      | -2.844e+0 |  22.68    |  14.59    |  0.4372   |  98.17    |
|  174      | -2.844e+0 |  22.74    |  14.42    |  0.02986  |  98.45    |
|  175      | -2.909e+0 |  22.68    |  13.94    |  0.06073  |  97.66    |
|  176      | -2.923e+0 |  22.95    |  15.3     |  0.1521   |  98.32    |
|  177      | -2.844e+0 |  22.12    |  14.38    |  0.04509  |  98.05    |
|  178      | -2.844e+0 |  22.75    |  14.55    |  0.03358  |  98.83    |
|  179      | -2.927e+0 |  25.0     |  10.0     |  1.0      |  89.0     |
|  180      | -2.915e+0 |  25.0     |  20.0     |  0.0      |  80.0     |
|  181      | -2.853e+0 |  15.79    |  18.56    |  0.7099   |  80.36    |
|  182      | -2.891e+0 |  24.98    |  19.91    |  0.1936   |  93.81    |
|  183      | -2.868e+0 |  23.37    |  19.27    |  0.1689   |  94.63    |
|  184      | -2.922e+0 |  23.63    |  18.52    |  0.0      |  93.94    |
|  185      | -2.885e+0 |  23.23    |  20.0     |  0.3337   |  95.25    |
|  186      | -2.89e+04 |  22.78    |  18.79    |  0.3049   |  95.3     |
|  187      | -2.845e+0 |  22.69    |  14.63    |  0.04476  |  99.41    |
|  188      | -2.887e+0 |  24.1     |  19.4     |  0.9451   |  94.93    |
|  189      | -2.895e+0 |  22.57    |  19.96    |  0.3987   |  94.32    |
|  190      | -2.888e+0 |  24.67    |  19.63    |  0.05494  |  96.67    |
|  191      | -2.865e+0 |  23.11    |  19.72    |  0.0      |  97.06    |
|  192      | -2.865e+0 |  23.07    |  19.74    |  0.4877   |  97.98    |
|  193      | -2.914e+0 |  23.47    |  18.98    |  0.02639  |  97.68    |
|  194      | -2.905e+0 |  22.51    |  20.0     |  0.6894   |  97.27    |
|  195      | -2.908e+0 |  22.91    |  20.0     |  0.0      |  98.7     |
|  196      | -2.881e+0 |  23.75    |  20.0     |  0.6416   |  97.48    |
|  197      | -2.844e+0 |  22.39    |  14.74    |  0.02051  |  98.29    |
|  198      | -2.91e+04 |  24.31    |  18.68    |  0.04868  |  95.52    |
|  199      | -2.899e+0 |  25.0     |  20.0     |  1.0      |  98.6     |
|  200      | -2.924e+0 |  24.19    |  15.02    |  0.9349   |  80.11    |
|  201      | -2.86e+04 |  25.0     |  14.54    |  0.3456   |  98.65    |
|  202      | -2.866e+0 |  23.72    |  19.99    |  0.2093   |  96.33    |
|  203      | -2.937e+0 |  18.19    |  10.0     |  1.0      |  94.69    |
|  204      | -2.907e+0 |  10.0     |  20.0     |  1.0      |  83.16    |
|  205      | -2.949e+0 |  20.88    |  13.2     |  0.0      |  91.29    |
|  206      | -2.899e+0 |  24.93    |  19.71    |  0.7239   |  84.92    |
|  207      | -2.864e+0 |  23.47    |  19.53    |  0.9883   |  96.6     |
|  208      | -2.89e+04 |  22.62    |  18.81    |  0.9123   |  96.46    |
|  209      | -2.927e+0 |  22.54    |  10.05    |  0.6898   |  84.75    |
|  210      | -2.895e+0 |  22.53    |  19.11    |  0.9994   |  98.67    |
|  211      | -2.886e+0 |  24.47    |  19.79    |  0.8973   |  96.5     |
|  212      | -2.891e+0 |  22.95    |  19.3     |  0.1932   |  96.38    |
|  213      | -2.98e+04 |  10.0     |  10.0     |  0.0      |  80.0     |
|  214      | -2.968e+0 |  10.3     |  10.06    |  0.7974   |  97.34    |
|  215      | -2.928e+0 |  17.22    |  13.86    |  1.0      |  97.21    |
|  216      | -2.966e+0 |  10.04    |  15.28    |  0.6897   |  85.13    |
|  217      | -2.861e+0 |  24.83    |  14.35    |  0.9897   |  92.38    |
|  218      | -2.873e+0 |  25.0     |  14.19    |  1.0      |  91.43    |
|  219      | -2.861e+0 |  24.11    |  14.84    |  0.681    |  91.97    |
|  220      | -2.859e+0 |  24.93    |  14.77    |  0.1068   |  92.13    |
|  221      | -2.935e+0 |  24.99    |  15.37    |  0.8659   |  92.53    |
|  222      | -2.912e+0 |  24.07    |  13.9     |  0.439    |  91.94    |
|  223      | -2.927e+0 |  23.61    |  15.31    |  0.7272   |  91.32    |
|  224      | -2.859e+0 |  24.53    |  14.67    |  0.07302  |  93.05    |
|  225      | -2.926e+0 |  23.92    |  15.24    |  0.04927  |  92.62    |
|  226      | -2.86e+04 |  24.82    |  14.2     |  0.5013   |  93.37    |
|  227      | -2.859e+0 |  24.89    |  14.48    |  0.03647  |  92.78    |
|  228      | -2.86e+04 |  24.88    |  14.63    |  0.0183   |  94.33    |
|  229      | -2.859e+0 |  24.25    |  14.1     |  0.0      |  93.87    |
|  230      | -2.861e+0 |  24.18    |  14.19    |  0.8938   |  94.33    |
|  231      | -2.913e+0 |  24.68    |  13.42    |  0.2816   |  94.46    |
|  232      | -2.929e+0 |  23.83    |  15.11    |  0.9128   |  94.35    |
|  233      | -2.861e+0 |  24.25    |  14.29    |  0.9944   |  92.97    |
|  234      | -2.915e+0 |  23.37    |  13.37    |  0.762    |  94.11    |
|  235      | -2.856e+0 |  24.73    |  14.37    |  0.261    |  90.4     |
|  236      | -2.935e+0 |  24.95    |  15.1     |  0.4849   |  89.9     |
|  237      | -2.912e+0 |  24.91    |  13.3     |  0.2908   |  89.98    |
|  238      | -2.861e+0 |  24.62    |  14.94    |  0.6089   |  91.26    |
|  239      | -2.861e+0 |  24.54    |  14.36    |  0.1977   |  91.36    |
|  240      | -2.861e+0 |  24.36    |  14.05    |  0.9764   |  93.54    |
|  241      | -2.856e+0 |  24.52    |  14.33    |  0.8382   |  90.72    |
|  242      | -2.927e+0 |  25.0     |  10.0     |  1.0      |  100.0    |
|  243      | -2.861e+0 |  25.0     |  14.41    |  0.8521   |  94.48    |
|  244      | -2.851e+0 |  23.64    |  14.41    |  0.2165   |  90.07    |
|  245      | -2.855e+0 |  24.06    |  14.16    |  0.0      |  90.54    |
|  246      | -2.915e+0 |  23.67    |  13.92    |  0.8395   |  90.12    |
|  247      | -2.836e+0 |  22.8     |  14.65    |  0.3069   |  89.92    |
|  248      | -2.85e+04 |  23.1     |  14.78    |  0.0      |  89.46    |
|  249      | -2.835e+0 |  22.77    |  14.14    |  0.0      |  89.82    |
|  250      | -2.92e+04 |  22.32    |  15.03    |  0.07138  |  89.74    |
|  251      | -2.85e+04 |  23.13    |  14.58    |  0.5024   |  89.17    |
|  252      | -2.836e+0 |  22.73    |  14.18    |  0.5437   |  89.63    |
|  253      | -2.85e+04 |  23.24    |  14.07    |  0.0      |  89.4     |
|  254      | -2.914e+0 |  21.99    |  13.75    |  0.341    |  89.29    |
|  255      | -2.84e+04 |  22.8     |  14.36    |  0.4114   |  90.29    |
|  256      | -2.908e+0 |  22.86    |  13.51    |  0.06174  |  90.37    |
|  257      | -2.852e+0 |  23.11    |  14.57    |  0.001069 |  88.48    |
|  258      | -2.853e+0 |  23.87    |  14.3     |  0.0      |  88.17    |
|  259      | -2.918e+0 |  23.23    |  13.78    |  0.0      |  87.93    |
|  260      | -2.853e+0 |  23.8     |  14.97    |  0.0      |  88.56    |
|  261      | -2.852e+0 |  23.84    |  14.8     |  0.7077   |  88.22    |
|  262      | -2.933e+0 |  24.01    |  15.45    |  0.2949   |  87.59    |
|  263      | -2.849e+0 |  23.52    |  14.67    |  0.1701   |  89.08    |
|  264      | -2.858e+0 |  24.83    |  14.17    |  0.3443   |  88.34    |
|  265      | -2.857e+0 |  24.37    |  14.5     |  0.898    |  88.49    |
|  266      | -2.909e+0 |  24.65    |  13.69    |  0.3106   |  87.35    |
|  267      | -2.841e+0 |  22.86    |  14.78    |  0.9396   |  88.5     |
|  268      | -2.852e+0 |  23.24    |  14.61    |  0.8546   |  88.18    |
|  269      | -2.841e+0 |  22.77    |  14.62    |  0.9389   |  90.12    |
|  270      | -2.93e+04 |  23.3     |  15.33    |  0.912    |  88.89    |
|  271      | -2.84e+04 |  22.53    |  14.24    |  0.797    |  88.73    |
|  272      | -2.905e+0 |  21.9     |  14.54    |  0.5857   |  87.91    |
|  273      | -2.852e+0 |  23.74    |  14.29    |  0.86     |  88.47    |
|  274      | -2.837e+0 |  22.28    |  14.62    |  0.9523   |  89.16    |
|  275      | -2.908e+0 |  24.15    |  13.91    |  0.05677  |  88.88    |
|  276      | -2.858e+0 |  24.76    |  14.84    |  0.3275   |  88.16    |
|  277      | -2.836e+0 |  22.96    |  14.25    |  0.4754   |  89.87    |
|  278      | -2.939e+0 |  15.1     |  10.0     |  0.0      |  98.37    |
|  279      | -2.836e+0 |  22.84    |  14.38    |  0.8515   |  89.17    |
|  280      | -2.971e+0 |  10.0     |  17.82    |  0.0      |  80.0     |
|  281      | -2.982e+0 |  10.0     |  10.0     |  1.0      |  89.99    |
|  282      | -2.97e+04 |  13.98    |  10.21    |  0.0337   |  84.75    |
|  283      | -2.849e+0 |  23.02    |  14.42    |  0.0763   |  89.95    |
|  284      | -2.968e+0 |  21.22    |  20.0     |  1.0      |  80.0     |
|  285      | -2.905e+0 |  21.91    |  14.42    |  0.9744   |  90.57    |
|  286      | -2.951e+0 |  21.38    |  12.03    |  1.0      |  80.0     |
|  287      | -2.872e+0 |  25.0     |  16.98    |  0.0      |  83.25    |
|  288      | -2.948e+0 |  24.54    |  17.98    |  0.3718   |  82.82    |
|  289      | -2.909e+0 |  25.0     |  15.89    |  0.0      |  83.67    |
|  290      | -2.971e+0 |  20.52    |  10.0     |  0.0      |  88.21    |
|  291      | -2.938e+0 |  10.0     |  14.77    |  1.0      |  89.34    |
|  292      | -2.934e+0 |  10.13    |  14.54    |  0.3669   |  95.81    |
|  293      | -2.976e+0 |  19.17    |  15.27    |  0.0      |  94.48    |
|  294      | -2.965e+0 |  25.0     |  12.26    |  0.0      |  83.29    |
|  295      | -2.888e+0 |  24.84    |  19.87    |  0.8937   |  90.86    |
|  296      | -2.961e+0 |  13.87    |  10.0     |  1.0      |  94.86    |
|  297      | -2.895e+0 |  22.89    |  19.95    |  0.6566   |  91.44    |
|  298      | -2.975e+0 |  10.13    |  13.87    |  0.9502   |  81.74    |
|  299      | -2.946e+0 |  16.58    |  15.55    |  0.01365  |  99.89    |
|  300      | -2.888e+0 |  24.79    |  19.88    |  0.8136   |  88.76    |
=========================================================================

 
 best params:  {'max_depth': 22.76987255830374, 'max_features': 14.139404019924546, 'min_impurity_decrease': 0.0, 'n_estimators': 89.82134842869006} 
 
 best cvscore:  -28346.672687223065
It takes 2.107581957181295 minutes

 
 validation_score:  -28346.672687223065
超参数优化方法(HPO) 默认参数 网格搜索 随机搜索 随机搜索
(大空间)
随机搜索
(连续型)
贝叶斯优化
(基于GP)
搜索空间/全域空间 - 1536/1536 800/1536 1536/3000 1536/无限 300/无限
运行时间(分钟) - 6.36 **2.83(↓)** **3.86(↓)** 3.92 2.11(↓)
搜索最优(RMSE) 30571.266 29179.698 29251.284 **29012.905(↓)** 29148.381 **28346.673(↓)**
重建最优(RMSE) - 28572.070 **28639.969(↑)** **28346.673(↓)** 28495.682 **28346.673(↓)**
  • 原理上有优越性

可以看到,基于高斯过程的贝叶斯优化在2.11分钟内锁定了最佳分数28346.673,这是之前使用随机搜索时获得的最佳分数,很可能也是我们当前超参数空间上可以获得的最佳分数。贝叶斯优化作为从原理上高于网格优化的HPO方法,能够以更短的时间获得与随机网格搜索相同的结果,可见其原理上的优越性。能够在最短的时间获得与大空间随机网格搜索一样好的结果。

  • 优化过程无法复现,但优化结果可以复现

但同时要注意,由于贝叶斯优化每次都是随机的,因此我们并不能在多次运行代码时复现出28346.673这个结果,事实上如果我们重复运行,也只有很小的概率可以再次找到这个最低值(这一点对于随机搜索来说也是类似的,如果不规定随机数种子,我们也无法复现最低值)。因此我们在执行贝叶斯优化时,往往会多运行几次观察模型找出的结果。同时,验证分数与目标函数最后输出的分数一模一样,可见最终输出的超参数组合的效力是可以复现的。

  • 效率不足

不难发现,bayes_opt的速度虽然快,效率却不高。实际上在迭代到170次时,贝叶斯优化就已经找到了最小损失,但由于没有提前停止机制,模型还持续地迭代了130次才停下,如果bayes_opt支持提前停止机制,贝叶斯优化所需的实际迭代时间可能会更少。

同时,由于Bayes_opt只能够在参数空间提取浮点数,bayes_opt在随机森林上的搜索效率是较低的,即便在10次不同的迭代中分别取到了[88.89, 88.23, 88.16, 88.59……]等值,在取整之后也只能够获得一个备选值88,但bayes_opt无法辨别这种区别,对于决策树来说,浮点数的参数没有意义,因此可能取出了众多无效的观测点。如果使用其他贝叶斯优化器,贝叶斯优化的效率将会更高。

  • 支持灵活修改

虽然在我们的代码中没有体现,但bayes_opt是支持灵活修改采集函数与高斯过程中的种种参数的,具体可以参考这里:https://github.com/fmfn/BayesianOptimization/blob/master/examples/advanced-tour.ipynb

2 基于HyperOpt实现TPE优化¶

Hyperopt优化器是目前最为通用的贝叶斯优化器之一,Hyperopt中集成了包括随机搜索、模拟退火和TPE(Tree-structured Parzen Estimator Approach)等多种优化算法。相比于Bayes_opt,Hyperopt的是更先进、更现代、维护更好的优化器,也是我们最常用来实现TPE方法的优化器。在实际使用中,相比基于高斯过程的贝叶斯优化,基于高斯混合模型的TPE在大多数情况下以更高效率获得更优结果,该方法目前也被广泛应用于AutoML领域中。TPE算法原理可以参阅原论文Multiobjective tree-structured parzen estimator for computationally expensive optimization problems,在这里我们将重点介绍关于Hyperopt中使用TPE进行超参数搜索的过程。

In [3]:
import hyperopt
from hyperopt import hp, fmin, tpe, Trials, partial
from hyperopt.early_stop import no_progress_loss # 用于控制提前停止
In [4]:
print(hyperopt.__version__)
0.2.7
  • 1 定义目标函数

在定义目标函数$f(x)$时,我们需要严格遵守需要使用的当下优化库的基本规则。与Bayes_opt一样,Hyperopt也有一些特定的规则会限制我们的定义方式,主要包括:

1 目标函数的输入必须是符合hyperopt规定的字典,不能是类似于sklearn的参数空间字典、不能是参数本身,更不能是数据、算法等超参数以外的元素。因此在自定义目标函数时,我们需要让超参数空间字典作为目标函数的输入。

2 Hyperopt只支持寻找$f(x)$的最小值,不支持寻找最大值,因此当我们定义的目标函数是某种正面的评估指标时(如准确率,auc),我们需要对该评估指标取负。如果我们定义的目标函数是负损失,也需要对负损失取绝对值。当且仅当我们定义的目标函数是普通损失时,我们才不需要改变输出。

In [5]:
# 大概参数长这种样子
#params = {'参数名称':参数范围}
In [6]:
def hyperopt_objective(params):
    
    # 定义评估器
    # 需要搜索的参数需要从输入的字典中索引出来
    # 不需要搜索的参数,可以是设置好的某个值
    # 在需要整数的参数前调整参数类型
    reg = RFR(n_estimators = int(params["n_estimators"])
              ,max_depth = int(params["max_depth"])
              ,max_features = int(params["max_features"])
              ,min_impurity_decrease = params["min_impurity_decrease"]
              ,random_state=1412
              ,verbose=False
              ,n_jobs=-1)
    
    # 交叉验证结果,输出负根均方误差(-RMSE)
    cv = KFold(n_splits=5,shuffle=True,random_state=1412)
    validation_loss = cross_validate(reg,X,y
                                     ,scoring="neg_root_mean_squared_error"
                                     ,cv=cv
                                     ,verbose=False
                                     ,n_jobs=-1
                                     ,error_score='raise'
                                    )
    
    # 最终输出结果;由于hyperopt只支持取最小值,所以必须对(-RMSE)求绝对值
    # 以求解最小RMSE所对应的参数组合
    return np.mean(abs(validation_loss["test_score"]))
  • 2 定义参数空间

在任意超参数优化器中,优化器会将参数空格中的超参数组合作为备选组合,一组一组输入到算法中进行训练。在贝叶斯优化中,超参数组合会被输入我们定义好的目标函数$f(x)$中。

在hyperopt中,我们使用特殊的字典形式来定义参数空间,其中键值对上的键可以任意设置,只要与目标函数中索引参数的键一致即可,键值对的值则是hyperopt独有的hp函数,包括了:

hp.quniform("参数名称", 下界, 上界, 步长) - 适用于均匀分布的浮点数,最优参数由值表示

hp.uniform("参数名称",下界, 上界) - 适用于随机分布的浮点数,最优参数由值表示

hp.randint("参数名称",上界) - 适用于[0,上界)的整数,区间为前闭后开,最优参数由值表示

hp.choice("参数名称",["字符串1","字符串2",...]) - 适用于字符串类型,最优参数由索引表示

hp.choice("参数名称",[*range(下界,上界,步长)]) - 适用于整数型,最优参数由索引表示

hp.choice("参数名称",[整数1,整数2,整数3,...]) - 适用于整数型,最优参数由索引表示

hp.choice("参数名称",["字符串1",整数1,...]) - 适用于字符与整数混合,最优参数由索引表示

在hyperopt的说明当中,并未明确参数取值范围空间的开闭,根据实验,如无特殊说明,hp中的参数空间定义方法应当都为前闭后开区间。我们依然使用在随机森林上获得最高分的随机搜索的参数空间:

In [7]:
param_grid_simple = {'n_estimators': hp.quniform("n_estimators",80,100,1)
                     , 'max_depth': hp.quniform("max_depth",10,25,1)
                     , "max_features": hp.quniform("max_features",10,20,1)
                     , "min_impurity_decrease":hp.quniform("min_impurity_decrease",0,5,1)
                    }

由于hp.choice最终会返回最优参数的索引,容易与数值型参数的具体值混淆,所以最好使用hp.quniform()函数,而hp.randint又只能够支持从0开始进行计数,因此我们常常会使用quniform获得均匀分布的浮点数来替代整数。对于需要取整数的参数值,如果采用quniform方式构筑参数空间,则需要在目标函数中使用int函数限定输入类型。例如,在范围[0,5]中取值时,可以取出[0.0, 1.0, 2.0, 3.0,...]这种均匀浮点数,在输入目标函数时,则必须确保参数值前存在int函数。当然,如果使用hp.choice则不会存在该问题。

由于不涉及到连续型变量,因此我们可以计算出当前参数空间的大小:

In [8]:
len([*range(80,100,1)])*len([*range(10,25,1)])*len([*range(10,20,1)])*len([range(0,5,1)])
Out[8]:
3000
  • 3 定义优化目标函数的具体流程

有了目标函数和参数空间,接下来我们就可以进行优化了。在Hyperopt中,我们用于优化的基础功能叫做fmin,fmin用于求解最小值。

在fmin中,我们可以自定义使用的代理模型(参数algo),一般来说我们有tpe.suggest以及rand.suggest两种选项,前者指代TPE方法,后者指代随机网格搜索方法。tpe.suggest代表使用TPE的默认参数,当然我们还可以通过partial功能来修改算法涉及到的具体参数,包括模型具体使用了多少个初始观测值(参数n_start_jobs),以及在计算采集函数值时究竟考虑多少个样本(参数n_EI_candidates)。

除此之外,Hyperopt当中还有两个值得注意的功能,一个记录整个迭代过程的trials,另一个是提前停止参数early_stop_fn。提前停止参数early_stop_fn中我们一般输入从hyperopt库导入的方法no_progress_loss(),这个方法中可以输入具体的数字n,表示当损失连续n次没有下降时,让算法提前停止。由于贝叶斯方法的随机性较高,当样本量不足时需要多次迭代才能够找到最优解,因此一般no_progress_loss()中的数值不会设置得太高。在我们的课程中,由于数据量较少,我设置了一个较高的值来避免迭代停止太早。

其中,trials直译为“实验”或“测试”,表示我们不断尝试的每一种参数组合,这个参数中我们一般输入从hyperopt库中导入的方法Trials(),当优化完成之后,我们可以从保存好的trials中查看损失、参数等各种中间信息。

In [9]:
def param_hyperopt(max_evals=100):
    
    # 保存迭代过程
    trials = Trials()
    
    # 设置提前停止
    early_stop_fn = no_progress_loss(100)              # 当连续100次迭代,损失函数值都没有下降就停止吧,一般我们不会设置这么大,当数据量小的时候可以设置的小一些,多给一些机会。
    
    # 定义代理模型,注释掉代表使用代理模型参数默认值,一般来说我们也不调整代理模型
    #algo = partial(tpe.suggest, n_startup_jobs=20, n_EI_candidates=50)
    params_best = fmin(hyperopt_objective              # 目标函数
                       , space = param_grid_simple     # 参数空间
                       , algo = tpe.suggest            # 代理模型你要哪个呢?
                       #, algo = algo
                       , max_evals = max_evals         # 允许的迭代次数
                       , verbose=True                  # 打印优化流程
                       , trials = trials               # 保存迭代过程
                       , early_stop_fn = early_stop_fn # 控制提前停止
                      )
    
    # 打印最优参数,fmin会自动打印最佳分数
    print("\n","\n","best params: ", params_best,
          "\n")
    return params_best, trials
  • 4 定义验证函数(非必要)
In [10]:
def hyperopt_validation(params):    
    reg = RFR(n_estimators = int(params["n_estimators"])
              ,max_depth = int(params["max_depth"])
              ,max_features = int(params["max_features"])
              ,min_impurity_decrease = params["min_impurity_decrease"]
              ,random_state=1412
              ,verbose=False
              ,n_jobs=-1
             )
    cv = KFold(n_splits=5,shuffle=True,random_state=1412)
    validation_loss = cross_validate(reg,X,y
                                     ,scoring="neg_root_mean_squared_error"
                                     ,cv=cv
                                     ,verbose=False
                                     ,n_jobs=-1
                                    )
    return np.mean(abs(validation_loss["test_score"]))
  • 5 执行实际优化流程
In [22]:
params_best, trials = param_hyperopt(30) # 迭代30次,即1%的空间大小
100%|████████████████████████████████████████████████| 30/30 [01:26<00:00,  2.90s/trial, best loss: 28547.282757540164]

 
 best params:  {'max_depth': 15.0, 'max_features': 18.0, 'min_impurity_decrease': 5.0, 'n_estimators': 80.0} 

In [64]:
params_best, trials = param_hyperopt(100) # 3%的空间大小
100%|███████████████████████████████████████████████| 100/100 [00:21<00:00,  4.71trial/s, best loss: 28450.06487530331]

 
 best params:  {'max_depth': 22.0, 'max_features': 14.0, 'min_impurity_decrease': 0.0, 'n_estimators': 94.0} 

In [50]:
params_best, trials = param_hyperopt(300) # 10%的空间大小
 92%|██████████████████████████████████████████▍   | 277/300 [01:01<00:05,  4.52trial/s, best loss: 28346.672687223065]

 
 best params:  {'max_depth': 22.0, 'max_features': 14.0, 'min_impurity_decrease': 0.0, 'n_estimators': 89.0} 

In [33]:
hyperopt_validation(params_best)
Out[33]:
28346.672687223065
In [18]:
# 打印所有搜索相关的记录
trials.trials[0]
Out[18]:
{'state': 2,
 'tid': 0,
 'spec': None,
 'result': {'loss': 28766.452192638408, 'status': 'ok'},
 'misc': {'tid': 0,
  'cmd': ('domain_attachment', 'FMinIter_Domain'),
  'workdir': None,
  'idxs': {'max_depth': [0],
   'max_features': [0],
   'min_impurity_decrease': [0],
   'n_estimators': [0]},
  'vals': {'max_depth': [13.0],
   'max_features': [18.0],
   'min_impurity_decrease': [4.0],
   'n_estimators': [80.0]}},
 'exp_key': None,
 'owner': None,
 'version': 0,
 'book_time': datetime.datetime(2021, 12, 24, 13, 33, 19, 633000),
 'refresh_time': datetime.datetime(2021, 12, 24, 13, 33, 19, 840000)}
In [20]:
# 打印全部搜索的目标函数值
trials.losses()[:10]
Out[20]:
[28766.452192638408,
 29762.22885008687,
 29233.57333898302,
 29257.33343872428,
 29180.63733732971,
 29249.676793746046,
 29309.41793204717,
 28915.33638544984,
 29122.269575607537,
 29150.39720576636]
HPO方法 默认参数 网格搜索 随机搜索 随机搜索
(大空间)
随机搜索
(连续型)
贝叶斯优化
(基于GP)
贝叶斯优化
(基于TPE)
搜索空间/全域空间 - 1536/1536 800/1536 1536/3000 1536/无限 300/无限 277/3000
运行时间(分钟) - 6.36 **2.83(↓)** **3.86(↓)** 3.92 **2.11(↓)** **1.00(↓)**
搜索最优(RMSE) 30571.266 29179.698 29251.284 **29012.905(↓)** 29148.381 **28346.673(↓)** **28346.673(-)**
重建最优(RMSE) - 28572.070 **28639.969(↑)** **28346.673(↓)** 28495.682 **28346.673(-)** **28346.673(-)**

由于具有提前停止功能,因此基于TPE的hyperopt优化可能在我们设置的迭代次数被达到之前就停止,也因此hyperopt迭代到实际最优值所需的迭代次数可能更少。同时,TPE方法相比于高斯过程计算会更加迅速,因此在运行277次迭代的情况下,hyperopt只需要1分钟时间,而运行300次迭代的bayes_opt却需要2.11分钟,可见,即便运行同样的迭代次数,hyperopt也是更有优势的,这或许是因为hyperopt的参数空间更加稀疏、在整数型参数搜索上更高效。

不过HyperOpt的缺点也很明显,那就是代码精密度要求较高、灵活性较差,略微的改动就可能让代码疯狂报错难以跑通。同时,HyperOpt所支持的优化算法也不够多,如果我们专注地使用TPE方法,则掌握HyperOpt即可,如果我们希望拥有丰富的HPO手段,则可以更深入地接触Optuna库。

3 基于Optuna实现多种贝叶斯优化¶

Optuna是目前为止最为成熟、拓展性最强的超参数优化框架,与古旧的bayes_opt相比,Optuna明显是专门为机器学习和深度学习所设计。为了满足机器学习开发者的需求,Optuna拥有强大且固定的API,因此Optuna代码简单,编写高度模块化,是我们介绍的库中代码最为简练的库,有点kears的意思了。Optuna的优势在于,它可以无缝衔接到PyTorch、Tensorflow等深度学习框架上,也可以与sklearn的优化库scikit-optimize结合使用,因此Optuna可以被用于各种各样的优化场景。在我们的课程中,我们将重点介绍Optuna实现贝叶斯优化的过程,其他优化方面内容可以参考以下页面:https://github.com/optuna/optuna 。

In [25]:
import optuna
In [26]:
print(optuna.__version__)
2.10.0
  • 1 定义目标函数与参数空间

Optuna的目标函数相当特别。在其他优化库中,我们需要单独输入参数或参数空间,优化器会在具体优化过程中将参数空间一一放入我们的目标函数进行优化,但在Optuna中,我们并不需要将参数或参数空间输入目标函数,而是需要直接在目标函数中定义参数空间。特别的是,Optuna优化器会生成一个指代备选参数的变量trial,该变量无法被用户获取或打开,但该变量在优化器中生存,并被输入目标函数。在目标函数中,我们可以通过变量trail所携带的方法来构造参数空间,具体如下所示:

In [27]:
def optuna_objective(trial):  # 只能输入trial这一个参数,trial表示一次尝试,不由我们定义,由Optuna自己决定,本质是个对象
    
    # 定义参数空间
    n_estimators = trial.suggest_int("n_estimators",80,100,1)                           # 整数型,(参数名称,下界,上界,步长)
    max_depth = trial.suggest_int("max_depth",10,25,1)
    max_features = trial.suggest_int("max_features",10,20,1)
    #max_features = trial.suggest_categorical("max_features",["log2","sqrt","auto"])    # 字符型
    min_impurity_decrease = trial.suggest_int("min_impurity_decrease",0,5,1)
    #min_impurity_decrease = trial.suggest_float("min_impurity_decrease",0,5,log=False) # 浮点型
    
    # 定义评估器
    # 需要优化的参数由上述参数空间决定
    # 不需要优化的参数则直接填写具体值
    reg = RFR(n_estimators = n_estimators
              ,max_depth = max_depth
              ,max_features = max_features
              ,min_impurity_decrease = min_impurity_decrease
              ,random_state=1412
              ,verbose=False
              ,n_jobs=-1
             )
    
    # 交叉验证过程,输出负均方根误差(-RMSE)
    # optuna同时支持最大化和最小化,因此如果输出-RMSE,则选择最大化
    # 如果选择输出RMSE,则选择最小化
    cv = KFold(n_splits=5,shuffle=True,random_state=1412)
    validation_loss = cross_validate(reg,X,y
                                     ,scoring="neg_root_mean_squared_error"
                                     ,cv=cv                # 交叉验证模式
                                     ,verbose=False        # 是否打印进程
                                     ,n_jobs=-1            # 线程数
                                     ,error_score='raise'
                                    )
    # 最终输出RMSE
    return np.mean(abs(validation_loss["test_score"]))
  • 2 定义优化目标函数的具体流程

在HyperOpt当中我们可以调整参数algo来自定义用于执行贝叶斯优化的具体算法,在Optuna中我们也可以。大部分备选的算法都集中在Optuna的模块sampler中,包括我们熟悉的TPE优化、随机网格搜索以及其他各类更加高级的贝叶斯过程,对于Optuna.sampler中调出的类,我们也可以直接输入参数来设置初始观测值的数量、以及每次计算采集函数时所考虑的观测值量。在Optuna库中并没有集成实现高斯过程的方法,但我们可以从scikit-optimize里面导入高斯过程来作为optuna中的algo设置,而具体的高斯过程相关的参数则可以通过如下方法进行设置:

In [28]:
def optimizer_optuna(n_trials, algo):
    
    # 定义使用TPE或者高斯过程(GP)
    if algo == "TPE":
        algo = optuna.samplers.TPESampler(n_startup_trials = 10, n_ei_candidates = 24) # n_startup_trials:初始观测点个数;n_ei_candidates:期望增量;这里都是默认值
    elif algo == "GP":
        # Optuna没有内置高斯过程类,使用Optuna库的integration模块,有很多与其他模块相结合的工具,例如这里的SkoptSampler
        from optuna.integration import SkoptSampler
        import skopt
        algo = SkoptSampler(skopt_kwargs={'base_estimator':'GP', # 选择高斯过程
                                          'n_initial_points':10, # 初始观测点10个
                                          'acq_func':'EI'}       # 选择的采集函数为期望增量(EI)
                           )
    
    # 实际优化过程,首先实例化优化器
    study = optuna.create_study(sampler = algo         # 要使用的具体算法,sampler表示抽样器
                                , direction="minimize" # 优化的方向,可以填写minimize或maximize确立找最小值还是最大值
                               )
    # 开始优化,n_trials为允许的最大迭代次数
    # 由于参数空间已经在目标函数中定义好,因此这里不需要输入参数空间
    study.optimize(optuna_objective         # 目标函数
                   , n_trials=n_trials      # 最大迭代次数(包括最初的观测值的)
                   , show_progress_bar=True # 是否展示进度条
                  )
    
    # 可直接从优化好的对象study中调用优化的结果
    # 打印最佳参数与最佳损失值
    print("\n","\n","best params: ", study.best_trial.params,
          "\n","\n","best score: ", study.best_trial.values,
          "\n")
    
    return study.best_trial.params, study.best_trial.values
  • 3 执行实际优化流程

Optuna库虽然是当今最为成熟的HPO方法之一,但当参数空间较小时,Optuna库在迭代中容易出现抽样BUG,即Optuna会持续抽到曾经被抽到过的参数组合,并且持续报警告说"算法已在这个参数组合上检验过目标函数了"。在实际迭代过程中,一旦出现这个Bug,那当下的迭代就无用了,因为已经检验过的观测值不会对优化有任何的帮助,因此对损失的优化将会停止。如果出现该BUG,则可以增大参数空间的范围或密度。或者使用如下的代码令警告关闭:

In [ ]:
# 屏蔽警告
import warnings
warnings.filterwarnings('ignore', message='The objective has been evaluated at this point before.')
In [29]:
best_params, best_score = optimizer_optuna(10,"GP") #默认打印迭代过程
[I 2021-12-24 22:14:26,709] A new study created in memory with name: no-name-05950945-f6f7-41c3-bd8a-ffb15a284ea9
D:\ProgramData\Anaconda3\lib\site-packages\optuna\progress_bar.py:47: ExperimentalWarning: Progress bar is experimental (supported from v1.2.0). The interface can change in the future.
  self._init_valid()
  0%|          | 0/10 [00:00<?, ?it/s]
[I 2021-12-24 22:14:28,229] Trial 0 finished with value: 28848.70339210933 and parameters: {'n_estimators': 99, 'max_depth': 14, 'max_features': 16, 'min_impurity_decrease': 4}. Best is trial 0 with value: 28848.70339210933.
[I 2021-12-24 22:14:29,309] Trial 1 finished with value: 28632.395126147465 and parameters: {'n_estimators': 90, 'max_depth': 23, 'max_features': 16, 'min_impurity_decrease': 2}. Best is trial 1 with value: 28632.395126147465.
[I 2021-12-24 22:14:30,346] Trial 2 finished with value: 29301.159287113685 and parameters: {'n_estimators': 89, 'max_depth': 17, 'max_features': 12, 'min_impurity_decrease': 0}. Best is trial 1 with value: 28632.395126147465.
[I 2021-12-24 22:14:31,215] Trial 3 finished with value: 29756.446415640086 and parameters: {'n_estimators': 80, 'max_depth': 11, 'max_features': 14, 'min_impurity_decrease': 3}. Best is trial 1 with value: 28632.395126147465.
[I 2021-12-24 22:14:31,439] Trial 4 finished with value: 29784.547574554617 and parameters: {'n_estimators': 88, 'max_depth': 11, 'max_features': 15, 'min_impurity_decrease': 2}. Best is trial 1 with value: 28632.395126147465.
[I 2021-12-24 22:14:31,651] Trial 5 finished with value: 28854.291800282757 and parameters: {'n_estimators': 82, 'max_depth': 12, 'max_features': 18, 'min_impurity_decrease': 3}. Best is trial 1 with value: 28632.395126147465.
[I 2021-12-24 22:14:31,853] Trial 6 finished with value: 29268.28890743908 and parameters: {'n_estimators': 80, 'max_depth': 10, 'max_features': 19, 'min_impurity_decrease': 5}. Best is trial 1 with value: 28632.395126147465.
[I 2021-12-24 22:14:32,111] Trial 7 finished with value: 29302.5258321895 and parameters: {'n_estimators': 99, 'max_depth': 16, 'max_features': 14, 'min_impurity_decrease': 3}. Best is trial 1 with value: 28632.395126147465.
[I 2021-12-24 22:14:32,353] Trial 8 finished with value: 29449.903990989755 and parameters: {'n_estimators': 80, 'max_depth': 21, 'max_features': 17, 'min_impurity_decrease': 1}. Best is trial 1 with value: 28632.395126147465.
[I 2021-12-24 22:14:32,737] Trial 9 finished with value: 29168.76064401323 and parameters: {'n_estimators': 97, 'max_depth': 22, 'max_features': 17, 'min_impurity_decrease': 1}. Best is trial 1 with value: 28632.395126147465.

 
 best params:  {'n_estimators': 90, 'max_depth': 23, 'max_features': 16, 'min_impurity_decrease': 2} 
 
 best score:  [28632.395126147465] 

In [80]:
optuna.logging.set_verbosity(optuna.logging.ERROR)     # 关闭自动打印的info,只显示进度条
#optuna.logging.set_verbosity(optuna.logging.INFO)
best_params, best_score = optimizer_optuna(300,"TPE")  # 300次迭代,TPE
D:\ProgramData\Anaconda3\lib\site-packages\optuna\progress_bar.py:47: ExperimentalWarning: Progress bar is experimental (supported from v1.2.0). The interface can change in the future.
  self._init_valid()
  0%|          | 0/300 [00:00<?, ?it/s]
 
 best params:  {'n_estimators': 96, 'max_depth': 22, 'max_features': 14, 'min_impurity_decrease': 3} 
 
 best score:  [28457.22400533479] 

In [85]:
optuna.logging.set_verbosity(optuna.logging.ERROR)
best_params, best_score = optimizer_optuna(300,"GP")
D:\ProgramData\Anaconda3\lib\site-packages\optuna\progress_bar.py:47: ExperimentalWarning: Progress bar is experimental (supported from v1.2.0). The interface can change in the future.
  self._init_valid()
  0%|          | 0/300 [00:00<?, ?it/s]
 
 best params:  {'n_estimators': 87, 'max_depth': 23, 'max_features': 16, 'min_impurity_decrease': 5} 
 
 best score:  [28541.05837443567] 

很显然,基于高斯过程的贝叶斯优化是比基于TPE的贝叶斯优化运行更加缓慢的。在Optuna进行调试时,我并没有多次运行并取出Optuna表现最好的值,因此我们可以不将Optuna的结果最终放入表格进行比较,不过在TPE模式下,其运行速度与HyperOpt的运行速度高度接近。在未来的课程中,除非特殊说明,我们将默认使用TPE方法进行优化。