XGBoost 4J spark giving XGBoostError: std::bad_alloc on databricks


I am using XGBoost 4J spark to create a distributed xgboost model on my data. I am developing my model in databricks.

The spark version is 3.1.1, scala 2.12 and XGBoost 4J 1.4.1

My cluster setup looks as below,

  • 5 worker nodes with each worker of size 32GB and 16 cores

  • Driver node with 14GB of memory and 4 cores

My cluster configuration looks as below,
spark.executor.memory 9g
spark.executor.cores 5

So basically I have 10 executors with 4.6GB memory and 1 driver with 3.3GB of memory.

I imported the package as below,
import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel,XGBoostRegressor}

In order to find the best parameters for my model, I created a parameter grid with train-validation split as shown below,

//Parameter tuning
    import org.apache.spark.ml.tuning._
    import org.apache.spark.ml.PipelineModel
    import Array._

//Create parameter grid 
    val paramGrid = new ParamGridBuilder()
        .addGrid(xgbRegressor.maxDepth, range(6, 10, 2))
        .addGrid(xgbRegressor.eta, Array(0.01))
        .addGrid(xgbRegressor.minChildWeight, Array(8.0, 10.0, 12.0, 14.0))
        .addGrid(xgbRegressor.gamma, Array(0, 0.25, 0.5, 1))

I then fit it to my data and saved it,

    val trainValidationSplit = new TrainValidationSplit()

val tvmodel = trainValidationSplit.fit(train)


The error comes up when I try to load the model again,

    import org.apache.spark.ml.tuning._
    val tvmodel = TrainValidationSplitModel.load("spark-train-validation-split-28072021")

The error message is
XGBoostError: std::bad_alloc

I checked the executor and driver logs. The executor logs looked fine . I found the same error in driver logs (stderr and log4j files). Both the log files are avilable in this link log_files

Since the error message was mainly found in the driver logs, I tried the following solutions,

  • Increased the driver memory to 28GB

  • Increased the driver cores to 8

  • Made the driver same as worker.

But all the above failed. The log files clearly indicate that driver memory is not overloaded. Hence I am struggling to find out what the error actually is.

It would be great if one of you could point me in the right direction.