Support for left truncated data / time-dependent covariates for Cox regression


#1

Hi there,

I’m interested in using xgboost to do survival analysis with left-truncated data and time-dependent covariates. My understanding is that the current version does not support either, and a quick skim of the built-in Cox regression object
in https://github.com/dmlc/xgboost/blob/master/src/objective/regression_obj.cu seems to confirm this.

For those who are more familiar with the architecture of the program, would it be difficult to add support for left truncated data and time-dependent covariates? Conceptually, I believe that all that would be required is an adjustment of the risk set based on the entry time (left-truncation time) for each data point, and then time-dependent covariates come along for free; other than this, gradients and Hessians should be calculated similarly as currently. Is there an architectural reason that this would be difficult to implement? Maybe xgb doesn’t support passing in a second parameter for left-truncation times (in addition to the single “outcome” variable)?

Any advice should I try to implement this myself?

Thanks


#2

As you said, currently, XGBoost doesn’t support a second parameter. Is there any way to process the preds argument so as to take account of time truncation?


#3

Hi,

I have begun implementing this externally as a custom objective + evaluation function in R, using the ability to add attributes to xgb.dmatrix objects to store pre-calculated risk sets as lookup tables (to eliminate the need to recalculate the risk sets based on truncation times at each objective/evaluation call). For using attributes to inform custom objective functions (allowing the effective passing of more parameters), see e.g.https://github.com/dmlc/xgboost/blob/master/R-package/demo/custom_objective.R. This approach seems reasonably efficient though much slower than the non-truncated case because the risk sets are not simply sortable, as in the very efficient way the built-in code handles non-truncated cox regression.

There may be a better algorithm or implementation choice than the one I’m working on, since I am never sure my R code is as fast as it could be, and other languages may be more efficient or flexible. On a data set of approx. 100,000 points, with about 1/5 “death” events, the objective function takes about 10 seconds to evaluate. It’s a bit too slow to be convenient. Parallelization may help, but I haven’t worked out how to implement that effectively.

Unfortunately, I’ve had to abandon work on this because of other priorities at work. If you are interested, I could share the code I have so far for the objective function and risk set construction.


#4

It would be interesting if this GSOC project gets picked up: https://github.com/dmlc/xgboost/issues/4242


#5

Yes, that would be great!
For those interested, here is my code. It has not been checked carefully for accuracy, but I would welcome any corrections.

# 1. Functions to pre-calculate risk sets and inverse map
#return index of greatest value in sorted_vals  < or <= val, restricting search to indices between min and max
binSearch <- function( sorted_vals, val, min, max, strict=FALSE ){
  #print( paste("min", min, "max", max))
  width <- max-min+1
  if ( min > max ){
    print(paste("Error: min > max (", min, max,")"))
    return(NA)
  }
  if (width == 1) {
    if(strict){
      if ( sorted_vals[min] < val){
        return (min)
      } else{
        return(NA)
      }
    }else{
      if (sorted_vals[min] <= val){
        return (min)
      } else {
        return(NA)
      }
    }
  }
  
  
  if (width %% 2){
    #odd
    mid = min + (width-1)/2
  } else {
    # even
    mid = min + width/2
  }  
  #print(paste("mid",mid))
  if(strict){
    cond <- (sorted_vals[mid] < val)
  }else{
    cond <- (sorted_vals[mid] <= val)
  }
  
  if (cond){
    return( binSearch(sorted_vals, val, mid, max, strict) )
  } else {
    return( binSearch(sorted_vals, val, min, mid-1, strict) )
  }
  
}

