Support for left truncated data / time-dependent covariates for Cox regression

Hi there,

I’m interested in using xgboost to do survival analysis with left-truncated data and time-dependent covariates. My understanding is that the current version does not support either, and a quick skim of the built-in Cox regression object
in https://github.com/dmlc/xgboost/blob/master/src/objective/regression_obj.cu seems to confirm this.

For those who are more familiar with the architecture of the program, would it be difficult to add support for left truncated data and time-dependent covariates? Conceptually, I believe that all that would be required is an adjustment of the risk set based on the entry time (left-truncation time) for each data point, and then time-dependent covariates come along for free; other than this, gradients and Hessians should be calculated similarly as currently. Is there an architectural reason that this would be difficult to implement? Maybe xgb doesn’t support passing in a second parameter for left-truncation times (in addition to the single “outcome” variable)?

Any advice should I try to implement this myself?

Thanks

1 Like

As you said, currently, XGBoost doesn’t support a second parameter. Is there any way to process the preds argument so as to take account of time truncation?

Hi,

I have begun implementing this externally as a custom objective + evaluation function in R, using the ability to add attributes to xgb.dmatrix objects to store pre-calculated risk sets as lookup tables (to eliminate the need to recalculate the risk sets based on truncation times at each objective/evaluation call). For using attributes to inform custom objective functions (allowing the effective passing of more parameters), see e.g.https://github.com/dmlc/xgboost/blob/master/R-package/demo/custom_objective.R. This approach seems reasonably efficient though much slower than the non-truncated case because the risk sets are not simply sortable, as in the very efficient way the built-in code handles non-truncated cox regression.

There may be a better algorithm or implementation choice than the one I’m working on, since I am never sure my R code is as fast as it could be, and other languages may be more efficient or flexible. On a data set of approx. 100,000 points, with about 1/5 “death” events, the objective function takes about 10 seconds to evaluate. It’s a bit too slow to be convenient. Parallelization may help, but I haven’t worked out how to implement that effectively.

Unfortunately, I’ve had to abandon work on this because of other priorities at work. If you are interested, I could share the code I have so far for the objective function and risk set construction.

It would be interesting if this GSOC project gets picked up: https://github.com/dmlc/xgboost/issues/4242

Yes, that would be great!
For those interested, here is my code. It has not been checked carefully for accuracy, but I would welcome any corrections.

# 1. Functions to pre-calculate risk sets and inverse map
#return index of greatest value in sorted_vals  < or <= val, restricting search to indices between min and max
binSearch <- function( sorted_vals, val, min, max, strict=FALSE ){
  #print( paste("min", min, "max", max))
  width <- max-min+1
  if ( min > max ){
    print(paste("Error: min > max (", min, max,")"))
    return(NA)
  }
  if (width == 1) {
    if(strict){
      if ( sorted_vals[min] < val){
        return (min)
      } else{
        return(NA)
      }
    }else{
      if (sorted_vals[min] <= val){
        return (min)
      } else {
        return(NA)
      }
    }
  }
  
  
  if (width %% 2){
    #odd
    mid = min + (width-1)/2
  } else {
    # even
    mid = min + width/2
  }  
  #print(paste("mid",mid))
  if(strict){
    cond <- (sorted_vals[mid] < val)
  }else{
    cond <- (sorted_vals[mid] <= val)
  }
  
  if (cond){
    return( binSearch(sorted_vals, val, mid, max, strict) )
  } else {
    return( binSearch(sorted_vals, val, min, mid-1, strict) )
  }
  
}

