How to get the prediction of the last tree directly?
Direct way to get the prediction of the last tree!
Yes, XGBoost currently does not offer this capability. However, you can get JSON dump of the model and manually compute the prediction.
bst = [trained model]
model_dump = bst.get_dump(dump_format='json')
last_tree = json.loads(model_dump[-1]) # get JSON dump of last tree
# now traverse the last tree to manually compute prediction
def predict(tree, inst):
if 'children' in tree:
feature_id = tree['split']
threshold = tree['split_condition']
# default direction for missing value
default_left = (tree['missing'] == tree['yes'])
if feature_id not in inst: # missing value
return predict(tree['children'][0 if default_left else 1], inst)
elif inst[feature_id] < threshold: # test is true, go to left child
return predict(tree['children'][0], inst)
else: # test is false, go to right child
return predict(tree['children'][1], inst)
else:
return tree['leaf']
print(predict(last_tree, {0: 0.02, 2: -0.5, 5: 1.2}))