hivemind.optim

This module contains decentralized optimizers that wrap your regular PyTorch Optimizer to train with peers. Depending on the exact configuration, Optimizer may perform large synchronous updates equivalent, or perform asynchronous local updates and average model parameters.

hivemind.Optimizer

class hivemind.optim.optimizer.Optimizer(*, dht: hivemind.dht.dht.DHT, run_id: str, target_batch_size: int, batch_size_per_step: Optional[int] = None, optimizer: Union[torch.optim.optimizer.Optimizer, Callable[[Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]], torch.optim.optimizer.Optimizer]], params: Optional[Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]] = None, scheduler: Optional[Union[torch.optim.lr_scheduler._LRScheduler, Callable[[torch.optim.optimizer.Optimizer], torch.optim.lr_scheduler._LRScheduler]]] = None, matchmaking_time: Optional[float] = 15.0, averaging_timeout: Optional[float] = 60.0, allreduce_timeout: Optional[float] = None, next_chunk_timeout: Optional[float] = None, load_state_timeout: float = 600.0, reuse_grad_buffers: bool = False, offload_optimizer: Optional[bool] = None, delay_optimizer_step: Optional[bool] = None, delay_grad_averaging: bool = False, delay_state_averaging: bool = True, average_state_every: int = 1, use_local_updates: bool = False, client_mode: Optional[bool] = None, auxiliary: bool = False, grad_compression: hivemind.compression.base.CompressionBase = hivemind.NoCompression(), grad_averager_factory: Optional[Callable[[...], hivemind.optim.grad_averager.TGradientAverager]] = None, state_averaging_compression: hivemind.compression.base.CompressionBase = hivemind.NoCompression(), load_state_compression: hivemind.compression.base.CompressionBase = hivemind.NoCompression(), average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (), averager_opts: Optional[dict] = None, tracker_opts: Optional[dict] = None, performance_ema_alpha: float = 0.1, shutdown_timeout: float = 5, verbose: bool = False)[source]

hivemind.Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.

By default, Optimizer is configured to be exactly equivalent to synchronous training with target_batch_size. There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging) or even fully asynchronous (use_local_updates=True).

Example

The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:

>>> model = transformers.AutoModel("albert-xxlarge-v2")
>>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
>>> opt = hivemind.Optimizer(dht=dht, run_id="run_42", batch_size_per_step=4, target_batch_size=4096,
>>>                          params=model.parameters(), optimizer=lambda params: torch.optim.Adam(params))
>>> while True:
>>>     loss = compute_loss_on_batch(model, batch_size=4)
>>>     opt.zero_grad()
>>>     loss.backward()
>>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)

By default, peers will perform the following steps:

  • accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;

  • after peers collectively accumulate target_batch_size, average gradients with peers and perform optimizer step;

  • if your peer lags behind the rest of the swarm, it will download parameters and optimizer state from others;

Unlike regular training, your device may join midway through training, when other peers already made some progress. For this reason, any learning rate schedulers, curriculum and other time-dependent features should be based on optimizer.local_epoch (and not the number of calls to opt.step). Otherwise, peers that joined training late may end up having different learning rates. To do so automatically, specify scheduler=... parameter below.

What is an epoch?

Optimizer uses the term epoch to describe intervals between synchronizations. One epoch corresponds to processing certain number of training samples (target_batch_size) in total across all peers. Like in PyTorch LR Scheduler, epoch does not necessarily correspond to a full pass over the training data. At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update, updating the learning rate scheduler or simply averaging parameters (if using local updates). The purpose of this is to ensure that changing the number of peers does not require changing hyperparameters. For instance, if the number of peers doubles, they will run all-reduce more frequently to adjust for faster training.

Configuration guide

This guide will help you set up your first collaborative training run. It covers the most important basic options, but ignores features that require significant changes to the training code.

