提交 4bb492e7 编写于 作者: Y Yang Yang

pass tiny data

上级 bb3ae206
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static constexpr char kParallelScopes[] = "parallel_scopes";
// NCCLinitOp // NCCLinitOp
class NCCLInitOp : public framework::OperatorBase { class NCCLInitOp : public framework::OperatorBase {
public: public:
...@@ -29,24 +31,37 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -29,24 +31,37 @@ class NCCLInitOp : public framework::OperatorBase {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kParallelScopes)),
"Can not find variable '%s' in the scope.",
kParallelScopes);
const auto &name = Output("Communicator"); const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name); "Can not find variable '%s' in the scope.", name);
// A parallel do may not use all the gpus. For example, the batch size is 7
int count = platform::GetCUDADeviceCount(); // in the last batch while we have 8 gpu. In this case, parallel_do will
std::vector<int> gpus(count); // create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
for (int i = 0; i < count; ++i) { LOG(INFO) << "---------------";
auto &parallel_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
LOG(INFO) << "---------------";
std::vector<int> gpus(parallel_scopes.size());
for (int i = 0; i < static_cast<int>(parallel_scopes.size()); ++i) {
gpus[i] = i; gpus[i] = i;
} }
LOG(INFO) << "---------------";
PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus."); PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus.");
LOG(INFO) << "---------------";
if (scope.FindVar(name) == nullptr) { if (scope.FindVar(name) == nullptr) {
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator."); PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
} }
LOG(INFO) << "---------------";
platform::Communicator *comm = platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>(); scope.FindVar(name)->GetMutable<platform::Communicator>();
LOG(INFO) << "---------------";
comm->InitAll(gpus); comm->InitAll(gpus);
LOG(INFO) << "---------------";
} }
}; };
...@@ -70,6 +85,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -70,6 +85,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker) NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kParallelScopes, "The working place of parallel do.");
AddOutput("Communicator", AddOutput("Communicator",
"Create Communicator for communicating between gpus"); "Create Communicator for communicating between gpus");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -223,9 +223,10 @@ def _callback_lookup_(op): ...@@ -223,9 +223,10 @@ def _callback_lookup_(op):
param_grad_names = [n + "@GRAD" for n in param_names] param_grad_names = [n + "@GRAD" for n in param_names]
class ParallelDoCallBack(object): class ParallelDoCallBack(object):
def __init__(self, param_grad_names): def __init__(self, param_grad_names, parallel_scopes_name):
self.has_inserted_nccl_init = False self.has_inserted_nccl_init = False
self.param_grad_names = param_grad_names self.param_grad_names = param_grad_names
self.parallel_scopes_name = parallel_scopes_name
def __call__(self, block, context): def __call__(self, block, context):
if not self.has_inserted_nccl_init: if not self.has_inserted_nccl_init:
...@@ -242,7 +243,8 @@ def _callback_lookup_(op): ...@@ -242,7 +243,8 @@ def _callback_lookup_(op):
# inputs={}, # inputs={},
# outputs={'Communicator': [self.nccl_com]}) # outputs={'Communicator': [self.nccl_com]})
op_desc = _create_op_desc_( op_desc = _create_op_desc_(
"ncclInit", {}, "ncclInit",
{"parallel_scopes": self.parallel_scopes_name},
{"Communicator": ['nccl_com__do_not_change_']}, {}) {"Communicator": ['nccl_com__do_not_change_']}, {})
# block.desc.append_op().copy_from(op_desc) # block.desc.append_op().copy_from(op_desc)
print(serialize_op_decs(op_desc)) print(serialize_op_decs(op_desc))
...@@ -281,7 +283,8 @@ def _callback_lookup_(op): ...@@ -281,7 +283,8 @@ def _callback_lookup_(op):
{"Out": [o_argu]}, {}) {"Out": [o_argu]}, {})
block.desc.append_op().copy_from(op_desc) block.desc.append_op().copy_from(op_desc)
return ParallelDoCallBack(param_grad_names) return ParallelDoCallBack(param_grad_names,
op.output("parallel_scopes"))
else: else:
return None return None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册