[dask] Is distributed training globally "data parallel" and locally "feature parallel"?


Is this a true statement?

When using distributed training in Python with xgb.dask, the data distribution is “data parallel”, meaning that each worker has a subset of the training data’s rows. Locally on each worker, the process of training a model is “feature parallel”, which means that candidate splits for different features are evaluated in parallel using multithreading.
As a result, the same feature + split combinations may be evaluated on multiple workers (using their local pieces of the data).

I ask because I’m trying to form good expectations for how training time with dask.xgb should change as I add more machines.

If my statement above is correct, I’d expect a decreasing rate of improvement from adding more workers because some duplicated work is being done (e.g. workers 1, 2, and 3 might all evaluate the split feat_1 > 0.75 on their local piece of the data).


I have read the original xgboost paper, and saw this description of the approx tree method used in distributed training:

In order to reduce the cost of sorting, we propose to store the data in in-memory units, which we called block. Data in each block is stored in the compressed column (CSC) format, with each column sorted by the corresponding feature value. This input data layout only needs to be computed once before training, and can be reused in later iterations
The block structure also helps when using the approximate algorithms. Multiple blocks can be used in this case, with each block corresponding to subset of rows in the dataset.
Different blocks can be distributed across machines, or stored on disk in the out-of-core setting. Using the sorted structure, the quantile finding step becomes a linear scan over the sorted columns.

As far as I can tell, the DaskDMatrix isn’t equivalent to that and can’t be, since Dask collections are lazy loaded and xgb.dask tries to avoid moving any data off of the worker it’s loaded onto.

I’m using these terms “data parallel” and “feature parallel” because that is how LightGBM describes them: https://lightgbm.readthedocs.io/en/latest/Features.html#optimization-in-parallel-learning.

Notes for reviewers

I have tried searching this discussion board, XGBoost issues, Stack Overflow, and the source code for xgboost and rabit and could not find an answer to this. Apologies in advance if this is covered somewhere already.

Thanks for your time and consideration!

Nice to see you here, @jameslamb!

No, there is no duplicated work when workers evaluate the same split candidate on their local piece of data.

Other factors may prevent linear speedup when adding more workers, but duplicated work is not the reason. One major factor for slowdown is the use of AllReduce to combine gradient histogram among multiple workers.

1 Like

Thanks! How is it that XGBoost avoids the case where the same combination of (feature, threshold) is evaluated on multiple workers?

Is the list of candidate splits determined globally first somehow, and then the work of evaluating them divided among the workers?

Yes, we use approximate quantiles for each feature to generate the list of candidate splits (feature, threshold).

No. The split evaluation works as follows:

  1. Each worker generates gradient histograms for the data slice it’s assigned. The histograms let us query the partial sum of gradient pairs
sum_i   (g_i, h_i)

where the sum is taken over the set of all data points for which the value of a particular feature is in a particular range. Each range is in form [q[j, k], q[j, k+1]] where q[j, k] is the k-th (approximate) quantile of feature j. As a result, the number of bins in the histograms is M * K * T, where M is the number of features, K is the number of quantiles (by default 256), and T is the number of tree nodes added so far in the tree.

  1. Workers perform AllReduce to combine the gradient histograms.

  2. Given the gradient histograms, workers are now able to choose the best split candidate.

  3. Workers perform data partition given the split candidate, i.e. all data rows have an updated partition (node) ID.

  4. Workers re-compute the gradient pairs for each data point using the new data partitions. Now go back to Step 1.

(I believe LightGBM’s distributed training works in a similar way. Correct me if I’m wrong here.)

Now I do realize that Step 3 represents duplicated workload across multiple workers. But given the histograms, the work involved in Step 3 is minimal, relative to the work needed in Steps 4 and 5. The reason is that the histograms constitute sufficient statistics for computing loss values for all split candidates. For example, we can evaluate the threshold q[j, k] as follows:

  1. Compute the “left” sum L by summing all the bins for ranges (-inf, q[j, 0]], …, [q[j, k-1], q[j, k]].
  2. Compute the “right” sum R by summing all the bins for ranges [q[j, k], q[j, k+1]], …, [q[j, K-1], +inf).
  3. Now use L and R to compute the change in loss that would result by creating the new split with (feature=j, threshold=q[j,k]). There exists a closed-form formula for the loss change, and the formula only depends on L and R.

In practice, I would be more worried about the bottleneck in AllReduce, as we have to communicate histograms of size M * K * T. It can be quite big, if we have high-dimensional data (large M) or if we grow deep trees (large T).

ps. If I have lots of time, I would love to write a whitepaper to describe the algorithm in precise mathematical details. But alas my hands are currently full.

1 Like

Thank you very much for the thorough answer!

Ok, this makes a lot more sense to me and I have a better idea of where the bottlenecks might be. I’m going to bookmark this page :grinning:

1 Like

I’m slightly confused about the histogram building process here – why is the size of the histogram M * K * T? In particular, I’m a bit confused about the T dimension. Why does each tree node have its own histogram in effect?

I thought gradient information is only updated once per tree, not once per node. Am I mistaken?

This is because we need to consider adding a new split to each tree node, and to do that we need to compute one histogram per tree node.

And this is because you need to build a new histogram on only the data instances in the tree node, right? Do you think there is a way to reuse some information from the histogram of say the nodes parent?

On that note, if you are building a new histogram for each node as you said, then why do you need to remember all T histograms?


This is already done.

Here T is the number of leaf nodes that are currently considered for expansion, and not the total number of leaf nodes. Sorry for confusion.

Thanks for the fast reply. Also wanted to check – for MSE loss function, the gradient of a data point is just the predicted value so far subtract the actual value, while the hessian is just 1. Are my calculations correct?

Yes, your calculation is correct.

I’m wondering why it is necessary to recompute histograms for the children. In fact can’t all the trees reuse the same histogram? Or is it too coarse grain? I.e. assuming we have data partitioned in 256 bins at root node, is it really necessary to come up with 256 new bins for the data on the left node and another 256 new bins for the data on the right node, effectively partitioning the data to 512 bins?

I too would really appreciate a white paper on how the distributed algorithm works :slight_smile:

No, this is not possible. The reason is that each bin of a histogram represents the partial sum of gradients summed over the set of data points {x : x is associated with the node AND feature j of x is in [q[j, k], q[j, k+1]]}. Intuitively, histograms partition the training data in one axis and nodes partition the training data in a different axis.


We can achieve a 2x speedup by computing histograms for two sibling nodes. The histogram for the right child node is the difference between the histogram for the left child node and the histogram for the parent node (which we have computed before). This DOES NOT mean that we can reuse the same histogram information for the entire tree.