hivemind.moe.server

A hivemind server hosts one or several experts and processes incoming requests to those experts. It periodically re-publishes these experts to the dht via a dedicated hivemind.dht.DHT peer that runs in background. The experts can be accessed directly as hivemind.moe.client.RemoteExpert(“addr:port”, “expert.uid.here”) or as a part of hivemind.moe.client.RemoteMixtureOfExperts that finds the most suitable experts across the DHT.

The hivemind.moe.server module is organized as follows:

  • Server is the main class that publishes experts, accepts incoming requests, and passes them to Runtime for compute.

  • ModuleBackend is a wrapper for torch.nn.Module that can be accessed by remote clients. It has two TaskPool s for forward and backward requests.

  • Runtime balances the device (GPU) usage between several ModuleBackend instances that each service one expert.

  • TaskPool stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches and offers those batches to Runtime for processing.

class hivemind.moe.server.Server(dht: hivemind.dht.dht.DHT, module_backends: Dict[str, hivemind.moe.server.module_backend.ModuleBackend], num_connection_handlers: int = 1, update_period: float = 30, expiration: Optional[float] = None, start=False, checkpoint_dir=None, **kwargs)[source]

Server allows you to host “experts” - pytorch subnetworks that can be accessed remotely by peers. After creation, a server should be started: see Server.run or Server.run_in_background.

A working server does two things:
  • processes incoming forward/backward requests via Runtime (created by the server)

  • publishes updates to expert status every :update_period: seconds

Parameters
  • module_backends – dict{expert uid (str) : ModuleBackend} for all expert hosted by this server.

  • num_connection_handlers – maximum number of simultaneous requests. Please note that the default value of 1 if too small for normal functioning, we recommend 4 handlers per expert backend.

  • update_period – how often will server attempt to publish its state (i.e. experts) to the DHT; if dht is None, this parameter is ignored.

  • expiration – when server declares its experts to the DHT, these entries will expire after this many seconds

  • start – if True, the server will immediately start as a background thread and returns control after server is ready (see .ready below)

classmethod create(num_experts: Optional[int] = None, expert_uids: Optional[str] = None, expert_pattern: Optional[str] = None, expert_cls='ffn', hidden_dim=1024, optim_cls=<class 'torch.optim.adam.Adam'>, scheduler: str = 'none', num_warmup_steps=None, num_training_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1, max_batch_size=4096, device=None, initial_peers=(), checkpoint_dir: Optional[pathlib.Path] = None, compression=0, stats_report_interval: Optional[int] = None, custom_module_path=None, update_period: float = 30, expiration: Optional[float] = None, *, start: bool, **kwargs) hivemind.moe.server.server.Server[source]

Instantiate a server with several identical modules. See argparse comments below for details

Parameters
  • num_experts – run this many identical experts

  • expert_pattern – a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256] means “sample random experts between myprefix.0.0 and myprefix.255.255;

  • expert_uids – spawn experts with these exact uids, overrides num_experts and expert_pattern

  • expert_cls – expert type from hivemind.moe.server.layers, e.g. ‘ffn’ or ‘transformer’;

  • hidden_dim – main dimension for expert_cls

  • num_handlers – server will use this many parallel processes to handle incoming requests

  • min_batch_size – total num examples in the same batch will be greater than this value

  • max_batch_size – total num examples in the same batch will not exceed this value

  • device – all experts will use this device in torch notation; default: cuda if available else cpu

  • optim_cls – uses this optimizer to train all experts

  • scheduler – if not none, the name of the expert LR scheduler

  • num_warmup_steps – the number of warmup steps for LR schedule

  • num_training_steps – the total number of steps for LR schedule

  • clip_grad_norm – maximum gradient norm used for clipping

  • initial_peers – multiaddrs of one or more active DHT peers (if you want to join an existing DHT)

  • checkpoint_dir – directory to save and load expert checkpoints

  • compression – if specified, use this compression to pack all inputs, outputs and gradients by all experts hosted on this server. For a more fine-grained compression, start server in python and specify compression for each BatchTensorProto in ModuleBackend for the respective experts.

  • start – if True, starts server right away and returns when server is ready for requests

  • stats_report_interval – interval between two reports of batch processing performance statistics

  • kwargs – any other params will be forwarded to DHT upon creation

run()[source]

Starts Server in the current thread. Initializes dht if necessary, starts connection handlers, runs Runtime (self.runtime) to process incoming requests.

run_in_background(await_ready=True, timeout=None)[source]

Starts Server in a background thread. if await_ready, this method will wait until background server is ready to process incoming requests or for :timeout: seconds max.

property ready: multiprocessing.synchronize.Event[source]

An event (multiprocessing.Event) that is set when the server is ready to process requests.

Example

>>> server.start()
>>> server.ready.wait(timeout=10)
>>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
shutdown()[source]

Gracefully terminate the server, process-safe. Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes. If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).

