Understanding XGBoost AFT predictions [Question]


I was reading the docs trying to understand the output when we set objective:aft for survival modeling with xgboost. There is also this post that says that T(x) is the prediction of predict.xgb.Booster pretty much. If that is true, then the output values are not survival times since they need some extra calculations in there to compute that, the σΖ part. In the paper, the authors write though: “Although the XGBoost predict method only computes a point estimate (mean) of the survival time for each individual in the population, …”.

If not, then the prediction values must be somekind of linear predictions (lp) or exp(lp) as is the case with Cox (which is mentioned in the doc)? Which one is it is my question :slight_smile:

BR, John.

For AFT models, the tree ensemble model produces predictions in the same scale as the original survival time.

the output values are not survival times since they need some extra calculations in there to compute that, the σΖ part

The output is an estimate of survival time. The σΖ term represents an error (noise) from the ground truth.

1 Like

Thanks! So that means that the returned estimated survival time values is T(estimate) = exp(lp), since from the AFT formula we have:

log(T) = lp + error => T = exp(lp) * exp(error), where lp = bX

Actually, my bad. XGBoost’s AFT model predicts the log survival time, just like the Cox model. See Section 2 of https://arxiv.org/pdf/2006.04920.pdf, which contains the detailed mathematical formulation of AFT.

In the paper the authors also say: image,

which is why its confusing to me. With cox objective we know that exp(lp) is returned (hazard ratio scale as it says in the docs).

Can maybe someone pinpoint the rpediction output line in the xgboost code? (I couldn’t find it with a (quick) search)… that could help figure out what is returned exactly

XGBoost first internally estimates the log survival time and computes the AFT loss (since the AFT loss uses the log survival time, according to Section 2 of the paper). Individual trees produces scores in the scale of log survival time. Then XGBoost transforms the log survival time to the survival time using the exponential function. See the relevant source code at https://github.com/dmlc/xgboost/blob/d07b7fe8c8f893a6d6d4c09c5ee5dd4d22eb2fa5/src/objective/aft_obj.cu#L104-L113

In short: The individual trees estimate the log survival time, whereas the predict() method estimates the original survival time.

1 Like