XGBoost failing with Rabit error on AWS EMR Serverless

Environment

AWS EMR Serverless 7.0.0
PySpark 3.5.0
XGBoost 2.0.3

I’m using XBoost for regression, specifically the SparkXGBRegressor. I’m able to use it without issues on my local machine. However, I get the following Rabit-related error when executing it on AWS EMR Serverless. Any ideas? Is it related to this?:

2024-05-28 21:54:33.922 [INFO] XGBoost-PySpark: Running xgboost-2.0.3 on 1 workers with
booster params: {‘objective’: ‘reg:squarederror’, ‘device’: ‘cpu’, ‘learning_rate’: 0.03, ‘max_depth’: 5, ‘random_state’: 0, ‘subsample’: 1.0, ‘eval_metric’: ‘rmse’, ‘nthread’: 8}
train_call_kwargs_params: {‘early_stopping_rounds’: 50, ‘verbose_eval’: True, ‘num_boost_round’: 10}
dmatrix_kwargs: {‘nthread’: 8, ‘missing’: 0.0}
Traceback (most recent call last):
File “/tmp/spark-86c93d08-da51-4c1d-9834-689be49aad15/train_model.py”, line 436, in
main()
File “/home/hadoop/environment/lib64/python3.9/site-packages/my_project/utils/logger_utils.py”, line 38, in wrapper
return func(*args, **kwargs)
File “/home/hadoop/environment/lib64/python3.9/site-packages/click/core.py”, line 1157, in call
return self.main(*args, **kwargs)
File “/home/hadoop/environment/lib64/python3.9/site-packages/click/core.py”, line 1078, in main
rv = self.invoke(ctx)
File “/home/hadoop/environment/lib64/python3.9/site-packages/click/core.py”, line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File “/home/hadoop/environment/lib64/python3.9/site-packages/click/core.py”, line 783, in invoke
return __callback(*args, **kwargs)
File “/tmp/spark-86c93d08-da51-4c1d-9834-689be49aad15/train_model.py”, line 326, in main
train_model(
File “/home/hadoop/environment/lib64/python3.9/site-packages/my_project/utils/logger_utils.py”, line 55, in timed
result: Callable = method(*args, **kw)
File “/home/hadoop/environment/lib64/python3.9/site-packages/my_project/utils/logger_utils.py”, line 38, in wrapper
return func(*args, **kwargs)
File “/tmp/spark-86c93d08-da51-4c1d-9834-689be49aad15/train_model.py”, line 410, in train_model
model_training_strategy.train_model(
File “/home/hadoop/environment/lib64/python3.9/site-packages/my_project/pyspark_utils/model_processor/xgboost_pipeline_model_training_strategy.py”, line 38, in train_model
return model_training_pipeline.fit(my_project_dataframe)
File “/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py”, line 205, in fit
File “/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/pipeline.py”, line 134, in _fit
File “/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py”, line 205, in fit
File “/home/hadoop/environment/lib64/python3.9/site-packages/xgboost/spark/core.py”, line 1136, in _fit
(config, booster) = _run_job()
File “/home/hadoop/environment/lib64/python3.9/site-packages/xgboost/spark/core.py”, line 1122, in _run_job
ret = rdd_with_resource.collect()[0]
File “/usr/lib/spark/python/lib/pyspark.zip/pyspark/rdd.py”, line 1833, in collect
File “/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py”, line 1322, in call
File “/usr/lib/spark/python/lib/pyspark.zip/pyspark/errors/exceptions/captured.py”, line 179, in deco
File “/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/protocol.py”, line 326, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(314, 0) finished unsuccessfully.
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File “/home/hadoop/environment/lib64/python3.9/site-packages/xgboost/spark/core.py”, line 1067, in _train_booster
_rabit_args = _get_rabit_args(context, num_workers)
File “/home/hadoop/environment/lib64/python3.9/site-packages/xgboost/spark/utils.py”, line 77, in _get_rabit_args
env = _start_tracker(context, n_workers)
File “/home/hadoop/environment/lib64/python3.9/site-packages/xgboost/spark/utils.py”, line 66, in _start_tracker
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers)
File “/home/hadoop/environment/lib64/python3.9/site-packages/xgboost/tracker.py”, line 208, in init
sock.bind((host_ip, port))
OSError: [Errno 99] Cannot assign requested address

at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:118)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.hasNext(SerDeUtil.scala:86)
at scala.collection.Iterator.foreach(Iterator.scala:943)
at scala.collection.Iterator.foreach$(Iterator.scala:943)
at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.foreach(SerDeUtil.scala:80)
at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:322)
at org.apache.spark.api.python.PythonRunner$$anon$2.writeIteratorToStream(PythonRunner.scala:751)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:451)
at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1962)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:282)

at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3067)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3003)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3002)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3002)
at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:2326)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3265)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3205)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3194)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1041)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2406)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2427)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2446)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2471)
at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1046)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
at org.apache.spark.rdd.RDD.collect(RDD.scala:1045)
at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:195)
at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:568)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
at java.base/java.lang.Thread.run(Thread.java:840)

2024-05-28 21:54:39.439 [INFO] py4j.clientserver: Closing down clientserver connection