getRiskSets <- function (starts, stops, delta){
  starts.sorted <- sort(starts,index.return=TRUE)
  stops.sorted <- sort(stops,index.return=TRUE)
  nstarts <- length(starts)
  
  start_time <- Sys.time()
  overlaps <- vector(mode="list", length=nstarts )
  ov1 <- vector(mode="logical", length=nstarts)
  ov2 <- vector(mode="logical", length=nstarts)
  ov2[1:nstarts] <- TRUE  # flag vector for intervals with stops >= current stop 
  ov1[1:nstarts] <- FALSE # flag vector for intervals with starts <= current stop
  laststartmax=1
  first = TRUE
  laststopi=1
  for(i in 1:nstarts){
    stop_i <- stops.sorted$ix[i]
    if( !delta[stop_i] ){
      next
    }
    t_stop <- (stops.sorted$x)[i]
    start_max <- binSearch( starts.sorted$x, t_stop, laststartmax, nstarts) 
    # might be faster to just count up, but not sure, depends on density of intervals:
    # for(j in laststartmax:nstarts){
    #    if (starts.sorted$x[j] > t_stop){
    #      start_max = j-1
    #      break
    #    }
    #  }
    
    #start_max <- findInterval( t_stop,starts.sorted$x,left.open=TRUE)
    if( is.na(start_max) ){
      next
    }
    #  start_is <- starts.sorted$ix[1:start_max]
    # stop_is <- stops.sorted$ix[i:nstarts]
    
    #overlaps[stop_i] <- list(intersect(start_is, stop_is))
    # ov1[1:nstarts] <- FALSE 
    #  ov1[starts.sorted$ix[1:start_max]] <- TRUE
    if( first){
      ov1[starts.sorted$ix[1:start_max]] <- TRUE
      first=FALSE
    } else if(start_max > laststartmax){ #only this branch should ever occur after first loop
      #print(paste(i, start_max, ">", laststartmax))  
      ov1[starts.sorted$ix[(laststartmax+1):start_max]] <- TRUE  # only update indices that have changed since last loop
    }else if(start_max < laststartmax){ # this should throw an error, since we are assuming above in binsearch that start_max >= laststartmax, should be true by sort of starts and stops
      #print(paste(i,"<"))
      #ov1[starts.sorted$ix[(start_max+1):laststartmax]] <- FALSE
      print(paste("ERROR: start_max < laststartmax at iteration ",i ))
    }
    #if( i < 5) { print(which(ov1))}
    laststartmax <- start_max
    
    #ov2[stops.sorted$ix[i:nstarts]] <- TRUE
    if( i > 1){
      ov2[stops.sorted$ix[laststopi:i-1]] <- FALSE
      laststopi <-i
    }
    #ov2[stops.sorted$ix[-i:-nstarts]] <- FALSE
    
    # final step - make list over intervals that overlap stop_i by a quick vector operation
    overlaps[stop_i] <- list ( which(ov1 & ov2))
    
  }
  
  end_time <- Sys.time()
  print(end_time-start_time)
  #m2<-overlaps
  return(overlaps)
}

# much simpler though slower code
getRiskSetsContaining <- function(starts, stops, delta){
  start_time <- Sys.time()
  nstarts <- length(starts)
  risksets <- vector(mode="list", length=nstarts)
  
  for(i in 1:nstarts){
    start <- starts[i]
    stop <- stops[i]
    
    ivec<- ( (start <= stops) & (stops <= stop) & delta)  # should delta be here?
    
    risksets[i] <- list(which(ivec))
    
  }
  
  end_time <- Sys.time()
  print(end_time-start_time)
  return(risksets)
}


#2. Custom xgb objective and eval functions

# Assume these are given as lookup tables - risk set map and its inverse
#i_to_R_ti # list of vectors; if delta_i = 0, null; if delta_i=1, Risk(t_i) 
#k_to_i_k_in_R_ti  # list of vectors, for each index k containing all i s.t. delta_i=1 and k is in R(t_i) 
#delta # death indicator


ltcoxgradhess <- function(preds, dtrain){
  start_time <- Sys.time()
  i_to_R_ti <- attr(dtrain,'i_to_R_ti')
  k_to_i_k_in_R_ti <- attr(dtrain,'k_to_i_k_in_R_ti')
  delta <- attr(dtrain, 'delta')
  
  #y <- attr(dtrain, 'y')
  n <- length(delta)
  r <- vector(length=n, mode="numeric")
  rsq <- vector(length=n, mode="numeric")
  exp_r_sum <- vector(length=n, mode="numeric")
  for( i in 1:n){
    if(!delta[[i]]){
      next
    }
    # R_ti <- i_to_R_ti[[i]]  
    #exp_r_sum[[i]] <- sum(exp(preds[R_ti]))
    exp_r_sum[[i]] <- sum(exp(preds[i_to_R_ti[[i]]]))
  }
  
  for( k in 1:n){
    ti <- k_to_i_k_in_R_ti[[k]]
    exp_r_sum_ti <- exp_r_sum[ti]
    r[[k]] <- sum(1/exp_r_sum_ti)
    rsq[[k]] <- sum( 1/ (exp_r_sum_ti*exp_r_sum_ti))
  }
  
  exp_preds <- exp(preds)
  exp_preds_r <- exp_preds*r
  grad <- (-1)*(delta - exp_preds_r)
  hess <- exp_preds_r - exp_preds*exp_preds*rsq
  
  end_time <- Sys.time()
  #print(paste("ltcoxgradhess:",end_time-start_time))
  
  return(list(grad=grad,hess=hess))
}


