# 1. Functions to pre-calculate risk sets and inverse map
#return index of greatest value in sorted_vals < or <= val, restricting search to indices between min and max
binSearch <- function( sorted_vals, val, min, max, strict=FALSE ){
#print( paste("min", min, "max", max))
width <- max-min+1
if ( min > max ){
print(paste("Error: min > max (", min, max,")"))
return(NA)
}
if (width == 1) {
if(strict){
if ( sorted_vals[min] < val){
return (min)
} else{
return(NA)
}
}else{
if (sorted_vals[min] <= val){
return (min)
} else {
return(NA)
}
}
}
if (width %% 2){
#odd
mid = min + (width-1)/2
} else {
# even
mid = min + width/2
}
#print(paste("mid",mid))
if(strict){
cond <- (sorted_vals[mid] < val)
}else{
cond <- (sorted_vals[mid] <= val)
}
if (cond){
return( binSearch(sorted_vals, val, mid, max, strict) )
} else {
return( binSearch(sorted_vals, val, min, mid-1, strict) )
}
}
getRiskSets <- function (starts, stops, delta){
starts.sorted <- sort(starts,index.return=TRUE)
stops.sorted <- sort(stops,index.return=TRUE)
nstarts <- length(starts)
start_time <- Sys.time()
overlaps <- vector(mode="list", length=nstarts )
ov1 <- vector(mode="logical", length=nstarts)
ov2 <- vector(mode="logical", length=nstarts)
ov2[1:nstarts] <- TRUE # flag vector for intervals with stops >= current stop
ov1[1:nstarts] <- FALSE # flag vector for intervals with starts <= current stop
laststartmax=1
first = TRUE
laststopi=1
for(i in 1:nstarts){
stop_i <- stops.sorted$ix[i]
if( !delta[stop_i] ){
next
}
t_stop <- (stops.sorted$x)[i]
start_max <- binSearch( starts.sorted$x, t_stop, laststartmax, nstarts)
# might be faster to just count up, but not sure, depends on density of intervals:
# for(j in laststartmax:nstarts){
# if (starts.sorted$x[j] > t_stop){
# start_max = j-1
# break
# }
# }
#start_max <- findInterval( t_stop,starts.sorted$x,left.open=TRUE)
if( is.na(start_max) ){
next
}
# start_is <- starts.sorted$ix[1:start_max]
# stop_is <- stops.sorted$ix[i:nstarts]
#overlaps[stop_i] <- list(intersect(start_is, stop_is))
# ov1[1:nstarts] <- FALSE
# ov1[starts.sorted$ix[1:start_max]] <- TRUE
if( first){
ov1[starts.sorted$ix[1:start_max]] <- TRUE
first=FALSE
} else if(start_max > laststartmax){ #only this branch should ever occur after first loop
#print(paste(i, start_max, ">", laststartmax))
ov1[starts.sorted$ix[(laststartmax+1):start_max]] <- TRUE # only update indices that have changed since last loop
}else if(start_max < laststartmax){ # this should throw an error, since we are assuming above in binsearch that start_max >= laststartmax, should be true by sort of starts and stops
#print(paste(i,"<"))
#ov1[starts.sorted$ix[(start_max+1):laststartmax]] <- FALSE
print(paste("ERROR: start_max < laststartmax at iteration ",i ))
}
#if( i < 5) { print(which(ov1))}
laststartmax <- start_max
#ov2[stops.sorted$ix[i:nstarts]] <- TRUE
if( i > 1){
ov2[stops.sorted$ix[laststopi:i-1]] <- FALSE
laststopi <-i
}
#ov2[stops.sorted$ix[-i:-nstarts]] <- FALSE
# final step - make list over intervals that overlap stop_i by a quick vector operation
overlaps[stop_i] <- list ( which(ov1 & ov2))
}
end_time <- Sys.time()
print(end_time-start_time)
#m2<-overlaps
return(overlaps)
}
# much simpler though slower code
getRiskSetsContaining <- function(starts, stops, delta){
start_time <- Sys.time()
nstarts <- length(starts)
risksets <- vector(mode="list", length=nstarts)
for(i in 1:nstarts){
start <- starts[i]
stop <- stops[i]
ivec<- ( (start <= stops) & (stops <= stop) & delta) # should delta be here?
risksets[i] <- list(which(ivec))
}
end_time <- Sys.time()
print(end_time-start_time)
return(risksets)
}
#2. Custom xgb objective and eval functions
# Assume these are given as lookup tables - risk set map and its inverse
#i_to_R_ti # list of vectors; if delta_i = 0, null; if delta_i=1, Risk(t_i)
#k_to_i_k_in_R_ti # list of vectors, for each index k containing all i s.t. delta_i=1 and k is in R(t_i)
#delta # death indicator
ltcoxgradhess <- function(preds, dtrain){
start_time <- Sys.time()
i_to_R_ti <- attr(dtrain,'i_to_R_ti')
k_to_i_k_in_R_ti <- attr(dtrain,'k_to_i_k_in_R_ti')
delta <- attr(dtrain, 'delta')
#y <- attr(dtrain, 'y')
n <- length(delta)
r <- vector(length=n, mode="numeric")
rsq <- vector(length=n, mode="numeric")
exp_r_sum <- vector(length=n, mode="numeric")
for( i in 1:n){
if(!delta[[i]]){
next
}
# R_ti <- i_to_R_ti[[i]]
#exp_r_sum[[i]] <- sum(exp(preds[R_ti]))
exp_r_sum[[i]] <- sum(exp(preds[i_to_R_ti[[i]]]))
}
for( k in 1:n){
ti <- k_to_i_k_in_R_ti[[k]]
exp_r_sum_ti <- exp_r_sum[ti]
r[[k]] <- sum(1/exp_r_sum_ti)
rsq[[k]] <- sum( 1/ (exp_r_sum_ti*exp_r_sum_ti))
}
exp_preds <- exp(preds)
exp_preds_r <- exp_preds*r
grad <- (-1)*(delta - exp_preds_r)
hess <- exp_preds_r - exp_preds*exp_preds*rsq
end_time <- Sys.time()
#print(paste("ltcoxgradhess:",end_time-start_time))
return(list(grad=grad,hess=hess))
}
# now create custom evaluation error function - i.e. log partial likelhood
ltcoxevalerror <- function(preds, dtrain) {
start_time <- Sys.time()
i_to_R_ti <- attr(dtrain,'i_to_R_ti')
#k_to_i_k_in_R_ti <- attr(dtrain,'k_to_i_k_in_R_ti')
delta <- attr(dtrain, 'delta')
n<- length(delta)
loglik = vector(length=n, mode="numeric")
for( i in 1:n){
if( !delta[[i]]){
next
}
ti <- i_to_R_ti[[i]]
loglik[[i]]<- preds[[i]]-log(sum(exp(preds[ti])))
}
err <- sum(loglik)
end_time <- Sys.time()
#print(paste("ltcoxevalerror:",end_time-start_time))
return(list(metric = "log partial likelihood", value = err))
}
#3. how to run - assumes ds_d is input data set with observation id's (id), start time (start), stop time (stop), and death indicator(delta)
require(dplyr)
ds_train_ids <- ds_d %>% select(id) %>% sample_frac(size=0.75)
ds_train <- ds_d %>% semi_join(ds_train_ids)
ds_test <- ds_d %>% anti_join(ds_train_ids)
# strictly speaking, these labels are not needed or used, since we use delta passed as dmatrix atribute
ds_train_l <- ( ds_train %>% select(delta))[[1]]
ds_test_l <- (ds_test %>% select(delta ) )[[1]]
ds_all_l <- (ds_d%>% select(delta))[[1]]
# VERY IMPORTANT: remove id, start/stop times and delta (outcome!!) from DMatrix! We don't want to use these as predictors.
ds_all_m <- xgb.DMatrix( as.matrix(ds_d %>% select(-id, -start, -stop, -delta)), label=ds_all_l)
ds_train_m <- xgb.DMatrix( as.matrix(ds_train %>% select(-id, -start, -stop, -delta)), label=ds_train_l)
ds_test_m <- xgb.DMatrix( as.matrix(ds_test%>% select(-id, -start, -stop, -delta)), label=ds_test_l)
ds_list <- list(ds_d=ds_d,train_ids=ds_train_ids,train=ds_train,test=ds_test, train_l=ds_train_l, test_l=ds_test_l, all_l=ds_all_l, train_m=ds_train_m, test_m=ds_test_m, all_m=ds_all_m)
#create lookup tables for risk set and its inverse for training and test sets - to be used by custom xgboost objective and evaluation functions
train_riskSet <- getRiskSets( starts=ds_list$train$start, stops=ds_list$train$stop, delta= ds_list$train$delta )
train_riskSetsContaining <- getRiskSetsContaining( starts=ds_list$train$start, stops=ds_list$train$stop, delta= ds_list$train$delta )
train_delta <- ds_list$train$delta
test_riskSet <- getRiskSets( starts=ds_list$test$start, stops=ds_list$test$stop, delta= ds_list$test$delta )
test_riskSetsContaining <- getRiskSetsContaining( starts=ds_list$test$start, stops=ds_list$test$stop, delta= ds_list$test$delta )
test_delta <- ds_list$test$delta
#attach these lookup tables and delta flags to the Dmatrices to be used in custom xgboost calls
attr(ds_list$train_m,'i_to_R_ti') <- train_riskSet
attr(ds_list$train_m,'k_to_i_k_in_R_ti') <- train_riskSetsContaining
attr(ds_list$train_m, 'delta') <- train_delta
attr(ds_list$test_m,'i_to_R_ti') <- test_riskSet
attr(ds_list$test_m,'k_to_i_k_in_R_ti') <- test_riskSetsContaining
attr(ds_list$test_m, 'delta') <- test_delta
# set up parameters and custom objective/eval functions
test_params<- list(booster="gbtree", objective=ltcoxgradhess,eval_metric=ltcoxevalerror, eta=0.2, gamma=0, max_depth=4, min_child_weight=1, subsample=0.5, colsample_bytree=1, nthread=8)
# fit the model with early stopping
test_xgb <- xgb.train(params = test_params, data =ds_list$train_m, nrounds = 2001, watchlist = list(train=ds_list$train_m, eval=ds_list$test_m), print_every_n = 10,
early_stopping_rounds = 50, maximize = TRUE)