diff --git a/doc/design/parallel_executor.md b/doc/design/parallel_executor.md index 567eede1bd59bfb519bdb0b96de1684cfea6c61b..78ef74f159d4767055662711c0ac7346e393d860 100644 --- a/doc/design/parallel_executor.md +++ b/doc/design/parallel_executor.md @@ -30,23 +30,45 @@ operator run on each GPU, it will automatically sync with different streams when // if op's input is params' grad: // sync with allreduce stream // e.g. sgd should wait for allreduce to be finished -SyncMultipleStreams(op); +CallBack->BeforeOp(op); op->Run(*local_scope, place_); // if op's output is params' grad: // sync with computation stream // e.g. allreduce shoudl wait for fc_grad to be finished. -SyncMultipleStreams(op); +CallBack->AfterOp(op); ``` +And the `Callback` object can be implemented as the following -## API +```c++ +struct AllReduceCallBack { + void BeforeOp(framework::OperatorBase* op); + void AfterOp(framework::OperatorBase* op); + + std::unordered_set reduced_param_grad_names; + std::unordered_set param_grad_names_; + + platform::DeviceContext* computation_dev_ctx; // computation device context + platform::DeviceContext* communication_dev_ctx; // communication device context -The `ParallelExecutor.run` has similar interface as `Executor.run`. Besides -1. Scope: we don't expose `scope` in `ParallelExecutor.run` since `ParallelExecutor` has its -own scope to maintain NCCL. -1. Feed: we don't expose `feed` in the API either, because the whole point of implementing -parallel_executor is the speed. The input for NN should be implemented in an reader OP. -1. Fetch: we return the fetched value on all GPUs as a list. (e.g. `exe.run(..., fetch=loss)` -with return `[loss_on_gpu0, loss_on_gpu1]`) + framework::Scope* scope; + platform::NCCL::Communicator* nccl_com; +}; + +AllReduceCallBack::BeforeOp(framework::OperatorBase* op) { + if (op->Input() in reduced_param_grad_names) { + communication_dev_ctx->Wait(); + reduced_param_grad_names.erase(op->Input()) + } +} + +AllReduceCallBack::AfterOp(framework::OperatorBase* op) { + if (op->Output() in param_grad_names) { + computation_dev_ctx->Wait(); + reduced_param_grad_names.insert(op->Output()); + ncclAllreduce(scope, op->Output(), communication_dev_ctx); + } +} +```