🧭 开始使用¶
安装 YDF¶
pip install ydf -U
导入库¶
import ydf # Yggdrasil决策森林
import pandas as pd # 我们使用Pandas加载小型数据集。
下载和加载数据集¶
我们使用二分类成人数据集。目标是预测 income 列的值,该值可以是 <50k 或 >=50k,使用其他数值和分类列。该数据集包含缺失值。
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
# 下载并加载数据集为Pandas DataFrame
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# 打印前5个训练样本
train_ds.head(5)
| age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
| 1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
| 2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
| 3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
| 4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
训练模型¶
让我们使用所有超参数的默认值来训练一个梯度提升树模型。
model = ydf.GradientBoostedTreesLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:03.698584
备注
- YDF 区分学习算法(即 学习者,如
GradientBoostedTreesLearner)和 模型。稍后,在更高级的示例中,您将看到我们为什么这样做 :). - 学习者的唯一必需参数是
label。其他参数具有良好的默认值。 - 我们没有指定输入特征,因此所有列都作为输入特征使用。特征的类型会自动检测(例如,数值、分类、布尔、文本,可能还有缺失值)并进行处理。
- 默认情况下,学习者训练分类模型。其他任务(例如,回归、排序、提升)可以通过任务参数进行配置,例如
task=ydf.Task.REGRESSION。 - 训练日志可以在训练期间通过
verbose=2参数显示,或在训练后通过model.describe()查看。这对调试和理解训练过程很有用。 - 没有指定验证数据集。在这种情况下,像
GradientBoostedTreesLearner这样的学习者将从培训数据集中提取可用于验证的数据。像RandomForestLearner这样的其他学习者不需要验证数据集,并将使用所有数据进行训练。
查看模型¶
通过 model.describe(),我们可以查看:
- 模型:模型任务、输入特征和大小。
- 数据规范:关于所有输入特征的统计类型。
- 训练:训练和验证的损失及指标。
- 调优(仅在启用超参数调优时):调优日志。
- 变量重要性:对模型最重要的特征。
- 结构:模型中的树。
model.describe()
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Model size : 2174 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION
Label: income
Loss (BINOMIAL_LOG_LIKELIHOOD): 0.576162
Accuracy: 0.868526 CI95[W][0 1]
ErrorRate: : 0.131474
Confusion Table:
truth\prediction
<=50K >50K
<=50K 1557 107
>50K 190 405
Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "age" 0.226642 ################
2. "occupation" 0.219727 #############
3. "capital_gain" 0.214876 ############
4. "education" 0.213746 ###########
5. "marital_status" 0.212739 ###########
6. "relationship" 0.206040 #########
7. "fnlwgt" 0.203843 ########
8. "hours_per_week" 0.203735 ########
9. "capital_loss" 0.196549 ######
10. "native_country" 0.190548 ####
11. "workclass" 0.187795 ###
12. "education_num" 0.184215 ##
13. "race" 0.180495
14. "sex" 0.177647
1. "age" 26.000000 ################
2. "capital_gain" 26.000000 ################
3. "marital_status" 20.000000 ############
4. "relationship" 17.000000 ##########
5. "capital_loss" 14.000000 ########
6. "hours_per_week" 14.000000 ########
7. "education" 12.000000 #######
8. "fnlwgt" 10.000000 #####
9. "race" 9.000000 #####
10. "education_num" 7.000000 ###
11. "sex" 4.000000 #
12. "occupation" 2.000000
13. "workclass" 1.000000
14. "native_country" 1.000000
1. "occupation" 724.000000 ################
2. "fnlwgt" 513.000000 ###########
3. "age" 483.000000 ##########
4. "education" 464.000000 ##########
5. "hours_per_week" 339.000000 #######
6. "capital_gain" 326.000000 ######
7. "native_country" 306.000000 ######
8. "capital_loss" 297.000000 ######
9. "relationship" 262.000000 #####
10. "workclass" 244.000000 #####
11. "marital_status" 210.000000 ####
12. "education_num" 82.000000 #
13. "sex" 42.000000
14. "race" 21.000000
1. "relationship" 3014.690076 ################
2. "capital_gain" 2065.521668 ##########
3. "education" 1144.490954 ######
4. "marital_status" 1111.389695 #####
5. "occupation" 1094.619502 #####
6. "education_num" 796.666823 ####
7. "capital_loss" 584.055066 ###
8. "age" 582.288569 ###
9. "hours_per_week" 366.856509 #
10. "native_country" 263.872689 #
11. "fnlwgt" 216.537764 #
12. "workclass" 196.085503 #
13. "sex" 47.217730
14. "race" 5.428727
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0:
"relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09
├─(pos)─ "education_num">=12.5 [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933
| ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683
| | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173
| | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414
| | | | ├─(pos)─ pred:0.309737
| | | | └─(neg)─ pred:0.418684
| | | └─(neg)─ pred:0.309737
| | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058
| | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701
| | | ├─(pos)─ pred:0.349312
| | | └─(neg)─ pred:0.417359
| | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919
| | ├─(pos)─ pred:0.253437
| | └─(neg)─ pred:0.11557
| └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685
| ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543
| | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428
| | | ├─(pos)─ pred:0.397934
| | | └─(neg)─ pred:0.205614
| | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:0.419984
| └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136
| ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799
| | ├─(pos)─ pred:0.132992
| | └─(neg)─ pred:0.00826457
| └─(neg)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Bachelors, Masters, Assoc-voc, Assoc-acdm, Prof-school, Doctorate} [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452
| ├─(pos)─ pred:0.00969668
| └─(neg)─ pred:-0.0813718
└─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681
├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823
| ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777
| | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988
| | | ├─(pos)─ pred:0.392422
| | | └─(neg)─ pred:0.419984
| | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949
| | ├─(pos)─ pred:0.419984
| | └─(neg)─ pred:-0.0210046
| └─(neg)─ pred:0.0892425
└─(neg)─ "education" is in [BITMAP] {<OOD>, Bachelors, Masters, Prof-school, Doctorate} [s:0.00229611 n:11121 np:2199 miss:1] ; pred:-0.10399
├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848
| ├─(pos)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Assoc-voc, 11th, Assoc-acdm, 10th, 7th-8th, Prof-school, 9th, ...[5 left]} [s:0.0110157 n:1263 np:125 miss:1] ; pred:-0.0103552
| | ├─(pos)─ pred:0.16421
| | └─(neg)─ pred:-0.0295298
| └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339
| ├─(pos)─ pred:0.19949
| └─(neg)─ pred:-0.106976
└─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103
├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198
| ├─(pos)─ pred:-0.0328167
| └─(neg)─ pred:0.292776
└─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969
├─(pos)─ pred:-0.0927111
└─(neg)─ pred:-0.123347
进行预测¶
model.predict(ds) 应用模型并返回预测结果作为一个 Numpy 数组。
model.predict(test_ds)
array([0.01860435, 0.36130956, 0.83858865, ..., 0.03087652, 0.08280362,
0.00970956], dtype=float32)
可以消耗数据集的方法,例如train和predict,支持多种数据集格式,如Pandas DataFrames、列表或Numpy数组的字典、TensorFlow数据集以及事件文件路径!
# 使用字典进行预测
model.predict({
'age': [39],
'workclass': ['State-gov'],
'fnlwgt': [77516],
'education': ['Bachelors'],
'education_num': [13],
'marital_status': ['Never-married'],
'occupation': ['Adm-clerical'],
'relationship': ['Not-in-family'],
'race': ['White'],
'sex': ['Male'],
'capital_gain': [2174],
'capital_loss': [0],
'hours_per_week': [40],
'native_country': ['United-States'],
'income': ['<=50K'],
})
array([0.01860435], dtype=float32)
评估模型¶
虽然上面的验证数据集提供了模型质量的指示,但我们也希望在测试数据集上评估模型。
evaluation = model.evaluate(test_ds)
# 查询个人评估指标
print(f"test accuracy: {evaluation.accuracy}")
# 显示完整评估报告
print("Full evaluation report:")
evaluation
test accuracy: 0.8738867847271983 Full evaluation report:
| Label \ Pred | <=50K | >50K |
|---|---|---|
| <=50K | 6962 | 782 |
| >50K | 450 | 1575 |
model.analyze(test_ds, sampling=0.1)
Variable importances measure the importance of an input feature for a model.
1. "capital_gain" 0.052513 ################
2. "occupation" 0.020882 ######
3. "age" 0.015559 ####
4. "relationship" 0.015150 ####
5. "marital_status" 0.014331 ####
6. "capital_loss" 0.014331 ####
7. "education" 0.009110 ##
8. "hours_per_week" 0.006551 #
9. "education_num" 0.005323 #
10. "workclass" 0.003378
11. "race" 0.001024
12. "sex" 0.000921
13. "fnlwgt" 0.000614
14. "native_country" 0.000614
1. "capital_gain" 0.248326 ################
2. "age" 0.051386 ###
3. "marital_status" 0.046224 ##
4. "capital_loss" 0.044403 ##
5. "occupation" 0.037985 ##
6. "relationship" 0.037500 ##
7. "education" 0.021677 #
8. "hours_per_week" 0.015487
9. "education_num" 0.008588
10. "workclass" 0.003808
11. "sex" 0.003478
12. "fnlwgt" 0.002788
13. "native_country" 0.001978
14. "race" 0.001111
1. "capital_gain" 0.061589 ################
2. "age" 0.033311 ########
3. "marital_status" 0.029546 #######
4. "relationship" 0.020694 #####
5. "occupation" 0.019686 #####
6. "capital_loss" 0.014316 ###
7. "education" 0.012061 ##
8. "hours_per_week" 0.009984 ##
9. "education_num" 0.004140
10. "sex" 0.001985
11. "workclass" 0.001577
12. "native_country" 0.001397
13. "fnlwgt" 0.000936
14. "race" 0.000637
1. "capital_gain" 0.248064 ################
2. "age" 0.051338 ###
3. "marital_status" 0.045982 ##
4. "capital_loss" 0.044387 ##
5. "occupation" 0.037982 ##
6. "relationship" 0.037494 ##
7. "education" 0.021676 #
8. "hours_per_week" 0.015486
9. "education_num" 0.008585
10. "workclass" 0.003812
11. "sex" 0.003477
12. "fnlwgt" 0.002791
13. "native_country" 0.001981
14. "race" 0.001112
1. "age" 0.226642 ################
2. "occupation" 0.219727 #############
3. "capital_gain" 0.214876 ############
4. "education" 0.213746 ###########
5. "marital_status" 0.212739 ###########
6. "relationship" 0.206040 #########
7. "fnlwgt" 0.203843 ########
8. "hours_per_week" 0.203735 ########
9. "capital_loss" 0.196549 ######
10. "native_country" 0.190548 ####
11. "workclass" 0.187795 ###
12. "education_num" 0.184215 ##
13. "race" 0.180495
14. "sex" 0.177647
1. "age" 26.000000 ################
2. "capital_gain" 26.000000 ################
3. "marital_status" 20.000000 ############
4. "relationship" 17.000000 ##########
5. "capital_loss" 14.000000 ########
6. "hours_per_week" 14.000000 ########
7. "education" 12.000000 #######
8. "fnlwgt" 10.000000 #####
9. "race" 9.000000 #####
10. "education_num" 7.000000 ###
11. "sex" 4.000000 #
12. "occupation" 2.000000
13. "workclass" 1.000000
14. "native_country" 1.000000
1. "occupation" 724.000000 ################
2. "fnlwgt" 513.000000 ###########
3. "age" 483.000000 ##########
4. "education" 464.000000 ##########
5. "hours_per_week" 339.000000 #######
6. "capital_gain" 326.000000 ######
7. "native_country" 306.000000 ######
8. "capital_loss" 297.000000 ######
9. "relationship" 262.000000 #####
10. "workclass" 244.000000 #####
11. "marital_status" 210.000000 ####
12. "education_num" 82.000000 #
13. "sex" 42.000000
14. "race" 21.000000
1. "relationship" 3014.690076 ################
2. "capital_gain" 2065.521668 ##########
3. "education" 1144.490954 ######
4. "marital_status" 1111.389695 #####
5. "occupation" 1094.619502 #####
6. "education_num" 796.666823 ####
7. "capital_loss" 584.055066 ###
8. "age" 582.288569 ###
9. "hours_per_week" 366.856509 #
10. "native_country" 263.872689 #
11. "fnlwgt" 216.537764 #
12. "workclass" 196.085503 #
13. "sex" 47.217730
14. "race" 5.428727
基准模型速度¶
在模型速度至关重要的应用中,我们可以使用 model.benchmark(ds) 来评估模型的速度。
model.benchmark(test_ds)
Inference time per example and per cpu core: 0.891 us (microseconds) Estimated over 345 runs over 3.004 seconds. * Measured with the C++ serving API. Check model.to_cpp() for details.
基准测试测量了使用C++ API时模型的速度。由于Python解释器的开销,Python API会更慢。如果您不熟悉C++ API,可以使用model.to_cpp()方法生成可以运行的C++代码,以评估模型的速度。
print(model.to_cpp())
// Automatically generated code running an Yggdrasil Decision Forests model in
// C++. This code was generated with "model.to_cpp()".
//
// Date of generation: 2023-12-19 15:29:09.343331
// YDF Version: 0.0.8
//
// How to use this code:
//
// 1. Copy this code in a new .h file.
// 2. If you use Bazel/Blaze, use the following dependencies:
// //third_party/absl/status:statusor
// //third_party/absl/strings
// //external/ydf_cc/yggdrasil_decision_forests/api:serving
// 3. In your existing code, include the .h file and do:
// // Load the model (to do only once).
// namespace ydf = yggdrasil_decision_forests;
// const auto model = ydf::exported_model_123::Load(<path to model>);
// // Run the model
// predictions = model.Predict();
// 4. By default, the "Predict" function takes no inputs and creates fake
// examples. In practice, you want to add your input data as arguments to
// "Predict" and call "examples->Set..." functions accordingly.
// 4. (Bonus)
// Allocate one "examples" and "predictions" per thread and reuse them to
// speed-up the inference.
//
#ifndef YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model
#define YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model
#include <memory>
#include <vector>
#include "third_party/absl/status/statusor.h"
#include "third_party/absl/strings/string_view.h"
#include "external/ydf_cc/yggdrasil_decision_forests/api/serving.h"
namespace yggdrasil_decision_forests {
namespace exported_model_my_model {
struct ServingModel {
std::vector<float> Predict() const;
// Compiled model.
std::unique_ptr<serving_api::FastEngine> engine;
// Index of the input features of the model.
//
// Non-owning pointer. The data is owned by the engine.
const serving_api::FeaturesDefinition* features;
// Number of output predictions for each example.
// Equal to 1 for regression, ranking and binary classification with compact
// format. Equal to the number of classes for classification.
int NumPredictionDimension() const {
return engine->NumPredictionDimension();
}
// Indexes of the input features.
serving_api::NumericalFeatureId feature_age;
serving_api::CategoricalFeatureId feature_workclass;
serving_api::NumericalFeatureId feature_fnlwgt;
serving_api::CategoricalFeatureId feature_education;
serving_api::NumericalFeatureId feature_education_num;
serving_api::CategoricalFeatureId feature_marital_status;
serving_api::CategoricalFeatureId feature_occupation;
serving_api::CategoricalFeatureId feature_relationship;
serving_api::CategoricalFeatureId feature_race;
serving_api::CategoricalFeatureId feature_sex;
serving_api::NumericalFeatureId feature_capital_gain;
serving_api::NumericalFeatureId feature_capital_loss;
serving_api::NumericalFeatureId feature_hours_per_week;
serving_api::CategoricalFeatureId feature_native_country;
};
// TODO: Pass input feature values to "Predict".
inline std::vector<float> ServingModel::Predict() const {
// Allocate memory for 2 examples. Alternatively, for speed-sensitive code,
// an "examples" object can be allocated for each thread and reused. It is
// okay to allocate more examples than needed.
const int num_examples = 2;
auto examples = engine->AllocateExamples(num_examples);
// Set all the values to be missing. The values may then be overridden by the
// "Set*" methods. If all the values are set with "Set*" methods,
// "FillMissing" can be skipped.
examples->FillMissing(*features);
// Example #0
examples->SetNumerical(/*example_idx=*/0, feature_age, 1.f, *features);
examples->SetCategorical(/*example_idx=*/0, feature_workclass, "A", *features);
examples->SetNumerical(/*example_idx=*/0, feature_fnlwgt, 1.f, *features);
examples->SetCategorical(/*example_idx=*/0, feature_education, "A", *features);
examples->SetNumerical(/*example_idx=*/0, feature_education_num, 1.f, *features);
examples->SetCategorical(/*example_idx=*/0, feature_marital_status, "A", *features);
examples->SetCategorical(/*example_idx=*/0, feature_occupation, "A", *features);
examples->SetCategorical(/*example_idx=*/0, feature_relationship, "A", *features);
examples->SetCategorical(/*example_idx=*/0, feature_race, "A", *features);
examples->SetCategorical(/*example_idx=*/0, feature_sex, "A", *features);
examples->SetNumerical(/*example_idx=*/0, feature_capital_gain, 1.f, *features);
examples->SetNumerical(/*example_idx=*/0, feature_capital_loss, 1.f, *features);
examples->SetNumerical(/*example_idx=*/0, feature_hours_per_week, 1.f, *features);
examples->SetCategorical(/*example_idx=*/0, feature_native_country, "A", *features);
// Example #1
examples->SetNumerical(/*example_idx=*/1, feature_age, 2.f, *features);
examples->SetCategorical(/*example_idx=*/1, feature_workclass, "B", *features);
examples->SetNumerical(/*example_idx=*/1, feature_fnlwgt, 2.f, *features);
examples->SetCategorical(/*example_idx=*/1, feature_education, "B", *features);
examples->SetNumerical(/*example_idx=*/1, feature_education_num, 2.f, *features);
examples->SetCategorical(/*example_idx=*/1, feature_marital_status, "B", *features);
examples->SetCategorical(/*example_idx=*/1, feature_occupation, "B", *features);
examples->SetCategorical(/*example_idx=*/1, feature_relationship, "B", *features);
examples->SetCategorical(/*example_idx=*/1, feature_race, "B", *features);
examples->SetCategorical(/*example_idx=*/1, feature_sex, "B", *features);
examples->SetNumerical(/*example_idx=*/1, feature_capital_gain, 2.f, *features);
examples->SetNumerical(/*example_idx=*/1, feature_capital_loss, 2.f, *features);
examples->SetNumerical(/*example_idx=*/1, feature_hours_per_week, 2.f, *features);
examples->SetCategorical(/*example_idx=*/1, feature_native_country, "B", *features);
// Run the model on the two examples.
//
// For speed-sensitive code, reuse the same predictions.
std::vector<float> predictions;
engine->Predict(*examples, num_examples, &predictions);
return predictions;
}
inline absl::StatusOr<ServingModel> Load(absl::string_view path) {
ServingModel m;
// Load the model
ASSIGN_OR_RETURN(auto model, serving_api::LoadModel(path));
// Compile the model into an inference engine.
ASSIGN_OR_RETURN(m.engine, model->BuildFastEngine());
// Index the input features of the model.
m.features = &m.engine->features();
// Index the input features.
ASSIGN_OR_RETURN(m.feature_age, m.features->GetNumericalFeatureId("age"));
ASSIGN_OR_RETURN(m.feature_workclass, m.features->GetCategoricalFeatureId("workclass"));
ASSIGN_OR_RETURN(m.feature_fnlwgt, m.features->GetNumericalFeatureId("fnlwgt"));
ASSIGN_OR_RETURN(m.feature_education, m.features->GetCategoricalFeatureId("education"));
ASSIGN_OR_RETURN(m.feature_education_num, m.features->GetNumericalFeatureId("education_num"));
ASSIGN_OR_RETURN(m.feature_marital_status, m.features->GetCategoricalFeatureId("marital_status"));
ASSIGN_OR_RETURN(m.feature_occupation, m.features->GetCategoricalFeatureId("occupation"));
ASSIGN_OR_RETURN(m.feature_relationship, m.features->GetCategoricalFeatureId("relationship"));
ASSIGN_OR_RETURN(m.feature_race, m.features->GetCategoricalFeatureId("race"));
ASSIGN_OR_RETURN(m.feature_sex, m.features->GetCategoricalFeatureId("sex"));
ASSIGN_OR_RETURN(m.feature_capital_gain, m.features->GetNumericalFeatureId("capital_gain"));
ASSIGN_OR_RETURN(m.feature_capital_loss, m.features->GetNumericalFeatureId("capital_loss"));
ASSIGN_OR_RETURN(m.feature_hours_per_week, m.features->GetNumericalFeatureId("hours_per_week"));
ASSIGN_OR_RETURN(m.feature_native_country, m.features->GetCategoricalFeatureId("native_country"));
return m;
}
} // namespace exported_model_my_model
} // namespace yggdrasil_decision_forests
#endif // YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model
保存模型¶
最后,我们将使用相同的模型以便后续使用。
model.save("/tmp/my_model")
因此,我们可以通过以下方式加载模型:
loaded_model = ydf.load_model("/tmp/my_model")
print(f"This is a {loaded_model.name()} model.")
This is a GRADIENT_BOOSTED_TREES model.
结论¶
这就是所有内容。您已经了解了 YDF 的基本功能 😊。
要了解更多有关 YDF 的信息,请查看 ydf.readthedocs.io 上的其他教程。例如,了解如何:
- 使用
task参数训练排序、回归或提升模型。 - 使用
model.distance测量距离并找到示例之间的最近邻。 - 使用
features参数对特征施加单调约束。 - 在网页中使用 JavaScript 运行模型,使用
model.to_javascript()。 - 将模型转换为 TensorFlow SavedModel,并在 TensorFlow Serving 中运行,使用
model.to_tensorflow_saved_model()。 - 使用分布式训练计算在数十亿个训练示例上训练模型。