getRiskSets <- function (starts, stops, delta){
  starts.sorted <- sort(starts,index.return=TRUE)
  stops.sorted <- sort(stops,index.return=TRUE)
  nstarts <- length(starts)
  
  start_time <- Sys.time()
  overlaps <- vector(mode="list", length=nstarts )
  ov1 <- vector(mode="logical", length=nstarts)
  ov2 <- vector(mode="logical", length=nstarts)
  ov2[1:nstarts] <- TRUE  # flag vector for intervals with stops >= current stop 
  ov1[1:nstarts] <- FALSE # flag vector for intervals with starts <= current stop
  laststartmax=1
  first = TRUE
  laststopi=1
  for(i in 1:nstarts){
    stop_i <- stops.sorted$ix[i]
    if( !delta[stop_i] ){
      next
    }
    t_stop <- (stops.sorted$x)[i]
    start_max <- binSearch( starts.sorted$x, t_stop, laststartmax, nstarts) 
    # might be faster to just count up, but not sure, depends on density of intervals:
    # for(j in laststartmax:nstarts){
    #    if (starts.sorted$x[j] > t_stop){
    #      start_max = j-1
    #      break
    #    }
    #  }
    
    #start_max <- findInterval( t_stop,starts.sorted$x,left.open=TRUE)
    if( is.na(start_max) ){
      next
    }
    #  start_is <- starts.sorted$ix[1:start_max]
    # stop_is <- stops.sorted$ix[i:nstarts]
    
    #overlaps[stop_i] <- list(intersect(start_is, stop_is))
    # ov1[1:nstarts] <- FALSE 
    #  ov1[starts.sorted$ix[1:start_max]] <- TRUE
    if( first){
      ov1[starts.sorted$ix[1:start_max]] <- TRUE
      first=FALSE
    } else if(start_max > laststartmax){ #only this branch should ever occur after first loop
      #print(paste(i, start_max, ">", laststartmax))  
      ov1[starts.sorted$ix[(laststartmax+1):start_max]] <- TRUE  # only update indices that have changed since last loop
    }else if(start_max < laststartmax){ # this should throw an error, since we are assuming above in binsearch that start_max >= laststartmax, should be true by sort of starts and stops
      #print(paste(i,"<"))
      #ov1[starts.sorted$ix[(start_max+1):laststartmax]] <- FALSE
      print(paste("ERROR: start_max < laststartmax at iteration ",i ))
    }
    #if( i < 5) { print(which(ov1))}
    laststartmax <- start_max
    
    #ov2[stops.sorted$ix[i:nstarts]] <- TRUE
    if( i > 1){
      ov2[stops.sorted$ix[laststopi:i-1]] <- FALSE
      laststopi <-i
    }
    #ov2[stops.sorted$ix[-i:-nstarts]] <- FALSE
    
    # final step - make list over intervals that overlap stop_i by a quick vector operation
    overlaps[stop_i] <- list ( which(ov1 & ov2))
    
  }
  
  end_time <- Sys.time()
  print(end_time-start_time)
  #m2<-overlaps
  return(overlaps)
}

# much simpler though slower code
getRiskSetsContaining <- function(starts, stops, delta){
  start_time <- Sys.time()
  nstarts <- length(starts)
  risksets <- vector(mode="list", length=nstarts)
  
  for(i in 1:nstarts){
    start <- starts[i]
    stop <- stops[i]
    
    ivec<- ( (start <= stops) & (stops <= stop) & delta)  # should delta be here?
    
    risksets[i] <- list(which(ivec))
    
  }
  
  end_time <- Sys.time()
  print(end_time-start_time)
  return(risksets)
}


#2. Custom xgb objective and eval functions

# Assume these are given as lookup tables - risk set map and its inverse
#i_to_R_ti # list of vectors; if delta_i = 0, null; if delta_i=1, Risk(t_i) 
#k_to_i_k_in_R_ti  # list of vectors, for each index k containing all i s.t. delta_i=1 and k is in R(t_i) 
#delta # death indicator


ltcoxgradhess <- function(preds, dtrain){
  start_time <- Sys.time()
  i_to_R_ti <- attr(dtrain,'i_to_R_ti')
  k_to_i_k_in_R_ti <- attr(dtrain,'k_to_i_k_in_R_ti')
  delta <- attr(dtrain, 'delta')
  
  #y <- attr(dtrain, 'y')
  n <- length(delta)
  r <- vector(length=n, mode="numeric")
  rsq <- vector(length=n, mode="numeric")
  exp_r_sum <- vector(length=n, mode="numeric")
  for( i in 1:n){
    if(!delta[[i]]){
      next
    }
    # R_ti <- i_to_R_ti[[i]]  
    #exp_r_sum[[i]] <- sum(exp(preds[R_ti]))
    exp_r_sum[[i]] <- sum(exp(preds[i_to_R_ti[[i]]]))
  }
  
  for( k in 1:n){
    ti <- k_to_i_k_in_R_ti[[k]]
    exp_r_sum_ti <- exp_r_sum[ti]
    r[[k]] <- sum(1/exp_r_sum_ti)
    rsq[[k]] <- sum( 1/ (exp_r_sum_ti*exp_r_sum_ti))
  }
  
  exp_preds <- exp(preds)
  exp_preds_r <- exp_preds*r
  grad <- (-1)*(delta - exp_preds_r)
  hess <- exp_preds_r - exp_preds*exp_preds*rsq
  
  end_time <- Sys.time()
  #print(paste("ltcoxgradhess:",end_time-start_time))
  
  return(list(grad=grad,hess=hess))
}


# now create custom evaluation error function - i.e. log partial likelhood 

