hivemind.averaging¶
This module lets you average tensors in a decentralized manner.- class hivemind.averaging.DecentralizedAverager(averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.dht.DHT, *, start: bool, prefix: str, target_group_size: Optional[int] = None, min_group_size: int = 2, initial_group_bits: str = '', min_matchmaking_time: float = 5.0, request_timeout: float = 3.0, averaging_alpha: float = 1.0, part_size_bytes: int = 524288, allreduce_timeout: Optional[float] = None, next_chunk_timeout: Optional[float] = None, sender_timeout: Optional[float] = None, reducer_timeout: Optional[float] = None, compression: hivemind.compression.base.CompressionBase = hivemind.NoCompression(), state_compression: hivemind.compression.base.CompressionBase = hivemind.NoCompression(), tensor_infos: Optional[Sequence[hivemind.compression.base.CompressionInfo]] = None, bandwidth: Optional[float] = None, min_vector_size: int = 0, auxiliary: bool = False, allow_state_sharing: Optional[bool] = None, declare_state_period: float = 30, client_mode: Optional[bool] = None, daemon: bool = True, shutdown_timeout: float = 5)[source]¶
Parameter averaging service. A trainer can run this service in background to periodically average his parameters with other trainers. The averaging pattern is chosen so that (1) you only need to average with a small group of peers at a time, but (2) all trainers will converge to global average in a logarithmic number of steps.
- Parameters
averaged_tensors – a sequence of pytorch tensors that will be averaged in each all-reduce
dht – a DHT node that will be used to find groups
start – if True, starts the background process immediately
prefix – a shared prefix for all group keys
target_group_size – attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
initial_group_bits – a string of bits (‘0’ and ‘1’) that define the initial group key (bucket index)
min_matchmaking_time – when looking for group, wait for requests for at least this many seconds
compression – optionally compress tensors with this compression algorithm before running all-reduce
state_compression – a separate compression strategy for load_state_from_peers (default = no compression)
tensor_infos – CompressionInfo for each respective tensor; this determines how the tensor will be compressed
averaging_alpha – optional “learning rate” for averaging. If specified, local parameters will be shifted towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
request_timeout – when looking for group, wait for a response from leader for at most this many seconds.
part_size_bytes – tensors for AllReduce are processed in parts of up to this size (after compression)
bandwidth – if specified, this value represents the network bandwidth available to averager. By default, the averager is assumed to have the average bandwidth of his group. If bandwidth == 0, averager will rely on its groupmates to do all the averaging.
client_mode – if False, this averager will accept incoming requests from other peers. if True, the averager will only join existing groups where at least one peer has client_mode=False. By default, this flag is copied from DHTNode inside the
dht
instance.auxiliary – if this flag is specified, averager.step will only assist others without sending local tensors for averaging
allow_state_sharing – if set to True, other peers can download this peer’s state. Can be overwritten with averager.allow_state_sharing = True / False
declare_state_period – re-declare averager as a donor for load_state_from_peers every this many seconds
allreduce_timeout – spend at most this many seconds for allreduce (after group is formed)
next_chunk_timeout – during all-reduce and load_state_from_peers, if peer does not send next data chunk in this number of seconds, consider it failed and proceed with remaining peers. default: no timeout
sender_timeout – during all_reduce, any sender that fails to send tensor chunk within this many seconds from previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
reducer_timeout – during all_reduce, any reducer that fails to send results chunk within this many seconds from previous chunk will be marked as failed and excluded from averaging. default: 2 * sender_timeout
shutdown_timeout – when calling .shutdown, wait for up to this many seconds before terminating
- Note
request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
Example:
>>> averager = DecentralizedAverager(...) >>> with averager.get_tensors() as tensors: >>> # run some code, modify tensors if necessary >>> tensors[0] += 1 >>> # do not use tensors after the lock is released >>> metadata = averager.step(gather=dict(my_batch_size=32)) >>> # run averaging once (in-place), gather metadata from groupmates >>> with averager.get_tensors() as tensors_after_averaging: >>> pass # use the averaged tensors
- property allow_state_sharing: bool[source]¶
if set to True, other peers can download this peer’s state
- property state_sharing_priority: float[source]¶
Others will preferentially downloading state from peers with highest priority.
- run()[source]¶
Run averager function in a background thread; this is needed to avoid a heisenbug with broken OMP on fork Turns out, using a non-main thread creates a separate OMP pool that works even if the original pool is corrupted Read more: https://github.com/pytorch/pytorch/issues/17199
- run_in_background(await_ready: bool = True, timeout: Optional[float] = None) None [source]¶
Starts averager in a background process. if await_ready, this method will wait until background dht is ready to process incoming requests or for :timeout: seconds max.
- step(gather: Optional[Any] = None, scheduled_time: Optional[float] = None, weight: Optional[float] = None, timeout: Optional[float] = None, allow_retries: bool = True, require_trigger: bool = False, wait: bool = True) Union[Dict[hivemind.p2p.p2p_daemon_bindings.datastructures.PeerID, Any], None, hivemind.averaging.control.StepControl] [source]¶
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
- Parameters
gather – optionally send this information to all peers in the next group and gather it from every groupmate (this operation is known as all-gather). The gathered data will be available as the output of this function.
scheduled_time – when matchmaking, assume that all-reduce will begin at this moment. By default, schedule all-reduce current time plus min_matchmaking_time seconds
weight – averaging weight for this peer, int or float, must be strictly positive
allow_retries – if averager fails to run one round of allreduce, this option will allow it to try again within the specified timeout
require_trigger – if True, await for user to call .allow_allreduce() before running all-reduce
timeout – if averager was unable to find a group in this many seconds, consider allreduce failed
wait – if True (default), return when finished. Otherwise return StepControl and run in background.
- Returns
on success, update averaged_tensors and return group info; on failure, return None
- rpc_download_state(_request: averaging_pb2.DownloadRequest, _context: hivemind.p2p.p2p_daemon.P2PContext) AsyncIterator[averaging_pb2.DownloadData] [source]¶
Get the up-to-date trainer state from a peer. The state consists of two parts: (serialized_metadata, tensors)
serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
- get_current_state() Tuple[Any, Sequence[torch.Tensor], Sequence[hivemind.compression.base.CompressionInfo]] [source]¶
Get current state and send it to a peer. executed in the host process. Meant to be overridden. :returns: a tuple of (small metadata, sequence of torch tensors) :note: metadata must be seriablizable with self.serializer (default = MSGPackSerializer)
- load_state_from_peers(wait: bool = True, timeout: Optional[float] = None) Optional[Tuple[Any, Sequence[torch.Tensor]]] [source]¶
Try to download the latest optimizer state one of the existing peer. :returns: on success, return a 2-tuple with (metadata, tensors), where
metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
tensors is a sequence of pytorch tensors meant to contain peer’s model weights and optimizer statistics
The exact contents of both metadata and tensors are determined by get_current_state method