未验证 提交 d85a9dc4 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #7621 from helinwang/remote_optimize

Recv OP: use BlockDesc* instread of ProgramDesc proto as Attribute
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr int kCondStart = 0; constexpr int kCondStart = 0;
constexpr int kCondRunning = 1; constexpr int kCondRunning = 1;
constexpr int kCondDone = 2; constexpr int kCondDone = 2;
...@@ -99,10 +100,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -99,10 +100,8 @@ class RecvOp : public framework::OperatorBase {
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
std::string program_str = Attr<std::string>("OptimizeProgram"); auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
framework::proto::ProgramDesc program_desc; auto *program = block->Program();
program_desc.ParseFromString(program_str);
framework::ProgramDesc program(program_desc);
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
...@@ -142,8 +141,9 @@ class RecvOp : public framework::OperatorBase { ...@@ -142,8 +141,9 @@ class RecvOp : public framework::OperatorBase {
if (exit_flag) { if (exit_flag) {
break; break;
} }
try { try {
executor.Run(program, &recv_scope, 0, /*global_block*/ executor.Run(*program, &recv_scope, block->ID(), /*global_block*/
false /*create_local_scope*/, false /*create_vars*/); false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
...@@ -175,8 +175,8 @@ This operator will recv tensor from send_op ...@@ -175,8 +175,8 @@ This operator will recv tensor from send_op
"IP address to listen on.") "IP address to listen on.")
.SetDefault("127.0.0.1:6164") .SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::string>("OptimizeProgram", "type string", AddAttr<framework::BlockDesc *>(
"Serialized ProgramDesc string for recv to run."); kOptimizeBlock, "Serialized ProgramDesc string for recv to run.");
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"ParamList", "type list of string", "ParamList", "type list of string",
"grad->param name mapping to find which param to optimize.") "grad->param name mapping to find which param to optimize.")
......
...@@ -130,10 +130,7 @@ void StartServerNet(bool is_sparse) { ...@@ -130,10 +130,7 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"endpoint", std::string("127.0.0.1:6174")});
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
std::string program_proto; attrs.insert({"OptimizeBlock", block});
PADDLE_ENFORCE(program.Proto()->SerializeToString(&program_proto));
attrs.insert({"OptimizeProgram", program_proto});
recv_op = f::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, {}, attrs); recv_op = f::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, {}, attrs);
recv_op->Run(scope, place); recv_op->Run(scope, place);
} }
......
...@@ -452,7 +452,7 @@ class DistributeTranspiler: ...@@ -452,7 +452,7 @@ class DistributeTranspiler:
}, # grads to recv }, # grads to recv
outputs={}, outputs={},
attrs={ attrs={
"OptimizeProgram": optimize_sub_program.desc, "OptimizeBlock": optimize_sub_program.global_block(),
"endpoint": endpoint, "endpoint": endpoint,
"ParamList": [ "ParamList": [
p.name p.name
......
...@@ -243,7 +243,7 @@ class SimpleDistributeTranspiler: ...@@ -243,7 +243,7 @@ class SimpleDistributeTranspiler:
self.param_grad_map[endpoint]["grads"]}, # grads to recv self.param_grad_map[endpoint]["grads"]}, # grads to recv
outputs={}, outputs={},
attrs={ attrs={
"OptimizeProgram": optimize_sub_program.desc, "OptimizeBlock": optimize_sub_program.global_block(),
"endpoint": endpoint, "endpoint": endpoint,
"ParamList": "ParamList":
[p.name for p in self.param_grad_map[endpoint]["params"]], [p.name for p in self.param_grad_map[endpoint]["params"]],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册