Hello,
I am using XGBoost 4J spark to create a distributed xgboost model on my data. I am developing my model in databricks.
The spark version is 3.1.1, scala 2.12 and XGBoost 4J 1.4.1
My cluster setup looks as below,
-
5 worker nodes with each worker of size 32GB and 16 cores
-
Driver node with 14GB of memory and 4 cores
My cluster configuration looks as below,
spark.executor.memory 9g
spark.executor.cores 5
So basically I have 10 executors with 4.6GB memory and 1 driver with 3.3GB of memory.
I imported the package as below,
import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel,XGBoostRegressor}
In order to find the best parameters for my model, I created a parameter grid with train-validation split as shown below,
//Parameter tuning
import org.apache.spark.ml.tuning._
import org.apache.spark.ml.PipelineModel
import Array._
//Create parameter grid
val paramGrid = new ParamGridBuilder()
.addGrid(xgbRegressor.maxDepth, range(6, 10, 2))
.addGrid(xgbRegressor.eta, Array(0.01))
.addGrid(xgbRegressor.minChildWeight, Array(8.0, 10.0, 12.0, 14.0))
.addGrid(xgbRegressor.gamma, Array(0, 0.25, 0.5, 1))
.build()
I then fit it to my data and saved it,
val trainValidationSplit = new TrainValidationSplit()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setTrainRatio(0.75)
val tvmodel = trainValidationSplit.fit(train)
tvmodel.write.overwrite().save("spark-train-validation-split-28072021")
The error comes up when I try to load the model again,
import org.apache.spark.ml.tuning._
val tvmodel = TrainValidationSplitModel.load("spark-train-validation-split-28072021")
The error message is
XGBoostError: std::bad_alloc
I checked the executor and driver logs. The executor logs looked fine . I found the same error in driver logs (stderr and log4j files). Both the log files are avilable in this link log_files
Since the error message was mainly found in the driver logs, I tried the following solutions,
-
Increased the driver memory to 28GB
-
Increased the driver cores to 8
-
Made the driver same as worker.
But all the above failed. The log files clearly indicate that driver memory is not overloaded. Hence I am struggling to find out what the error actually is.
It would be great if one of you could point me in the right direction.
Thanks