class hivemind.moe.server.ModuleBackend(name: str, module: torch.nn.modules.module.Module, *, optimizer: Optional[torch.optim.optimizer.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, args_schema: Optional[Tuple[hivemind.utils.tensor_descr.BatchTensorDescriptor, ...]] = None, kwargs_schema: Optional[Dict[str, hivemind.utils.tensor_descr.BatchTensorDescriptor]] = None, outputs_schema: Optional[Union[hivemind.utils.tensor_descr.BatchTensorDescriptor, Tuple[hivemind.utils.tensor_descr.BatchTensorDescriptor, ...]]] = None, **kwargs)[source]

ModuleBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime By default, ModuleBackend handles three types of requests:

  • forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.

  • backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and update expert. Also batched.

  • get_info - return expert metadata. Not batched.

Parameters
  • module

    nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:

    • Experts must always receive the same set of args and kwargs and produce output tensors of same type

    • All args, kwargs and outputs must be tensors where 0-th dimension represents to batch size

    • We recommend using experts that are ~invariant to the order in which they process batches

    • Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want consistency,

      you should explicitly register these random variables as model inputs or outputs. See hivemind.utils.custom_layers.DeterministicDropout for an example

  • optimizer – torch optimizer to be applied on every backward call

  • scheduler – a function to create the learning rate scheduler for the expert

  • args_schema – description of positional arguments to expert.forward, list of BatchTensorProto

  • kwargs_schema – description of keyword arguments to expert.forward, dict of BatchTensorProto

  • outputs_schema – description of outputs from expert.forward, nested structure of BatchTensorProto

  • kwargs – extra parameters to be forwarded into TaskPool.__init__

forward(*inputs: torch.Tensor) Tuple[torch.Tensor, ...][source]

Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; To submit a request for asynchronous processing, please use ModuleBackend.forward_pool.submit_task.

Subclassing:

This method receives a sequence of torch tensors following nested_flatten(self.forward_schema); It should return gradients w.r.t. inputs that follow nested_flatten(self.outputs_schema);

backward(*inputs: torch.Tensor) Tuple[torch.Tensor, ...][source]

Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually To submit a request for asynchronous processing, please use ModuleBackend.backward_pool.submit_task.

Subclassing:

This method receives a sequence of torch tensors following nested_flatten(self.backward_schema);

It should return gradients w.r.t. inputs that follow nested_flatten(self.forward_schema);

Runtime doesn’t guarantee that backward will be performed in the same order and for the same data as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.

Please make sure to call ModuleBackend.on_backward after each call to backward

on_backward(batch_size: int) None[source]

Train the expert for one step. This method is called by ModuleBackend.backward after computing gradients.

get_info() Dict[str, Any][source]

Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.

get_pools() Sequence[hivemind.moe.server.task_pool.TaskPool][source]

return all pools that should be processed by Runtime

class hivemind.moe.server.runtime.Runtime(module_backends: Dict[str, hivemind.moe.server.module_backend.ModuleBackend], prefetch_batches=64, sender_threads: int = 1, device: Optional[torch.device] = None, stats_report_interval: Optional[int] = None)[source]

A group of processes that processes incoming requests for multiple module backends on a shared device. Runtime is usually created and managed by Server, humans need not apply.

For debugging, you can start runtime manually with .start() or .run()

>>> module_backends = {'expert_name': ModuleBackend(**kwargs)}
>>> runtime = Runtime(module_backends)
>>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
>>> runtime.ready.wait()  # await for runtime to load all experts on device and create request pools
>>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs)
>>> print("Returned:", future.result())
>>> runtime.shutdown()
Parameters
  • module_backends – a dict [expert uid -> ModuleBackend]

  • prefetch_batches – form up to this many batches in advance

  • sender_threads – dispatches outputs from finished batches using this many asynchronous threads

  • device – if specified, moves all experts and data to this device via .to(device=device). If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)

  • stats_report_interval – interval to collect and log statistics about runtime performance

run()[source]

Method representing the thread’s activity.

You may override this method in a subclass. The standard run() method invokes the callable object passed to the object’s constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.

process_batch(pool: hivemind.moe.server.task_pool.TaskPoolBase, batch_index: int, *batch: torch.Tensor) Tuple[Any, int][source]

process one batch of tasks from a given pool, return a batch of results and total batch size

shutdown()[source]

Gracefully terminate a running runtime.

iterate_minibatches_from_pools(timeout=None)[source]

Iteratively select non-empty pool with highest priority and loads a batch from that pool

class hivemind.moe.server.task_pool.TaskPool(process_func: callable, max_batch_size: int, name: str, min_batch_size=1, timeout=None, pool_size=None, prefetch_batches=1, daemon=True, start=False)[source]

Request aggregator that accepts processing requests, groups them into batches, waits for Runtime to process these batches and dispatches results back to request sources. Operates as a background process.

Parameters
  • process_func – function to be applied to every formed batch; called by Runtime Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors

  • max_batch_size – process at most this many inputs in a batch (task contains have one or several inputs)

  • name – pool name

  • min_batch_size – process at least this many inputs in a batch, otherwise wait for more

  • timeout – wait for a subsequent task for at most this many seconds

  • pool_size – store at most this many unprocessed tasks in a queue

  • prefetch_batches – prepare up to this many batches in background for faster off-loading to runtime

  • start – if True, start automatically at the end of __init__

submit_task(*args: torch.Tensor) concurrent.futures._base.Future[source]

Add task to this pool’s queue, return Future for its output

iterate_minibatches(*args, **kwargs)[source]

Form minibatches by grouping one or more tasks together up to self.max_batch_size

load_batch_to_runtime(timeout=None, device=None) Tuple[Any, List[torch.Tensor]][source]

receive next batch of tensors

send_outputs_from_runtime(batch_index: int, batch_outputs: List[torch.Tensor])[source]

send results for a processed batch, previously loaded through load_batch_to_runtime

get_task_size(task: hivemind.moe.server.task_pool.Task) int[source]

compute task processing complexity (used for batching); defaults to batch size