hivemind.client

This module lets you connect to distributed Mixture-of-Experts or individual experts hosted in the cloud cloud on someone else's computer.

class hivemind.client.RemoteExpert(uid, endpoint: str)[source]

A simple module that runs forward/backward of an expert hosted on a remote machine. Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)

Warning: RemoteExpert currently assumes that you provide it with correct input shapes. Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.

Parameters:
  • uid – unique expert identifier
  • endpoint – network endpoint of a server that services that expert, e.g. “201.123.321.99:1337” or “[::]:8080”
forward(*args, **kwargs)[source]

Call RemoteExpert for the specified inputs and return its output(s). Compatible with pytorch.autograd.

class hivemind.client.RemoteMixtureOfExperts(*, in_features, grid_size: Tuple[int, ...], dht: hivemind.dht.DHT, uid_prefix: str, k_best: int, k_min: int = 1, forward_timeout: Optional[float] = None, timeout_after_k_min: Optional[float] = None, backward_k_min: int = 1, backward_timeout: Optional[float] = None, detect_anomalies: bool = False, allow_zero_outputs: bool = False, **dht_kwargs)[source]

A torch module that performs Mixture-of-Experts inference with a local gating function and multiple remote experts. Natively supports pytorch autograd.

Note:

By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without the missing experts

Parameters:
  • in_features – common input size for experts and gating function
  • grid_size – dimensions that form expert uid (see below)
  • uid_prefix – common prefix for all expert uids (must end with ‘.’)
  • dht – a DHT instance used to search for best experts
  • k_best – average this many highest-scoring experts to compute activations
  • k_min – make sure at least this many experts returned output (i.e. didn’t fail)
  • timeout_after_k_min – wait for this many seconds after k_min experts returned results. Any expert that didn’t manage to return output after that delay is considered unavailable
  • detect_anomalies – whether to check input/output tensors for NaN and infinity values
  • allow_zero_outputs – whether to return zeros if no experts respond on forward pass
Note:

expert uid follows the pattern {uid_prefix}.{0…grid_size[0]}.{0…grid_size[1]}…{0…grid_size[-1]}

forward(input: torch.Tensor, *args, **kwargs)[source]

Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all dimensions except for first and last (we assume that extra dimensions represent sequence length or image height/width)

Parameters:
  • input – a tensor of values that are used to estimate gating function, batch-first.
  • args – extra positional parameters that will be passed to each expert after input, batch-first
  • kwargs – extra keyword parameters that will be passed to each expert, batch-first
Returns:

averaged predictions of all experts that delivered result on time, nested structure of batch-first

compute_expert_scores(grid_scores: List[torch.Tensor], batch_experts: List[List[hivemind.client.expert.RemoteExpert]]) → torch.Tensor[source]

Compute scores for each expert by adding up grid scores, autograd-friendly :param grid_scores: list of torch tensors, i-th tensor contains scores for i-th grid dimension :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch :returns: a tensor of scores, float32[batch_size, k] :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf

class hivemind.client.RemoteSwitchMixtureOfExperts(*, grid_size: Tuple[int, ...], utilization_alpha: float = 0.9, grid_dropout: float = 1.0, jitter_eps: float = 0.01, k_best=1, k_min=0, backward_k_min=0, allow_zero_outputs=True, **kwargs)[source]

A module implementing Switch Transformers [1] Mixture-of-Experts inference with remote experts.

[1] Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.
William Fedus, Barret Zoph, Noam Shazeer. https://arxiv.org/abs/2101.03961
Note:

By default, not all experts are guaranteed to perform forward pass. Moreover, not all of those who ran forward pass are guaranteed to perform backward pass. In the latter case, gradient will be averaged without the missing experts

Parameters:
  • in_features – common input size for experts and gating function
  • grid_size – dimensions that form expert uid (see below)
  • uid_prefix – common prefix for all expert uids (must end with ‘.’)
  • dht – a DHT instance used to search for best experts
  • k_best – average this many highest-scoring experts to compute activations
  • k_min – make sure at least this many experts returned output (i.e. didn’t fail)
  • timeout_after_k_min – wait for this many seconds after k_min experts returned results. Any expert that didn’t manage to return output after that delay is considered unavailable
  • detect_anomalies – whether to check input/output tensors for NaN and infinity values
  • allow_zero_outputs – whether to return just the input if no experts respond on forward pass
Note:

expert uid follows the pattern {uid_prefix}.{0…grid_size[0]}.{0…grid_size[1]}…{0…grid_size[-1]}

forward(input: torch.Tensor, *args, **kwargs)[source]

Choose k best experts with beam search, then call chosen experts and average their outputs. Input tensor is averaged over all dimensions except for first and last (we assume that extra dimensions represent sequence length or image height/width)

Parameters:
  • input – a tensor of values that are used to estimate gating function, batch-first.
  • args – extra positional parameters that will be passed to each expert after input, batch-first
  • kwargs – extra keyword parameters that will be passed to each expert, batch-first
Returns:

averaged predictions of all experts that delivered result on time, nested structure of batch-first

compute_expert_scores(grid_probs: List[torch.Tensor], batch_experts: List[List[hivemind.client.expert.RemoteExpert]]) → torch.Tensor[source]

Compute scores for each expert by multiplying grid probabilities, autograd-friendly :param grid_probs: list of torch tensors, i-th tensor contains scores for i-th grid dimension :param batch_experts: list(batch) of lists(k) of up to k experts selected for this batch :returns: a tensor of scores, float32[batch_size, k] :note: if some rows in batch have less than max number of experts, their scores will be padded with -inf