>>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=IF_BEHIND_FIREWALL_OR_VERY_UNRELIABLE, start=True)
>>> opt = hivemind.Optimizer(
>>>    dht=dht, run_id="a_unique_name_that_every_participant_will_see_when_training",
>>>    batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER, target_batch_size=LARGE_GLOBAL_BATCH,
>>>    # ^--- Each global optimzier step will use gradients from 1x-1.1x of target_batch_size (due to latency);
>>>    # It is recommended to train with very large batch sizes to reduce the % of time spent on communication.
>>>
>>>    params=params, optimizer=lambda params: AnyPyTorchOptimizer(params, **hyperparams_for_target_batch_size),
>>>    # tune learning rate for your target_batch_size. Here's a good reference: https://arxiv.org/abs/1904.00962
>>>    scheduler=lambda opt: AnyPyTorchScheduler(opt, **hyperparams_for_target_batch_size),
>>>    # scheduler.step will be called automatically each time when peers collectively accumulate target_batch_size
>>>
>>>    offload_optimizer=True,  # saves GPU memory, but increases RAM usage; Generally a good practice to use this.
>>>    delay_grad_averaging=OPTIONAL, delay_optimizer_step=OPTIONAL, # train faster, but with 1 round of staleness;
>>>    # setting both to True is equivalent to Delayed Parameter Updates (see https://arxiv.org/abs/2101.06840)
>>>
>>>    grad_compression=hivemind.Float16Compression(),  state_averaging_compression=hivemind.Float16Compression(),
>>>    # ^-- it is usually fine to use pure 16-bit or even lower precision during communication with no precaution;
>>>    # See hivemind/examples/albert for an working example of mixed 8/16-bit compression.
>>>
>>>    matchmaking_time=15.0, # 3-5s for small local runs, 10-15s for training over the internet or with many peers
>>>    averaging_timeout=60.0,  # around of 2x the actual time it takes to run all-reduce
>>>    verbose=True  # periodically report the training progress to the console (e.g. "Averaged with N peers")
>>> )  # and you're done!
Parameters
  • dht – a running hivemind.DHT instance connected to other peers.

  • run_id – a unique identifier of this training run, used as a common prefix for all DHT keys. Note: peers with the same run_id should generally train the same model and use compatible configurations. Some options can be safely changed by individual peers: batch_size_per_step, client_mode, auxiliary, reuse_grad_buffers, offload_optimizer, and verbose. In some cases, other options may also be tuned individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.

  • target_batch_size – global batch size that must be accumulated before the swarm transitions to the next epoch. The actual batch may be slightly larger due asynchrony (e.g. peers submit more gradients in the last second).

  • batch_size_per_step – you should accumulate gradients over this many samples between calls to optimizer.step.

  • params – parameters or param groups for the optimizer; required if optimizer is a callable(params).

  • optimizer – a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer. Note: some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging require and require the callable and will not work if hivemind.optimizer is created with a pre-existing PyTorch Optimizer.

  • scheduler – callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler. The learning rate scheduler will adjust learning rate based on global epoch, not the number of local calls to optimizer.step; this is required to keep different peers synchronized.

  • matchmaking_time – when looking for group, wait for peers to join for up to this many seconds. Increase if you see “averaged gradients with N peers” where N is below 0.9x the real siee on >=25% of epochs. When training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.

  • averaging_timeout – if an averaging step hangs for this long, it will be cancelled automatically. Increase averaging_timeout if you see “Proceeding with local gradients” at least 25% of the time. Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.

  • allreduce_timeout – timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.

  • load_state_timeout – wait for at most this many seconds before giving up on load_state_from_peers.

  • reuse_grad_buffers – if True, use model’s .grad buffers for gradient accumulation. This is more memory efficient, but it requires that the user does NOT call model/opt zero_grad at all

  • offload_optimizer – offload the optimizer to host memory, saving GPU memory for parameters and gradients

  • delay_optimizer_step – run optimizer in background, apply results in future .step; requires offload_optimizer

  • delay_grad_averaging – average gradients in background; requires offload_optimizer and delay_optimizer_step

  • delay_state_averaging – if enabled (default), average parameters and extra tensors in a background thread; if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.

  • average_state_every – average state (parameters, chosen opt tensors) with peers every this many epochs. This reduces the communication overhead increasing, but can cause parameters to diverge if too large. The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.

  • use_local_updates – if enabled, peers will update parameters on each .step using local gradients; if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients. Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.

  • client_mode – if True, this peer will not accept incoming connections (firewall-compatible mode)

  • auxiliary – if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)

  • grad_compression – compression strategy used for averaging gradients, default = no compression

  • grad_averager_factory – if provided, creates gradient averager with required averaging strategy

  • state_averaging_compression – compression for averaging params and state tensors, default = no compression

  • load_state_compression – compression strategy for loading state from peers, default = no compression

  • average_opt_statistics – names of optimizer statistics from state dict that should be averaged with peers

  • extra_tensors – if specified, these extra tensors will also be averaged and shared in load_state_from_peers.

  • averager_opts – additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager

  • tracker_opts – additional keyword arguments forwarded to ProgressTracker

  • performance_ema_alpha – moving average alpha in ProgressTracker, TrainingStateAverager and Optimizer

  • verbose – if True, report internal events such as accumilating gradients and running background tasks

Note

in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.

property local_epoch: int[source]

This worker’s current epoch, kept synchronized with peers. If peer’s local_epoch lags behind others, it will automatically re-synchronize by downloading state from another peer. An epoch corresponds to accumulating target_batch_size across all active devices.

step(closure: Optional[Callable[[], torch.Tensor]] = None, batch_size: Optional[int] = None, grad_scaler: Optional[hivemind.optim.grad_scaler.GradScaler] = None)[source]

Update training progress after accumulating another local batch size. Depending on the configuration, this will report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.

Parameters
  • closure – A closure that reevaluates the model and returns the loss.

  • batch_size – optional override for batch_size_per_step from init.

  • grad_scaler – if amp is enabled, this must be a hivemind-aware gradient scaler.

Note

this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.

zero_grad(set_to_none: bool = False)[source]

Reset gradients from model. If reuse_grad_buffers=True, this will raise an error.

load_state_from_peers(**kwargs)[source]

Attempt to load the newest collaboration state from other peers within the same run_id.

If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.

class hivemind.optim.grad_scaler.GradScaler(*args, **kwargs)[source]

A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.

Note

if not using reuse_grad_buffers=True, one can and should train normally without this class, e.g. using standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.

hivemind.GradScaler makes 3 modifications to the regular PyTorch AMP:

  • bypass .unscale_ and .update calls in order to accumulate gradients over several steps

  • limit increasing gradient scale to only immediately after global optimizer steps

  • allow training with some or master parameters in float16

Note

The above modiffications will be enabled automatically. One can (and should) use hivemind.GradScaler exactly as regular torch.amp.GradScaler.