From 6004a2ed4f74b864b8ff886d20e18891ac0a8ef3 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Wed, 3 Jan 2018 09:38:13 +0000 Subject: [PATCH] add copy skeleton --- paddle/operators/parallel_do_op.cc | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index fd48c1b54..6ac480b57 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -47,7 +47,7 @@ void SplitTensorAndMoveTensorToScopes( LOG(INFO) << lod.dims(); } - for (int i = 0; i < sub_scopes.size(); ++i) { + for (size_t i = 0; i < sub_scopes.size(); ++i) { *sub_scopes[i]->Var(argu)->GetMutable() = lod_tensors[i]; } } @@ -73,15 +73,14 @@ class ParallelDoOp : public framework::OperatorBase { auto &sub_scopes = *scope.FindVar(Output(kParallelScopes)) ->GetMutable>(); - // std::vector sub_scopes; - for (int place_idx = 0; place_idx < places.size(); ++place_idx) { + for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { sub_scopes.push_back(&scope.NewScope()); } SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places, Inputs(kInputs)); - for (int place_idx = 0; place_idx < places.size(); ++place_idx) { + for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { VLOG(3) << "Run " << place_idx; auto &place = places[place_idx]; @@ -163,17 +162,12 @@ class ParallelDoGradOp : public OperatorBase { } // exe run - for (int place_idx = 0; place_idx < places.size(); ++place_idx) { + for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) { VLOG(3) << "Run " << place_idx; auto &place = places[place_idx]; auto *cur_scope = sub_scopes[place_idx]; - // copy parameter - if (dev_ctx.GetPlace() != place) { - PADDLE_THROW("Not Implemented"); - } - // execute auto executor = framework::Executor(place); executor.Run(*program, cur_scope, block->ID(), @@ -181,6 +175,21 @@ class ParallelDoGradOp : public OperatorBase { } // merge grad + for (auto &s : Outputs(framework::GradVarName(kParameters))) { + LOG(INFO) << s; + // std::string s_buf = s + "@BUF"; + // auto *t_buf = sub_scopes[0]->Var(s_buf)->GetMutable(); + for (size_t place_idx = 1; place_idx < places.size(); ++place_idx) { + LOG(INFO) << place_idx; + LOG(INFO) << sub_scopes[place_idx]->FindVar(s)->Get(); + // Copy grad[i] to grad_buf[0] + + // sum_op + } + + // Copy grad[0] to grad + // auto *t = scope.FindVar(s)->GetMutable(); + } } }; -- GitLab