提交 4bb492e7 编写于 作者: Y Yang Yang

pass tiny data

上级 bb3ae206
......@@ -19,6 +19,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
static constexpr char kParallelScopes[] = "parallel_scopes";
// NCCLinitOp
class NCCLInitOp : public framework::OperatorBase {
public:
......@@ -29,24 +31,37 @@ class NCCLInitOp : public framework::OperatorBase {
void Run(const framework::Scope &scope,
const platform::Place &place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kParallelScopes)),
"Can not find variable '%s' in the scope.",
kParallelScopes);
const auto &name = Output("Communicator");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
"Can not find variable '%s' in the scope.", name);
int count = platform::GetCUDADeviceCount();
std::vector<int> gpus(count);
for (int i = 0; i < count; ++i) {
// A parallel do may not use all the gpus. For example, the batch size is 7
// in the last batch while we have 8 gpu. In this case, parallel_do will
// create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
LOG(INFO) << "---------------";
auto &parallel_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>();
LOG(INFO) << "---------------";
std::vector<int> gpus(parallel_scopes.size());
for (int i = 0; i < static_cast<int>(parallel_scopes.size()); ++i) {
gpus[i] = i;
}
LOG(INFO) << "---------------";
PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus.");
LOG(INFO) << "---------------";
if (scope.FindVar(name) == nullptr) {
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
}
LOG(INFO) << "---------------";
platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>();
LOG(INFO) << "---------------";
comm->InitAll(gpus);
LOG(INFO) << "---------------";
}
};
......@@ -70,6 +85,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kParallelScopes, "The working place of parallel do.");
AddOutput("Communicator",
"Create Communicator for communicating between gpus");
AddComment(R"DOC(
......
......@@ -223,9 +223,10 @@ def _callback_lookup_(op):
param_grad_names = [n + "@GRAD" for n in param_names]
class ParallelDoCallBack(object):
def __init__(self, param_grad_names):
def __init__(self, param_grad_names, parallel_scopes_name):
self.has_inserted_nccl_init = False
self.param_grad_names = param_grad_names
self.parallel_scopes_name = parallel_scopes_name
def __call__(self, block, context):
if not self.has_inserted_nccl_init:
......@@ -242,7 +243,8 @@ def _callback_lookup_(op):
# inputs={},
# outputs={'Communicator': [self.nccl_com]})
op_desc = _create_op_desc_(
"ncclInit", {},
"ncclInit",
{"parallel_scopes": self.parallel_scopes_name},
{"Communicator": ['nccl_com__do_not_change_']}, {})
# block.desc.append_op().copy_from(op_desc)
print(serialize_op_decs(op_desc))
......@@ -281,7 +283,8 @@ def _callback_lookup_(op):
{"Out": [o_argu]}, {})
block.desc.append_op().copy_from(op_desc)
return ParallelDoCallBack(param_grad_names)
return ParallelDoCallBack(param_grad_names,
op.output("parallel_scopes"))
else:
return None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册