class hivemind.client.DecentralizedAverager(averaged_tensors: Sequence[torch.Tensor], dht: DHT, *, start: bool, prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None, averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 65536, allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0, compression_type: runtime_pb2.CompressionType = 0, throughput: Optional[float] = None, min_vector_size: int = 0, auxiliary: bool = False, allow_state_sharing: Optional[bool] = None, listen: bool = True, listen_on: Endpoint = '0.0.0.0:*', daemon: bool = True, channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs)[source]

Parameter averaging service. A trainer can run this service in background to periodically average his parameters with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.

Parameters:
  • averaged_tensors – a sequence of pytorch tensors that will be averaged in each all-reduce
  • dht – a DHT node that will be used to find groups
  • start – if True, starts the background process immediately
  • prefix – a shared prefix for all group keys
  • target_group_size – attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
  • initial_group_bits – a string of bits (‘0’ and ‘1’) that define the initial group key (bucket index)
  • averaging_expiration – attempt to find a group for this many seconds, otherwise try again note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
  • compression_type – optionally compress tensors with this compression algorithm before sending them to peers
  • allreduce_timeout – spend at most this many seconds for allreduce (after group is formed)
  • averaging_alpha – optional “learning rate” for averaging. If specified, local parameters will be shifted towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
  • request_timeout – when looking for group, wait for a response from leader for at most this many seconds.
  • chunk_size_bytes – tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
  • throughput – if specified, this value represents the network bandwidth available to averager. By default, the averager is assumed to have the average bandwidth of his group. If throughput == 0, averager will rely on its groupmates to do all the averaging.
  • listen – if True (default), this averager will accept incoming requests from other peers and perform allreduce if False, the averager will register as a freeloader and attempt to fetch vectors from other averagers
  • listen_on – network interface, e.g. “0.0.0.0:1337” or “localhost:” ( means pick any port) or “[::]:7654”
  • channel_options – options for grpc.aio.insecure_channel, e.g. [(‘grpc.enable_retries’, 0)] see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
  • kwargs – extra parameters forwarded to grpc.aio.server
  • auxiliary – if this flag is specified, averager.step will only assist others without sending local tensors for averaging
  • allow_state_sharing – if set to True, other peers can download this peer’s state. Can be overwritten with averager.allow_state_sharing = True / False
Note:

request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.

Example:

>>> averager = DecentralizedAverager(...)
>>> with averager.get_tensors() as tensors:
>>>     # run some code, modify tensors if necessary
>>>     tensors[0] += 1
>>> # do not use tensors after the lock is released
>>> metadata = averager.step(gather=dict(my_batch_size=32))
>>> # run averaging once (in-place), gather metadata from groupmates
>>> with averager.get_tensors() as tensors_after_averaging:
>>>     pass # use the averaged tensors
serializer[source]

alias of hivemind.utils.serializer.MSGPackSerializer

allow_state_sharing[source]

if set to True, other peers can download this peer’s state

run()[source]

Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted Read more: https://github.com/pytorch/pytorch/issues/17199

run_in_background(await_ready=True, timeout=None)[source]

Starts averager in a background process. if await_ready, this method will wait until background dht is ready to process incoming requests or for :timeout: seconds max.

shutdown() → None[source]

Shut down the averager process

step(gather: Optional[Any] = None, weight: float = 1.0, timeout: Optional[float] = None, allow_retries: bool = True, wait: bool = True) → Union[Dict[str, Any], None, hivemind.utils.mpfuture.MPFuture][source]

Set up the averager to look for a group and run one round of averaging, return True on success, False on failure

Parameters:
  • gather – optionally send this informaton to all peers in the next group and gather it from every groupmate (this operation is known as all-gather). The gathered data will be available as the output of this function.
  • weight – averaging weight for this peer, int or float, must be strictly positive
  • allow_retries – if averager fails to run one round of allreduce, this option will allow it to try again within the specified timeout
  • timeout – if averager was unable to find a group in this many seconds, consider allreduce failedK
  • wait – if True (default), return when finished. Otherwise return MPFuture and run in background.
Returns:

on success, update averaged_tensors and return group info; on failure, return None

get_current_state() → Tuple[Any, Sequence[torch.Tensor]][source]

Get current state and send it to a peer. executed in the host process. Meant to be overriden. :returns: a tuple of (small metadata, sequence of torch tensors) :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)

load_state_from_peers(wait=True) → Optional[Tuple[Any, Sequence[torch.Tensor]]][source]

Try to download the latest optimizer state one of the existing peer. :returns: on success, return a 2-tuple with (metadata, tensors), where

  • metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
  • tensors is a sequence of pytorch tensors meant to contain peer’s model weights and optimizer statistics

The exact contents of both metadata and tensors are determined by get_current_state method

rpc_download_state(request: averaging_pb2.DownloadRequest, context: grpc.ServicerContext) → AsyncIterator[averaging_pb2.DownloadData][source]

Get the up-to-date trainer state from a peer. The state consists of two parts: (serialized_metadata, tensors)

  • serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
  • tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
get_group_bits(wait: bool = True)[source]
Parameters:wait – if True, return bits immediately. Otherwise return awaitable MPFuture
Returns:averager’s current group key bits (without prefix)
set_group_bits(group_bits: str, wait: bool = True)[source]
Parameters:
  • group_bits – group bits (string of ‘0’ or ‘1’) to be used in averager’s group key
  • wait – if True, wait until the update is confirmed by the averager. Otherwise return immediately