diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index a00458ea068dd703d2c7f362511ed08bc212d2a8..77936cf58e361dd1d2c5198b8583f687a26de077 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -31,6 +31,7 @@ static constexpr char kParallelScopes[] = "parallel_scopes"; static constexpr char kParallelBlock[] = "sub_block"; using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; static void SplitTensorAndMoveTensorToScopes( const framework::Scope &scope, std::vector *sub_scopes, @@ -64,6 +65,30 @@ static void SplitTensorAndMoveTensorToScopes( } } +inline void CopyOrShare(const framework::Variable& src, + const platform::Place& dst_place, + framework::Variable* dst) { + if (src.IsType()) { + if (src.Get().place() == dst_place) { + dst->GetMutable()->ShareDataWith(src.Get()); + } else { + Copy(src.Get(), dst_place, dst->GetMutable()); + } + } else if (src.IsType()) { + auto &src_sr = src.Get(); + auto *dst_sr = dst->GetMutable(); + dst_sr->set_rows(src_sr.rows()); + dst_sr->set_height(src_sr.height()); + Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); +// if (src_sr.value().place() == dst_place) { +// dst_sr->mutable_value()->ShareDataWith(src_sr.value()); +// } else { +// } + } else { + PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); + } +} + void WaitOnPlace(const platform::Place place) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); @@ -149,6 +174,7 @@ class ParallelDoOp : public framework::OperatorBase { lod_tensor_to_be_merged->MergeLoDTensor(lod_tensors, dev_ctx.GetPlace()); } WaitOnPlaces(places); + LOG(INFO) << "End of ParallelGradDo"; } }; @@ -210,21 +236,27 @@ class ParallelDoGradOp : public framework::OperatorBase { } WaitOnPlaces(places); - // merge grad + AccumulateGrad(scope, place, sub_scopes, places); + LOG(INFO) << "End of ParallelDoGrad"; + } + + void AccumulateGrad(const framework::Scope &scope, + const platform::Place &place, + const std::vector &sub_scopes, + const platform::PlaceList &places) const { for (auto &s : Outputs(framework::GradVarName(kParameters))) { - auto &result = sub_scopes[0]->FindVar(s)->Get(); - std::string tmp_name; - auto *tmp = sub_scopes[0]->Var(&tmp_name)->GetMutable(); + std::__cxx11::string tmp_name; + auto *tmp = sub_scopes[0]->Var(&tmp_name); + LOG(INFO) << "---" << s; for (size_t i = 1; i < sub_scopes.size(); ++i) { - auto &tensor_to_merge = sub_scopes[i]->FindVar(s)->Get(); if (!(places[i] == places[0])) { - framework::Copy(tensor_to_merge, places[0], tmp); + LOG(INFO) << "---"; + CopyOrShare(*sub_scopes[i]->FindVar(s), places[0], tmp); WaitOnPlace(places[0]); - } else { - tmp->ShareDataWith(tensor_to_merge); } + LOG(INFO) << "---"; auto sum_op = framework::OpRegistry::CreateOp( "sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}}, framework::AttributeMap{}); @@ -232,8 +264,8 @@ class ParallelDoGradOp : public framework::OperatorBase { WaitOnPlace(places[0]); } - VLOG(3) << result; - framework::Copy(result, place, scope.FindVar(s)->GetMutable()); + LOG(INFO) << "---"; + CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s)); } WaitOnPlaces(places); } @@ -289,7 +321,7 @@ class ParallelDoGradOpShapeInference : public framework::InferShapeBase { PADDLE_ENFORCE(ctx->HasInputs(kParameters)); PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters))); - PADDLE_ENFORCE(ctx->HasInput(kInputs)); + PADDLE_ENFORCE(ctx->HasInputs(kInputs)); for (auto &s : output) { PADDLE_ENFORCE(ctx->HasInputs(s)); diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 2f1188c542cc0208a189511a1eef1eddc411007c..0d6062fcebc5060c1b76abcfee42e5c053aebb3e 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -270,6 +270,7 @@ class ParallelDo(object): for in_var_name in op.input(iname): if in_var_name not in local_inputs: params.append(in_var_name) + params = list(set(params)) return [parent_block.var(name) for name in params]