提交 e67325cd 编写于 作者: Y Yang Yang

update readme

上级 0621c327
...@@ -30,23 +30,45 @@ operator run on each GPU, it will automatically sync with different streams when ...@@ -30,23 +30,45 @@ operator run on each GPU, it will automatically sync with different streams when
// if op's input is params' grad: // if op's input is params' grad:
// sync with allreduce stream // sync with allreduce stream
// e.g. sgd should wait for allreduce to be finished // e.g. sgd should wait for allreduce to be finished
SyncMultipleStreams(op); CallBack->BeforeOp(op);
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
// if op's output is params' grad: // if op's output is params' grad:
// sync with computation stream // sync with computation stream
// e.g. allreduce shoudl wait for fc_grad to be finished. // e.g. allreduce shoudl wait for fc_grad to be finished.
SyncMultipleStreams(op); CallBack->AfterOp(op);
``` ```
And the `Callback` object can be implemented as the following
## API ```c++
struct AllReduceCallBack {
void BeforeOp(framework::OperatorBase* op);
void AfterOp(framework::OperatorBase* op);
std::unordered_set<std::string> reduced_param_grad_names;
std::unordered_set<std::string> param_grad_names_;
platform::DeviceContext* computation_dev_ctx; // computation device context
platform::DeviceContext* communication_dev_ctx; // communication device context
The `ParallelExecutor.run` has similar interface as `Executor.run`. Besides framework::Scope* scope;
1. Scope: we don't expose `scope` in `ParallelExecutor.run` since `ParallelExecutor` has its platform::NCCL::Communicator* nccl_com;
own scope to maintain NCCL. };
1. Feed: we don't expose `feed` in the API either, because the whole point of implementing
parallel_executor is the speed. The input for NN should be implemented in an reader OP. AllReduceCallBack::BeforeOp(framework::OperatorBase* op) {
1. Fetch: we return the fetched value on all GPUs as a list. (e.g. `exe.run(..., fetch=loss)` if (op->Input() in reduced_param_grad_names) {
with return `[loss_on_gpu0, loss_on_gpu1]`) communication_dev_ctx->Wait();
reduced_param_grad_names.erase(op->Input())
}
}
AllReduceCallBack::AfterOp(framework::OperatorBase* op) {
if (op->Output() in param_grad_names) {
computation_dev_ctx->Wait();
reduced_param_grad_names.insert(op->Output());
ncclAllreduce(scope, op->Output(), communication_dev_ctx);
}
}
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册