受本篇启发:
Treelite:树模型部署加速工具(支持XGBoost、LightGBM和Sklearn)
项目链接:https://treelite.readthedocs.io/
项目论文:https://mlsys.org/Conferences/doc/2018/196.pdf
支持模型:XGB、LGB、SKlearn树模型
还有一个特性:在树模型运行的每台计算机上安装机器学习包(例如 XGBoost、LightGBM、scikit-learning 等)非常麻烦。
这种情况不再如此:Treelite 将导出模型作为独立预测库,以便无需安装任何机器学习包即可进行预测。
1 安装
代码语言:javascript复制python3 -m pip install --user treelite treelite_runtime
2 Treelite介绍与原理
Treelite能够树模型编译优化为单独库,可以很方便的用于模型部署。经过优化后可以将XGBoost模型的预测速度提高2-6倍。
如上图,黑色曲线为XGBoost在不同batch size下的吞吐量,红色曲线为XGBoost经过TreeLite编译后的吞吐量。
Treelite支持众多的树模型,特别是随机森林和GBDT。同时Treelite可以很好的支持XGBoost, LightGBM和 scikit-learn,也可以将自定义模型根据要求完成编译。
2.1 逻辑分支
对于树模型而言,节点的分类本质使用if语句完成,而CPU在执行if语句时会等待条件逻辑的计算。
代码语言:javascript复制if ( [conditional expression] ) {
foo();
} else {
bar();
}
如果在构建树模型时候,提前计算好每个分支下面样本的个数,则可以提前预知哪一个叶子节点被执行的可能性更大,进而可以提前执行子节点逻辑。
借助于编译命令,可以完成逻辑计算加速。
代码语言:javascript复制/* expected to be false */
if( __builtin_expect([condition],0)){
...
} else {
...
}
2.2 逻辑比较
原始的分支比较可能会有浮点数比较逻辑,可以量化为数值比较逻辑。
代码语言:javascript复制if (data[3].fvalue < 1.5) {
/* floating-point comparison */
...
}
代码语言:javascript复制if (data[3].qvalue < 3) {
/* integer comparison */
...
}
3 快速入门
将树组合模型导入树精简:
代码语言:javascript复制import treelite
model = treelite.Model.load('my_model.model', model_format='xgboost')
部署源存档:
代码语言:javascript复制# Produce a zipped source directory, containing all model information
# Run `make` on the target machine
model.export_srcpkg(platform='unix', toolchain='gcc',
pkgpath='./mymodel.zip', libname='mymodel.so',
verbose=True)
部署共享库:
代码语言:javascript复制# Like export_srcpkg, but generates a shared library immediately
# Use this only when the host and target machines are compatible
model.export_lib(toolchain='gcc', libpath='./mymodel.so', verbose=True)
对目标机器进行预测:
代码语言:javascript复制import treelite_runtime
predictor = treelite_runtime.Predictor('./mymodel.so', verbose=True)
batch = treelite_runtime.Batch.from_npy2d(X)
out_pred = predictor.predict(batch)
4 快速load几类数据模型:XGB、LGB、SKlearn
4.1 XGB
- 从xgboost.Booster加载XGBoost模型
# bst = an object of type xgboost.Booster
model = Model.from_xgboost(bst)
- 从binary 二进制格式加载XGBoost模型
# model had been saved to a file named my_model.model
# notice the second argument model_format='xgboost'
model = Model.load('my_model.model', model_format='xgboost')
4.2 LGB
Microsoft/LightGBM的LightGBM 可以使用load(),可以指定参数:model_format='lightgbm'
代码语言:javascript复制# model had been saved to a file named my_model.txt
# notice the second argument model_format='lightgbm'
model = Model.load('my_model.txt', model_format='lightgbm')
4.3 scikit-learn模型
可以加载以下几种:
- sklearn.ensemble.RandomForestRegressor
- sklearn.ensemble.RandomForestClassifier
- sklearn.ensemble.GradientBoostingRegressor
- sklearn.ensemble.GradientBoostingClassifier
# clf is the model object generated by scikit-learn
import treelite.sklearn
model = treelite.sklearn.import_model(clf)
5 java版本:Treelite4J
Treelite4J 是Java使用的依赖,在本地文件系统中找到编译的模型(dll / so / dylib)。我们通过创建Predictor对象来加载已编译的模型:
代码语言:javascript复制import ml.dmlc.treelite4j.Predictor;
Predictor predictor = new Predictor("path/to/compiled_model.so", -1, true, true);
加载编译的模型后,我们可以对其进行查询:
代码语言:javascript复制// Get the input dimension, i.e. the number of feature values in the input vector
int num_feature = predictor.GetNumFeature();
// Get the number of classes.
// This number is 1 for tasks other than multi-class classification.
// For multi-class classification task, the number is equal to the number of classes.
int num_class = predictor.GetNumClass();
为了使用单个输入进行预测,我们创建了一个Entry对象数组,设置了它们的值,并调用了预测函数。
代码语言:javascript复制// Create an array of feature values for the input
int num_feature = predictor.GetNumFeature();
Entry[] inst = new Entry[num_feature];
// Initialize all feature values as missing
for (int i = 0; i < num_feature; i) {
inst[i] = new Entry();
inst[i].setMissing();
}
// Set feature values that are not missing
// In this example, we set feature 1, 3, and 7
inst[1].setFValue(-0.5);
inst[3].setFValue(3.2);
inst[7].setFValue(-1.7);
// Now run prediction
// (Put false in the second argument to get probability outputs)
float[] result = predictor.predict(inst, false);
// The result is either class probabilities (for multi-class classification)
// or a single number (for all other tasks, such as regression)