Dask distributed GPU training with already sharded data

Hi, please pardon me if this has been discussed already, but I couldn’t find any examples of how to run a multi-node, multi-GPU training with Dask. The examples use LocalCudaCluster and are only limited to single-node training.

However, I found this sample and was able to run multi-node using the following steps:

  1. Start scheduler process in a master node, and start dask-cuda-worker processes in all nodes.
  2. Read the input into a DaskDMatrix, in the master node, and start a training job from that node.

From what I understand, the master node then distributes the data to all workers.

My question is if I have a lot of training data and I want to use AWS SageMaker’s distribution=“ShardedByS3Key” to divide the data into individual instances beforehand, how do I start the training job in such a way that each worker will read the data present in it’s instance? In other words, instead of loading all the data in one node and distributing it with DaskDMatrix, how do I load data in individual nodes during training?

Here’s my current code:

if current_host_ip == master_host_ip:
    with Client("") as client:
        # Loads the entire dataset in this master node, and divides it into train/valid.
        X_train, X_valid, y_train, y_valid = load_higgs_csv(args.train)
        dtrain = xgb.dask.DaskDMatrix(client, X_train, y_train)
        dvalid = xgb.dask.DaskDMatrix(client, X_valid, y_valid)

        watchlist = ([(dtrain, "train"), (dvalid, "validation")])
        output = xgb.dask.train(client,
        booster = output['booster']  # booster is the trained model
        history = output['history']  # A dictionary containing evaluation results
        # Save the model to file
        booster.save_model(args.model_dir + '/xgboost-model')
    while True:
        scheduler = (master_host_ip, SCHEDULER_PORT)
        alive_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        alive_check = alive_socket.connect_ex(scheduler)
        if alive_check == 0:
            print("Received a shutdown signal from Dask cluster")

I assume this is a basic question that I can’t answer being a newbe. Any help is appreciated!


Hi, first allow me to clarify that:

  1. DaskDMatrix is just a reference to input dask array/dataframe. It doesn’t store any data.
  2. DaskDMatrix doesn’t move data.

I think your questions are more related to dask/distributed than to xgboost. XGBoost consumes dask dataframe and dask array that represents the already processed data. All data loading/processing are managed by dask, xgboost does the machine learning.

instead of loading all the data in one node and distributing it with DaskDMatrix, how do I load data in individual nodes during training

This is a question about how to use dask to load sharded data. I think one doesn’t need to pay any special attention to it and can just load the data using normal dd.read_parquet and friends for all data shards and concatenate them into a single dataframe. All computations in dask are lazy and dask will handle data locality automatically. (data will stay on the workers)

From what I understand, the master node then distributes the data to all workers.

By master node, I think you mean the scheduler. For most cases, dask load data at worker nodes using worker processes, and the data doesn’t need to go through the scheduler. The scheduler does, well, scheduling.

from dask import dataframe as dd

with Client(scheduler_file="sched.json") as client:
    df0 = dd.read_csv("s3://...0.parquet")
    df1 = dd.read_csv("s3://...1.parquet")

The dfs are future objects and the actual data is loaded by worker processes instead of the scheduler.

Lastly, there’s the client object, which can be anything from a AWS node to your laptop and is the interface for you to launch tasks. It doesn’t store or touch any data unless you call compute, which draws the result from workers to the client (you).

I recommend reading the dask/distributed introduction first to gain an intuition on how it works.

1 Like

@ jiamingy Thanks for pointers! I was away for some time, but wanted to follow-up - I have been running with data fully replicated on all nodes when training on AWS SageMaker, this takes a long time. Is there no way to run with distribution=“ShardedByS3Key” option offered by SageMaker? I get the following error when I try:

Traceback (most recent call last):
  File "/miniconda3/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/miniconda3/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/ml/code/dask_entry_script.py", line 218, in <module>
    dtrain = xgb.dask.DaskDeviceQuantileDMatrix(client, X_train, y_train)
  File "/miniconda3/lib/python3.8/site-packages/xgboost/core.py", line 506, in inner_f
    return f(**kwargs)
  File "/miniconda3/lib/python3.8/site-packages/xgboost/dask.py", line 645, in __init__
  File "/miniconda3/lib/python3.8/site-packages/xgboost/core.py", line 506, in inner_f
    return f(**kwargs)
  File "/miniconda3/lib/python3.8/site-packages/xgboost/dask.py", line 298, in __init__
    self._init = client.sync(
  File "/miniconda3/lib/python3.8/site-packages/distributed/utils.py", line 339, in sync
    return sync(
  File "/miniconda3/lib/python3.8/site-packages/distributed/utils.py", line 406, in sync
    raise exc.with_traceback(tb)
  File "/miniconda3/lib/python3.8/site-packages/distributed/utils.py", line 379, in f
    result = yield future
  File "/miniconda3/lib/python3.8/site-packages/tornado/gen.py", line 762, in run
    value = future.result()
  File "/miniconda3/lib/python3.8/site-packages/xgboost/dask.py", line 405, in _map_local_data
    assert part.status == 'finished', part.status
AssertionError: error
2023-01-11 23:48:23,820 - distributed.worker - WARNING - Compute Failed
Key:       ('getitem-2bf7ea2dffeb67f46f1be3403a0e834e', 14)
Function:  subgraph_callable-bb500a8d-2a0a-4808-873e-ef7e67ed
args:      ({'piece': ('/opt/ml/input/data/train/xgboost_benchmark_95.parquet', [0], [])})
kwargs:    {}
Exception: "FileNotFoundError('/opt/ml/input/data/train/xgboost_benchmark_95.parquet')"
2023-01-11 23:48:23,822 - distributed.worker - WARNING - Compute Failed
Key:       ('getitem-2bf7ea2dffeb67f46f1be3403a0e834e', 20)
Function:  subgraph_callable-bb500a8d-2a0a-4808-873e-ef7e67ed
args:      ({'piece': ('/opt/ml/input/data/train/xgboost_benchmark_142.parquet', [0], [])})
kwargs:    {}

Looks like Dask read expects all the machines to have all parts of the Parquet data(using read_parquet, but also happens with multi-part CSV). I know this is likely a Dask topic, but the code is in the XGB package, so extending the above conversation.