hivemind.moe.server¶
A hivemind server hosts one or several experts and processes incoming requests to those experts. It periodically re-publishes these experts to the dht via a dedicated hivemind.dht.DHT peer that runs in background. The experts can be accessed directly as hivemind.moe.client.RemoteExpert(“addr:port”, “expert.uid.here”) or as a part of hivemind.moe.client.RemoteMixtureOfExperts that finds the most suitable experts across the DHT.
The hivemind.moe.server module is organized as follows:
Server is the main class that publishes experts, accepts incoming requests, and passes them to Runtime for compute.
ModuleBackend is a wrapper for torch.nn.Module that can be accessed by remote clients. It has two TaskPool s for forward and backward requests.
Runtime balances the device (GPU) usage between several ModuleBackend instances that each service one expert.
TaskPool stores incoming requests for a batch-parallel computation (e.g. forward pass), groups them into batches and offers those batches to Runtime for processing.
- class hivemind.moe.server.Server(dht: hivemind.dht.dht.DHT, module_backends: Dict[str, hivemind.moe.server.module_backend.ModuleBackend], num_connection_handlers: int = 1, update_period: float = 30, expiration: Optional[float] = None, start=False, checkpoint_dir=None, **kwargs)[source]¶
Server allows you to host “experts” - pytorch subnetworks that can be accessed remotely by peers. After creation, a server should be started: see Server.run or Server.run_in_background.
- A working server does two things:
processes incoming forward/backward requests via Runtime (created by the server)
publishes updates to expert status every :update_period: seconds
- Parameters
module_backends – dict{expert uid (str) : ModuleBackend} for all expert hosted by this server.
num_connection_handlers – maximum number of simultaneous requests. Please note that the default value of 1 if too small for normal functioning, we recommend 4 handlers per expert backend.
update_period – how often will server attempt to publish its state (i.e. experts) to the DHT; if dht is None, this parameter is ignored.
expiration – when server declares its experts to the DHT, these entries will expire after this many seconds
start – if True, the server will immediately start as a background thread and returns control after server is ready (see .ready below)
- classmethod create(num_experts: Optional[int] = None, expert_uids: Optional[str] = None, expert_pattern: Optional[str] = None, expert_cls='ffn', hidden_dim=1024, optim_cls=<class 'torch.optim.adam.Adam'>, scheduler: str = 'none', num_warmup_steps=None, num_training_steps=None, clip_grad_norm=None, num_handlers=None, min_batch_size=1, max_batch_size=4096, device=None, initial_peers=(), checkpoint_dir: Optional[pathlib.Path] = None, compression=0, stats_report_interval: Optional[int] = None, custom_module_path=None, update_period: float = 30, expiration: Optional[float] = None, *, start: bool, **kwargs) hivemind.moe.server.server.Server [source]¶
Instantiate a server with several identical modules. See argparse comments below for details
- Parameters
num_experts – run this many identical experts
expert_pattern – a string pattern or a list of expert uids, example: myprefix.[0:32].[0:256] means “sample random experts between myprefix.0.0 and myprefix.255.255;
expert_uids – spawn experts with these exact uids, overrides num_experts and expert_pattern
expert_cls – expert type from hivemind.moe.server.layers, e.g. ‘ffn’ or ‘transformer’;
hidden_dim – main dimension for expert_cls
num_handlers – server will use this many parallel processes to handle incoming requests
min_batch_size – total num examples in the same batch will be greater than this value
max_batch_size – total num examples in the same batch will not exceed this value
device – all experts will use this device in torch notation; default: cuda if available else cpu
optim_cls – uses this optimizer to train all experts
scheduler – if not none, the name of the expert LR scheduler
num_warmup_steps – the number of warmup steps for LR schedule
num_training_steps – the total number of steps for LR schedule
clip_grad_norm – maximum gradient norm used for clipping
initial_peers – multiaddrs of one or more active DHT peers (if you want to join an existing DHT)
checkpoint_dir – directory to save and load expert checkpoints
compression – if specified, use this compression to pack all inputs, outputs and gradients by all experts hosted on this server. For a more fine-grained compression, start server in python and specify compression for each BatchTensorProto in ModuleBackend for the respective experts.
start – if True, starts server right away and returns when server is ready for requests
stats_report_interval – interval between two reports of batch processing performance statistics
kwargs – any other params will be forwarded to DHT upon creation
- run()[source]¶
Starts Server in the current thread. Initializes dht if necessary, starts connection handlers, runs Runtime (self.runtime) to process incoming requests.
- run_in_background(await_ready=True, timeout=None)[source]¶
Starts Server in a background thread. if await_ready, this method will wait until background server is ready to process incoming requests or for :timeout: seconds max.
- property ready: multiprocessing.synchronize.Event[source]¶
An event (multiprocessing.Event) that is set when the server is ready to process requests.
Example
>>> server.start() >>> server.ready.wait(timeout=10) >>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
- class hivemind.moe.server.ModuleBackend(name: str, module: torch.nn.modules.module.Module, *, optimizer: Optional[torch.optim.optimizer.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, args_schema: Optional[Tuple[hivemind.utils.tensor_descr.BatchTensorDescriptor, ...]] = None, kwargs_schema: Optional[Dict[str, hivemind.utils.tensor_descr.BatchTensorDescriptor]] = None, outputs_schema: Optional[Union[hivemind.utils.tensor_descr.BatchTensorDescriptor, Tuple[hivemind.utils.tensor_descr.BatchTensorDescriptor, ...]]] = None, **kwargs)[source]¶
ModuleBackend is a wrapper around torch module that allows it to run tasks asynchronously with Runtime By default, ModuleBackend handles three types of requests:
forward - receive inputs and compute outputs. Concurrent requests will be batched for better GPU utilization.
backward - receive gradients w.r.t. outputs, compute gradients w.r.t. inputs and update expert. Also batched.
get_info - return expert metadata. Not batched.
- Parameters
module –
nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
Experts must always receive the same set of args and kwargs and produce output tensors of same type
All args, kwargs and outputs must be tensors where 0-th dimension represents to batch size
We recommend using experts that are ~invariant to the order in which they process batches
- Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want consistency,
you should explicitly register these random variables as model inputs or outputs. See hivemind.utils.custom_layers.DeterministicDropout for an example
optimizer – torch optimizer to be applied on every backward call
scheduler – a function to create the learning rate scheduler for the expert
args_schema – description of positional arguments to expert.forward, list of BatchTensorProto
kwargs_schema – description of keyword arguments to expert.forward, dict of BatchTensorProto
outputs_schema – description of outputs from expert.forward, nested structure of BatchTensorProto
kwargs – extra parameters to be forwarded into TaskPool.__init__
- forward(*inputs: torch.Tensor) Tuple[torch.Tensor, ...] [source]¶
Apply forward pass to an aggregated batch of requests. Used by Runtime, do not call this manually; To submit a request for asynchronous processing, please use
ModuleBackend.forward_pool.submit_task
.- Subclassing:
This method receives a sequence of torch tensors following
nested_flatten(self.forward_schema)
; It should return gradients w.r.t. inputs that follownested_flatten(self.outputs_schema)
;
- backward(*inputs: torch.Tensor) Tuple[torch.Tensor, ...] [source]¶
Apply backward pass to an aggregated batch of requests. Used by Runtime, do not call this manually To submit a request for asynchronous processing, please use
ModuleBackend.backward_pool.submit_task
.- Subclassing:
This method receives a sequence of torch tensors following
nested_flatten(self.backward_schema)
;It should return gradients w.r.t. inputs that follow
nested_flatten(self.forward_schema)
;Runtime doesn’t guarantee that backward will be performed in the same order and for the same data as forward, so we recommend stateless backward pass that re-runs expert forward pass inside backward.
Please make sure to call
ModuleBackend.on_backward
after each call to backward
- on_backward(batch_size: int) None [source]¶
Train the expert for one step. This method is called by
ModuleBackend.backward
after computing gradients.
- get_info() Dict[str, Any] [source]¶
Get expert parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.
- get_pools() Sequence[hivemind.moe.server.task_pool.TaskPool] [source]¶
return all pools that should be processed by
Runtime
- class hivemind.moe.server.runtime.Runtime(module_backends: Dict[str, hivemind.moe.server.module_backend.ModuleBackend], prefetch_batches=64, sender_threads: int = 1, device: Optional[torch.device] = None, stats_report_interval: Optional[int] = None)[source]¶
A group of processes that processes incoming requests for multiple module backends on a shared device. Runtime is usually created and managed by Server, humans need not apply.
For debugging, you can start runtime manually with .start() or .run()
>>> module_backends = {'expert_name': ModuleBackend(**kwargs)} >>> runtime = Runtime(module_backends) >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run() >>> runtime.ready.wait() # await for runtime to load all experts on device and create request pools >>> future = runtime.module_backends['expert_name'].forward_pool.submit_task(*module_inputs) >>> print("Returned:", future.result()) >>> runtime.shutdown()
- Parameters
module_backends – a dict [expert uid -> ModuleBackend]
prefetch_batches – form up to this many batches in advance
sender_threads – dispatches outputs from finished batches using this many asynchronous threads
device – if specified, moves all experts and data to this device via .to(device=device). If you want to manually specify devices for each expert (in their forward pass), leave device=None (default)
stats_report_interval – interval to collect and log statistics about runtime performance
- run()[source]¶
Method representing the thread’s activity.
You may override this method in a subclass. The standard run() method invokes the callable object passed to the object’s constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.
- class hivemind.moe.server.task_pool.TaskPool(process_func: callable, max_batch_size: int, name: str, min_batch_size=1, timeout=None, pool_size=None, prefetch_batches=1, daemon=True, start=False)[source]¶
Request aggregator that accepts processing requests, groups them into batches, waits for Runtime to process these batches and dispatches results back to request sources. Operates as a background process.
- Parameters
process_func – function to be applied to every formed batch; called by Runtime Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
max_batch_size – process at most this many inputs in a batch (task contains have one or several inputs)
name – pool name
min_batch_size – process at least this many inputs in a batch, otherwise wait for more
timeout – wait for a subsequent task for at most this many seconds
pool_size – store at most this many unprocessed tasks in a queue
prefetch_batches – prepare up to this many batches in background for faster off-loading to runtime
start – if True, start automatically at the end of __init__
- submit_task(*args: torch.Tensor) concurrent.futures._base.Future [source]¶
Add task to this pool’s queue, return Future for its output
- iterate_minibatches(*args, **kwargs)[source]¶
Form minibatches by grouping one or more tasks together up to self.max_batch_size
- load_batch_to_runtime(timeout=None, device=None) Tuple[Any, List[torch.Tensor]] [source]¶
receive next batch of numpy arrays