Hi, please pardon me if this has been discussed already, but I couldn’t find any examples of how to run a multi-node, multi-GPU training with Dask. The examples use LocalCudaCluster and are only limited to single-node training.
However, I found this sample and was able to run multi-node using the following steps:
- Start scheduler process in a master node, and start dask-cuda-worker processes in all nodes.
- Read the input into a DaskDMatrix, in the master node, and start a training job from that node.
From what I understand, the master node then distributes the data to all workers.
My question is if I have a lot of training data and I want to use AWS SageMaker’s distribution=“ShardedByS3Key” to divide the data into individual instances beforehand, how do I start the training job in such a way that each worker will read the data present in it’s instance? In other words, instead of loading all the data in one node and distributing it with DaskDMatrix, how do I load data in individual nodes during training?
Here’s my current code:
if current_host_ip == master_host_ip:
with Client("127.0.0.1:8786") as client:
# Loads the entire dataset in this master node, and divides it into train/valid.
X_train, X_valid, y_train, y_valid = load_higgs_csv(args.train)
dtrain = xgb.dask.DaskDMatrix(client, X_train, y_train)
dvalid = xgb.dask.DaskDMatrix(client, X_valid, y_valid)
watchlist = ([(dtrain, "train"), (dvalid, "validation")])
output = xgb.dask.train(client,
train_hp,
dtrain,
evals=watchlist)
booster = output['booster'] # booster is the trained model
history = output['history'] # A dictionary containing evaluation results
# Save the model to file
booster.save_model(args.model_dir + '/xgboost-model')
else:
while True:
scheduler = (master_host_ip, SCHEDULER_PORT)
alive_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
alive_check = alive_socket.connect_ex(scheduler)
if alive_check == 0:
pass
else:
print("Received a shutdown signal from Dask cluster")
sys.exit(0)
alive_socket.close()
time.sleep(2)
I assume this is a basic question that I can’t answer being a newbe. Any help is appreciated!
Thanks!