Save and Load model in XGBoost4j with Databricks DBFS


I am using Databricks (Spark 2.4.4), and XGBoost4J - 0.9.

I am able to save my model into an S3 bucket (using the dbutils.fs.cp after saved it in the local file system), however I can’t load it.

Code and errors are below:

val trainedModel = // train model on pipeline (vectorAssembler + xgbregressor)

create directory to save the pipeline (again, model + vecotr) -
val trainedModelPath = "/dbfs/tmp/test-sage/m"

Save model in a specific way with -

Then, copy from dbfs to S3
dbutils.fs.cp("/tmp/test-sage/m", "/mnt/S3/XXXX-data-science/sandbox/save-test-xgboost/model")

I see a file in S3 called model (see screenshot attached) -

However, when I tried to load using -
val xgb = XGBoostRegressor.load("/mnt/S3/XXXX-data-science/sandbox/save-test-xgboost/model")
I get the error -
org.apache.hadoop.mapred.InvalidInputException: Input path does not exist: /mnt/S3/XXXX-data-science/sandbox/save-test-xgboost/model/metadata

i.e. the saved model is 1 file ( < 1MB) and no metadata file is saved alongside it.

Am I doing something wrong?

From a quick search I see this is a real pitty with XGBoost4J and Spark, so if this will be solved, I will be more than welcome to write a detailed documentation and create a relevant PR for it.


@hcho3 FYI

@hcho3 can you assist as always? Thanks in advance :slight_smile:

I have no idea. Can you raise the issue with Databricks?

will try them as well. Thank you.

@hcho3 is it even possible to save an XGBoost4J (Spark) model as a pickle? If so, can you elaborate about the best practice how to do it? Thanks

@hcho3, I think that is relates to XGBoost4J -

When trying to save a pipeline (with XGBoost4J model), I get an error java.lang.ClassCastException: java.lang.Integer cannot be cast to java.lang.Long

Would love to hear from you about it. Thank you once again

Maybe try using latest XGBoost snapshot? Again I have no clue what’s going on.

Hey @hcho3, it is the latest.

XGBoost 0.90 is not the latest. We have 1.0 in Maven Central and 1.1 in our private Maven repo.

Oh, got it. Will try 1.0 (a more public) and will update.

How to access snapshot version:

I am using Databricks so will get it from Maven directly. thanks

Solved it, was related to dbfs.