官方c api tutorial和文档,非常恶心的一点是,tutorial和文档问题很多。
也参考了不少开源项目,主要有xgboost-c-cplusplus,xgboostpp.
首先导入头文件#include "xgboost/c_api.h"
,接下来xgboost的绝大多数接口都包含在了这个头文件中。
然后我们需要一个宏,来用它获取xgboost函数使用的情况.在每次调用xgboost函数时都应该调用这个宏。
代码语言:c复制#define safe_xgboost(call) {
int err = (call);
if (err != 0) {
fprintf(stderr, "%s:%d: error in %s: %sn", __FILE__, __LINE__, #call, XGBGetLastError());
exit(1);
}
}
我们使用的模型文件为xgboost_model.bin
,训练数据的输入是 11 个元素。
首先我们声明一个boost模型的句柄BoosterHandle booster;
接着用XGBoosterCreate
函数创建一个模型 。
BoosterHandle booster;
safe_xgboost(XGBoosterCreate(NULL, 0, &booster));
设置一个字符串作为模型路径const char *model_path = "../xgboost_model.bin";
(../
是因为编译出来的可执行文件在build目录下) , 通过句柄使用XGBoosterLoadModel
函数加载模型。
const char *model_path = "../xgboost_model.bin";
XGBoosterLoadModel(booster, model_path)
设置一组数据作为推理测试,这里我选的数据标签是1.接着将输入数据转为xgboost的DMatrix格式。
代码语言:c复制float a[11]= {14.0,2.0,1.0,12.0,19010.0,120.0,14.0,0.0,0.0,0.0,0.0};
DMatrixHandle h_test;
safe_xgboost(XGDMatrixCreateFromMat(a, 1, 11, -1, &h_test));
下面就可以进行模型推理了,out_len
代表输出的长度(实际上是一个整型变量),f
的模型推理的结果。
bst_ulong out_len;
const float *f;
safe_xgboost(XGBoosterPredict(booster, h_test, 0, 0, 1, &out_len, &f));
我们可以打印输出查看结果
代码语言:c复制printf("Value of the variable: %fn", f[0]);
最后记得释放内存
代码语言:c复制XGDMatrixFree(h_test);
XGBoosterFree(booster);
完整的代码
代码语言:c复制#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include "xgboost/c_api.h"
#define safe_xgboost(call) {
int err = (call);
if (err != 0) {
fprintf(stderr, "%s:%d: error in %s: %sn", __FILE__, __LINE__, #call, XGBGetLastError());
exit(1);
}
}
int main(int argc, char const *argv[]) {
const char *model_path = "../xgboost_model.bin";
// create booster handle first
BoosterHandle booster;
safe_xgboost(XGBoosterCreate(NULL, 0, &booster));
// load model
safe_xgboost(XGBoosterLoadModel(booster, model_path));
//generate random data of a a[11],every nuber from 0 to 2
// float a[11]= {1.0,12.0,1.0,1.0,16134.0,20600.0,0.0,1.0,0.0,0.0,0.0}; // label: 0.0
float a[11]= {14.0,2.0,1.0,12.0,19010.0,120.0,14.0,0.0,0.0,0.0,0.0}; // label: 1.0
for (int i = 0; i < 11; i ) {
printf("%f, ", a[i]);
if (i == 10) {
printf("n");
}
}
// convert to DMatrix
DMatrixHandle h_test;
safe_xgboost(XGDMatrixCreateFromMat(a, 1, 11, -1, &h_test));
// predict
bst_ulong out_len;
const float *f;
safe_xgboost(XGBoosterPredict(booster, h_test, 0, 0, 1, &out_len, &f));
printf("Value of the variable: %fn", f[0]);
XGDMatrixFree(h_test);
XGBoosterFree(booster);
return 0;
}
使用cmake编译
代码语言:CMakeLists.txt复制cmake_minimum_required(VERSION 3.18)
project(project_name LANGUAGES C CXX VERSION 0.1)
set(xgboost_DIR "/usr/include/xgboost")
include_directories(${xgboost_DIR})
link_directories(${xgboost_DIR})
add_executable(project_name test.c)
target_link_libraries(project_name xgboost)
代码语言:bash复制mkdir build
cd ./build
cmake ..
make .
./project_name