Dask error : 'DMatrix' object has no attribute 'worker_map'

Hi !

I am training a model with the following setup:

train_DMatrix = xgb.dask.DMatrix(data=df_to_X(dataset_to_df_etl(train_dataset)), 
                            label=df_to_y(dataset_to_df_etl(train_dataset)),
                            enable_categorical=True,
                           )

validation_DMatrix = xgb.dask.DMatrix(data=df_to_X(dataset_to_df_etl(validation_dataset)), 
                                 label=df_to_y(dataset_to_df_etl(validation_dataset)),
                                 enable_categorical=True,
                           )

test_DMatrix = xgb.dask.DMatrix(data=df_to_X(dataset_to_df_etl(test_dataset)),
                           label=df_to_y(dataset_to_df_etl(test_dataset)),
                           enable_categorical=True,
                           )

def rsquared(preds, dmatrix):
    labels = dmatrix.get_label()
    ss_res = ((labels - preds) ** 2).sum()
    ss_tot = ((labels - labels.mean()) ** 2).sum()
    r2 = 1 - ss_res / ss_tot
    return 'r2', r2

# Parameters
params = {
    'max_depth': 3,
    'eta': 0.3,
    'objective': 'reg:squarederror',
    'nthread': -1,
    'eval_metric': ['rmse', rsquared]  
}

# Setting the number of boosting rounds and early stopping
num_boost_round = 100
early_stopping_rounds = int(0.15 * num_boost_round)

# Training with custom evaluation
evals = [(train_DMatrix, 'train'), (validation_DMatrix, 'valid')]
result = xgb.dask.train(client,
                        params,
                        dtrain=train_DMatrix,
                        num_boost_round=num_boost_round,
                        evals=evals,
                        early_stopping_rounds=early_stopping_rounds,
                        verbose_eval=True)

And I get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File <timed exec>:16

File /opt/conda/lib/python3.10/site-packages/xgboost/core.py:730, in require_keyword_args.<locals>.throw_if.<locals>.inner_f(*args, **kwargs)
    728 for k, arg in zip(sig.parameters, args):
    729     kwargs[k] = arg
--> 730 return func(**kwargs)

File /opt/conda/lib/python3.10/site-packages/xgboost/dask.py:1078, in train(client, params, dtrain, num_boost_round, evals, obj, feval, early_stopping_rounds, xgb_model, verbose_eval, callbacks, custom_metric)
   1076 client = _xgb_get_client(client)
   1077 args = locals()
-> 1078 return client.sync(
   1079     _train_async,
   1080     global_config=config.get_config(),
   1081     dconfig=_get_dask_config(),
   1082     **args,
   1083 )

File /opt/conda/lib/python3.10/site-packages/distributed/utils.py:358, in SyncMethodMixin.sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    356     return future
    357 else:
--> 358     return sync(
    359         self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    360     )

File /opt/conda/lib/python3.10/site-packages/distributed/utils.py:434, in sync(loop, func, callback_timeout, *args, **kwargs)
    431         wait(10)
    433 if error is not None:
--> 434     raise error
    435 else:
    436     return result

File /opt/conda/lib/python3.10/site-packages/distributed/utils.py:408, in sync.<locals>.f()
    406         awaitable = wait_for(awaitable, timeout)
    407     future = asyncio.ensure_future(awaitable)
--> 408     result = yield future
    409 except Exception as exception:
    410     error = exception

File /opt/conda/lib/python3.10/site-packages/tornado/gen.py:767, in Runner.run(self)
    765 try:
    766     try:
--> 767         value = future.result()
    768     except Exception as e:
    769         # Save the exception for later. It's important that
    770         # gen.throw() not be called inside this try/except block
    771         # because that makes sys.exc_info behave unexpectedly.
    772         exc: Optional[Exception] = e

File /opt/conda/lib/python3.10/site-packages/xgboost/dask.py:933, in _train_async(client, global_config, dconfig, params, dtrain, num_boost_round, evals, obj, feval, early_stopping_rounds, verbose_eval, xgb_model, callbacks, custom_metric)
    917 async def _train_async(
    918     client: "distributed.Client",
    919     global_config: Dict[str, Any],
   (...)
    931     custom_metric: Optional[Metric],
    932 ) -> Optional[TrainReturnT]:
--> 933     workers = _get_workers_from_data(dtrain, evals)
    934     await _check_workers_are_alive(workers, client)
    935     _rabit_args = await _get_rabit_args(len(workers), dconfig, client)

File /opt/conda/lib/python3.10/site-packages/xgboost/dask.py:872, in _get_workers_from_data(dtrain, evals)
    869 def _get_workers_from_data(
    870     dtrain: DaskDMatrix, evals: Optional[Sequence[Tuple[DaskDMatrix, str]]]
    871 ) -> List[str]:
--> 872     X_worker_map: Set[str] = set(dtrain.worker_map.keys())
    873     if evals:
    874         for e in evals:

AttributeError: 'DMatrix' object has no attribute 'worker_map'

The instantiation of the DMatrix object goes well, without any errors (and on the dask cluster, it does process all tasks ). However, when trying to do model training, it breaks. Any idea on what this may be?

Also referenced in the Dask discourse: https://dask.discourse.group/t/xgboost-error-dmatrix-object-has-no-attribute-worker-map/2702?u=rlourenco