RabitTracker __init__ error when training a SparkXGBClassifier on EMR

Hi,

I’m trying to test xgboost models on an EMR serverless instance (version 6.12.0) using the following python script, which works locally on my linux environment:

from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline
import pyspark
import xgboost
from xgboost.spark import SparkXGBClassifier
import logging
import sys
import os


spark = SparkSession.builder.appName("XGBoostExample").getOrCreate()

# Create a random DataFrame with two features ("feature1", "feature2") and one label ("label")
data = [(i, i * 2, 1 if i % 2 == 0 else 0) for i in range(100)]
columns = ["feature1", "feature2", "label"]
df = spark.createDataFrame(data, columns)

# Assemble features into a vector column
feature_columns = ["feature1", "feature2"]
vector_assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
assembled_df = vector_assembler.transform(df)

# Split data into training and testing sets
train_data, test_data = assembled_df.randomSplit([0.8, 0.2], seed=123)

# Initialize XGBoost classifier
xgb_classifier = SparkXGBClassifier()

# Create a pipeline for training
pipeline = Pipeline(stages=[xgb_classifier])

# Train the model
model = pipeline.fit(train_data)

# Evaluate the model on the test data
predictions = model.transform(test_data)

# Show prediction results
predictions.select("label", "prediction", "probability").show()

# Save the XGBoost model
model_path = "xgboost_model"
model.write().overwrite().save(model_path)

# Stop the Spark session
spark.stop()

When deploying it on EMR, I get the following error:

  File "/tmp/spark-50d59c7a-a3d7-4d22-a320-9b8ea43dac19/test_xgboost.py", line 42, in <module>
    model = pipeline.fit(train_data)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py", line 205, in fit
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/pipeline.py", line 134, in _fit
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/ml/base.py", line 205, in fit
  File "/home/hadoop/environment/lib/python3.10/site-packages/xgboost/spark/core.py", line 837, in _fit
    (config, booster) = _run_job()
  File "/home/hadoop/environment/lib/python3.10/site-packages/xgboost/spark/core.py", line 833, in _run_job
    .collect()[0]
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/rdd.py", line 1814, in collect
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1322, in __call__
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/errors/exceptions/captured.py", line 169, in deco
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/protocol.py", line 326, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(5, 0) finished unsuccessfully.
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/hadoop/environment/lib/python3.10/site-packages/xgboost/spark/core.py", line 790, in _train_booster
    _rabit_args = _get_rabit_args(context, num_workers)
  File "/home/hadoop/environment/lib/python3.10/site-packages/xgboost/spark/utils.py", line 77, in _get_rabit_args
    env = _start_tracker(context, n_workers)
  File "/home/hadoop/environment/lib/python3.10/site-packages/xgboost/spark/utils.py", line 63, in _start_tracker
    rabit_context = RabitTracker(host_ip=host, n_workers=n_workers)
  File "/home/hadoop/environment/lib/python3.10/site-packages/xgboost/tracker.py", line 208, in __init__
    sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
  File "/home/hadoop/environment/lib/python3.10/site-packages/xgboost/tracker.py", line 67, in get_family
    return socket.getaddrinfo(addr, None)[0][0]
  File "/home/hadoop/environment/lib/python3.10/socket.py", line 955, in getaddrinfo
    for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
socket.gaierror: [Errno -2] Name or service not known

I updated the python interpreter on the EMR to 3.10.12, and when I try with models that are natively present in pyspark everything works fine.
pyspark version on EMR is 3.4.0+amzn.0 and I’m using xgboost version 1.7.0

My understanding is that the Rabit Tracker, which is needed to manage the communications between driver and executors, fails to start because of a connection problem, but I don’t know how to fix this.

Thank you!