ltcoxevalerror <- function(preds, dtrain) {
  start_time <- Sys.time()
  i_to_R_ti <- attr(dtrain,'i_to_R_ti')
  #k_to_i_k_in_R_ti <- attr(dtrain,'k_to_i_k_in_R_ti')
  delta <- attr(dtrain, 'delta')
  
  n<- length(delta)
  loglik = vector(length=n, mode="numeric")
  for( i in 1:n){
    if( !delta[[i]]){
      next
    }
    ti <- i_to_R_ti[[i]]
    loglik[[i]]<- preds[[i]]-log(sum(exp(preds[ti])))
  }
  
  err <- sum(loglik)
  
  end_time <- Sys.time()
  #print(paste("ltcoxevalerror:",end_time-start_time))
  return(list(metric = "log partial likelihood", value = err)) 
}




#3. how to run - assumes ds_d is input data set with observation id's (id), start time (start), stop time (stop), and death indicator(delta)

require(dplyr)

ds_train_ids <- ds_d %>% select(id) %>% sample_frac(size=0.75)
ds_train <- ds_d %>% semi_join(ds_train_ids)
ds_test <- ds_d %>% anti_join(ds_train_ids)

# strictly speaking, these labels are not needed or used, since we use delta passed as dmatrix atribute
ds_train_l <-  ( ds_train %>% select(delta))[[1]]
ds_test_l <- (ds_test %>% select(delta ) )[[1]]
ds_all_l <- (ds_d%>% select(delta))[[1]]

# VERY IMPORTANT: remove id, start/stop times and delta (outcome!!) from DMatrix! We don't want to use these as predictors.
ds_all_m <- xgb.DMatrix( as.matrix(ds_d %>% select(-id, -start, -stop, -delta)), label=ds_all_l)
ds_train_m <- xgb.DMatrix( as.matrix(ds_train %>% select(-id, -start, -stop, -delta)), label=ds_train_l)
ds_test_m <- xgb.DMatrix( as.matrix(ds_test%>% select(-id, -start, -stop, -delta)), label=ds_test_l)

ds_list <- list(ds_d=ds_d,train_ids=ds_train_ids,train=ds_train,test=ds_test, train_l=ds_train_l, test_l=ds_test_l, all_l=ds_all_l, train_m=ds_train_m, test_m=ds_test_m, all_m=ds_all_m) 



#create lookup tables for risk set and its inverse for training and test sets - to be used by custom xgboost objective and evaluation functions

train_riskSet <- getRiskSets( starts=ds_list$train$start, stops=ds_list$train$stop, delta= ds_list$train$delta )
train_riskSetsContaining <- getRiskSetsContaining( starts=ds_list$train$start, stops=ds_list$train$stop, delta= ds_list$train$delta )
train_delta <- ds_list$train$delta

test_riskSet <- getRiskSets( starts=ds_list$test$start, stops=ds_list$test$stop, delta= ds_list$test$delta )
test_riskSetsContaining <- getRiskSetsContaining( starts=ds_list$test$start, stops=ds_list$test$stop, delta= ds_list$test$delta )
test_delta <- ds_list$test$delta


#attach these lookup tables and delta flags to the Dmatrices to be used in custom xgboost calls

attr(ds_list$train_m,'i_to_R_ti') <- train_riskSet
attr(ds_list$train_m,'k_to_i_k_in_R_ti') <- train_riskSetsContaining
attr(ds_list$train_m, 'delta') <- train_delta 

attr(ds_list$test_m,'i_to_R_ti') <- test_riskSet
attr(ds_list$test_m,'k_to_i_k_in_R_ti') <- test_riskSetsContaining
attr(ds_list$test_m, 'delta') <- test_delta 

# set up parameters and custom objective/eval functions
test_params<- list(booster="gbtree", objective=ltcoxgradhess,eval_metric=ltcoxevalerror, eta=0.2, gamma=0, max_depth=4, min_child_weight=1, subsample=0.5, colsample_bytree=1, nthread=8)

# fit the model with early stopping
test_xgb <- xgb.train(params = test_params, data =ds_list$train_m, nrounds = 2001, watchlist = list(train=ds_list$train_m, eval=ds_list$test_m), print_every_n = 10, 
                      early_stopping_rounds = 50, maximize = TRUE)

Hi. I would be interested to know if there has been any update to this post (i.e., addition of left-truncation for the cox model, such as is implemented in gbm3).

Alternatively (and mathematically equivalently), if an implementation of conditional logit has been added yet?

It is not difficult to implement, and would be handy. There are only 2 changes from binary logistic:

  1. link function is exp(x) instead of logistic exp(x)/(1+exp(x))
  2. renormalisation across strata (or groups at risk) is required for estimation of p

All other things are the same: loss: -log§ g=-(y-p) h=p(1-p)

Hope this helps.