提交 6004a2ed 编写于 作者: Y Yang Yang

add copy skeleton

上级 cb0b81f9
......@@ -47,7 +47,7 @@ void SplitTensorAndMoveTensorToScopes(
LOG(INFO) << lod.dims();
}
for (int i = 0; i < sub_scopes.size(); ++i) {
for (size_t i = 0; i < sub_scopes.size(); ++i) {
*sub_scopes[i]->Var(argu)->GetMutable<LoDTensor>() = lod_tensors[i];
}
}
......@@ -73,15 +73,14 @@ class ParallelDoOp : public framework::OperatorBase {
auto &sub_scopes = *scope.FindVar(Output(kParallelScopes))
->GetMutable<std::vector<framework::Scope *>>();
// std::vector<framework::Scope *> sub_scopes;
for (int place_idx = 0; place_idx < places.size(); ++place_idx) {
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
sub_scopes.push_back(&scope.NewScope());
}
SplitTensorAndMoveTensorToScopes(scope, sub_scopes, places,
Inputs(kInputs));
for (int place_idx = 0; place_idx < places.size(); ++place_idx) {
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx;
auto &place = places[place_idx];
......@@ -163,17 +162,12 @@ class ParallelDoGradOp : public OperatorBase {
}
// exe run
for (int place_idx = 0; place_idx < places.size(); ++place_idx) {
for (size_t place_idx = 0; place_idx < places.size(); ++place_idx) {
VLOG(3) << "Run " << place_idx;
auto &place = places[place_idx];
auto *cur_scope = sub_scopes[place_idx];
// copy parameter
if (dev_ctx.GetPlace() != place) {
PADDLE_THROW("Not Implemented");
}
// execute
auto executor = framework::Executor(place);
executor.Run(*program, cur_scope, block->ID(),
......@@ -181,6 +175,21 @@ class ParallelDoGradOp : public OperatorBase {
}
// merge grad
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
LOG(INFO) << s;
// std::string s_buf = s + "@BUF";
// auto *t_buf = sub_scopes[0]->Var(s_buf)->GetMutable<LoDTensor>();
for (size_t place_idx = 1; place_idx < places.size(); ++place_idx) {
LOG(INFO) << place_idx;
LOG(INFO) << sub_scopes[place_idx]->FindVar(s)->Get<LoDTensor>();
// Copy grad[i] to grad_buf[0]
// sum_op
}
// Copy grad[0] to grad
// auto *t = scope.FindVar(s)->GetMutable<LoDTensor>();
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册