提交 3067114f 编写于 作者: Y Yang Yang

clean up

上级 cd9e660d
...@@ -47,11 +47,8 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -47,11 +47,8 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<LoDTensor>("X"); auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto outs = ctx.MultiOutput<LoDTensor>("Out");
LOG(INFO) << "------------------";
std::string reduction = ctx.Attr<std::string>("reduction"); std::string reduction = ctx.Attr<std::string>("reduction");
LOG(INFO) << "------------------";
ncclRedOp_t reduction_op_ = ncclSum; ncclRedOp_t reduction_op_ = ncclSum;
LOG(INFO) << "------------------";
if (reduction == "ncclMin") { if (reduction == "ncclMin") {
reduction_op_ = ncclMin; reduction_op_ = ncclMin;
...@@ -65,19 +62,14 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -65,19 +62,14 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_THROW("Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
LOG(INFO) << "------------------";
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
LOG(INFO) << "------------------";
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
LOG(INFO) << "------------------";
// device id // device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
LOG(INFO) << "------------------";
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
LOG(INFO) << "------------------";
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << "gpu : " VLOG(1) << "gpu : "
<< " invoke allreduce. send " << ins[i]->numel() << " recv " << " invoke allreduce. send " << ins[i]->numel() << " recv "
......
...@@ -151,7 +151,6 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -151,7 +151,6 @@ class ParallelDoOp : public framework::OperatorBase {
} }
WaitOnPlaces(places); WaitOnPlaces(places);
// PADDLE_ENFORCE_EQ(places.size(), sub_scopes.size());
std::vector<std::future<void>> workers; std::vector<std::future<void>> workers;
workers.reserve(places.size()); workers.reserve(places.size());
for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) { for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) {
...@@ -219,21 +218,18 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -219,21 +218,18 @@ class ParallelDoGradOp : public framework::OperatorBase {
auto &sub_scopes = scope.FindVar(Input(kParallelScopes)) auto &sub_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>(); ->Get<std::vector<framework::Scope *>>();
auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>(); auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();
// PADDLE_ENFORCE_EQ(places.size(), sub_scopes.size());
// feed output@grad // feed output@grad
SplitTensorAndMoveTensorToScopes( SplitTensorAndMoveTensorToScopes(
scope, const_cast<std::vector<framework::Scope *> *>(&sub_scopes), scope, const_cast<std::vector<framework::Scope *> *>(&sub_scopes),
places, Inputs(framework::GradVarName(kOutputs))); places, Inputs(framework::GradVarName(kOutputs)));
WaitOnPlaces(places); WaitOnPlaces(places);
LOG(INFO) << "places " << places.size();
// exe run // exe run
std::vector<std::future<void>> workers; std::vector<std::future<void>> workers;
for (size_t i = 0; i < sub_scopes.size(); ++i) { for (size_t i = 0; i < sub_scopes.size(); ++i) {
auto &place = places[i]; auto &place = places[i];
auto *cur_scope = sub_scopes[i]; auto *cur_scope = sub_scopes[i];
LOG(INFO) << place;
// execute // execute
workers.emplace_back(framework::Async([program, cur_scope, place, block] { workers.emplace_back(framework::Async([program, cur_scope, place, block] {
...@@ -242,7 +238,6 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -242,7 +238,6 @@ class ParallelDoGradOp : public framework::OperatorBase {
false /*create_local_scope*/); false /*create_local_scope*/);
})); }));
} }
LOG(INFO) << "places " << places.size();
for (auto &worker : workers) { for (auto &worker : workers) {
worker.wait(); worker.wait();
} }
......
...@@ -230,44 +230,19 @@ def _callback_lookup_(op): ...@@ -230,44 +230,19 @@ def _callback_lookup_(op):
def __call__(self, block, context): def __call__(self, block, context):
if not self.has_inserted_nccl_init: if not self.has_inserted_nccl_init:
# global_block = block.program.global_block()
# op_desc = global_block.desc.append_op()
# var_desc = global_block.desc.var('nccl_com__do_not_change_')
# var_desc.set_type(core.VarDesc.VarType.NCCL_COM)
# self.nccl_com = global_block.create_var(
# name='nccl_com', type=core.VarDesc.VarType.NCCL_COM)
# framework.Operator(
# global_block,
# type='ncclInit',
# desc=op_desc,
# inputs={},
# outputs={'Communicator': [self.nccl_com]})
op_desc = _create_op_desc_( op_desc = _create_op_desc_(
"ncclInit", "ncclInit",
{"parallel_scopes": self.parallel_scopes_name}, {"parallel_scopes": self.parallel_scopes_name},
{"Communicator": ['nccl_com__do_not_change_']}, {}) {"Communicator": ['nccl_com__do_not_change_']}, {})
# block.desc.append_op().copy_from(op_desc)
print(serialize_op_decs(op_desc)) print(serialize_op_decs(op_desc))
block.program.global_block().desc.append_op().copy_from( block.program.global_block().desc.append_op().copy_from(
op_desc) op_desc)
self.has_inserted_nccl_init = True self.has_inserted_nccl_init = True
current_op_desc = context["__current_op_desc__"] current_op_desc = context["__current_op_desc__"]
# print(serialize_op_decs(context))
for o_param in current_op_desc.output_names(): for o_param in current_op_desc.output_names():
for o_argu in current_op_desc.output(o_param): for o_argu in current_op_desc.output(o_param):
if o_argu in self.param_grad_names: if o_argu in self.param_grad_names:
# # print("reduce", o_argu)
# op_desc = block.desc.append_op()
# op_desc.set_type("ncclAllReduce")
# op_desc.set_input("X", [o_argu])
#
# # FIXME(tonyyang-svail):
# # Looks like nccl_com has been changed to nccl_com_0
# op_desc.set_input("Communicator", ['nccl_com_0'])
# out_var = block.create_var()
# op_desc.set_output("Out", [out_var.name])
# op_desc.set_attr("reduction", "ncclSum")
allreduce_out_name = o_argu + "__nccl_all_reduce__" allreduce_out_name = o_argu + "__nccl_all_reduce__"
op_desc = _create_op_desc_( op_desc = _create_op_desc_(
"ncclAllReduce", { "ncclAllReduce", {
......
...@@ -175,7 +175,9 @@ class ParallelOpTest(BaseParallelForTest): ...@@ -175,7 +175,9 @@ class ParallelOpTest(BaseParallelForTest):
def test_simple_fc(self): def test_simple_fc(self):
self.run_test( self.run_test(
callback=self.__network__, callback=self.__network__,
feed={'img': numpy.random.random(size=(8, 784)).astype('float32')}, feed={
'img': numpy.random.random(size=(51, 784)).astype('float32')
},
fetch=['fc1.w@GRAD']) fetch=['fc1.w@GRAD'])
def test_fc_with_tiny_data(self): def test_fc_with_tiny_data(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册