# now create custom evaluation error function - i.e. log partial likelhood 

ltcoxevalerror <- function(preds, dtrain) {
  start_time <- Sys.time()
  i_to_R_ti <- attr(dtrain,'i_to_R_ti')
  #k_to_i_k_in_R_ti <- attr(dtrain,'k_to_i_k_in_R_ti')
  delta <- attr(dtrain, 'delta')
  
  n<- length(delta)
  loglik = vector(length=n, mode="numeric")
  for( i in 1:n){
    if( !delta[[i]]){
      next
    }
    ti <- i_to_R_ti[[i]]
    loglik[[i]]<- preds[[i]]-log(sum(exp(preds[ti])))
  }
  
  err <- sum(loglik)
  
  end_time <- Sys.time()
  #print(paste("ltcoxevalerror:",end_time-start_time))
  return(list(metric = "log partial likelihood", value = err)) 
}




#3. how to run - assumes ds_d is input data set with observation id's (id), start time (start), stop time (stop), and death indicator(delta)

require(dplyr)

ds_train_ids <- ds_d %>% select(id) %>% sample_frac(size=0.75)
ds_train <- ds_d %>% semi_join(ds_train_ids)
ds_test <- ds_d %>% anti_join(ds_train_ids)

# strictly speaking, these labels are not needed or used, since we use delta passed as dmatrix atribute
ds_train_l <-  ( ds_train %>% select(delta))[[1]]
ds_test_l <- (ds_test %>% select(delta ) )[[1]]
ds_all_l <- (ds_d%>% select(delta))[[1]]

# VERY IMPORTANT: remove id, start/stop times and delta (outcome!!) from DMatrix! We don't want to use these as predictors.
ds_all_m <- xgb.DMatrix( as.matrix(ds_d %>% select(-id, -start, -stop, -delta)), label=ds_all_l)
ds_train_m <- xgb.DMatrix( as.matrix(ds_train %>% select(-id, -start, -stop, -delta)), label=ds_train_l)
ds_test_m <- xgb.DMatrix( as.matrix(ds_test%>% select(-id, -start, -stop, -delta)), label=ds_test_l)

ds_list <- list(ds_d=ds_d,train_ids=ds_train_ids,train=ds_train,test=ds_test, train_l=ds_train_l, test_l=ds_test_l, all_l=ds_all_l, train_m=ds_train_m, test_m=ds_test_m, all_m=ds_all_m) 



#create lookup tables for risk set and its inverse for training and test sets - to be used by custom xgboost objective and evaluation functions

train_riskSet <- getRiskSets( starts=ds_list$train$start, stops=ds_list$train$stop, delta= ds_list$train$delta )
train_riskSetsContaining <- getRiskSetsContaining( starts=ds_list$train$start, stops=ds_list$train$stop, delta= ds_list$train$delta )
train_delta <- ds_list$train$delta

test_riskSet <- getRiskSets( starts=ds_list$test$start, stops=ds_list$test$stop, delta= ds_list$test$delta )
test_riskSetsContaining <- getRiskSetsContaining( starts=ds_list$test$start, stops=ds_list$test$stop, delta= ds_list$test$delta )
test_delta <- ds_list$test$delta


#attach these lookup tables and delta flags to the Dmatrices to be used in custom xgboost calls

attr(ds_list$train_m,'i_to_R_ti') <- train_riskSet
attr(ds_list$train_m,'k_to_i_k_in_R_ti') <- train_riskSetsContaining
attr(ds_list$train_m, 'delta') <- train_delta 

attr(ds_list$test_m,'i_to_R_ti') <- test_riskSet
attr(ds_list$test_m,'k_to_i_k_in_R_ti') <- test_riskSetsContaining
attr(ds_list$test_m, 'delta') <- test_delta 

# set up parameters and custom objective/eval functions
test_params<- list(booster="gbtree", objective=ltcoxgradhess,eval_metric=ltcoxevalerror, eta=0.2, gamma=0, max_depth=4, min_child_weight=1, subsample=0.5, colsample_bytree=1, nthread=8)

# fit the model with early stopping
test_xgb <- xgb.train(params = test_params, data =ds_list$train_m, nrounds = 2001, watchlist = list(train=ds_list$train_m, eval=ds_list$test_m), print_every_n = 10, 
                      early_stopping_rounds = 50, maximize = TRUE)