diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 8d1479bdd6311709baaf2a6c673db3d0de4610f8..f9f1c134d798a4e68a629364c17d1a59a94a7002 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -34,6 +34,7 @@ limitations under the License. */ namespace paddle { namespace operators { +constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr int kCondStart = 0; constexpr int kCondRunning = 1; constexpr int kCondDone = 2; @@ -99,10 +100,8 @@ class RecvOp : public framework::OperatorBase { auto fan_in = Attr("Fanin"); size_t param_count = param_list.size(); - std::string program_str = Attr("OptimizeProgram"); - framework::proto::ProgramDesc program_desc; - program_desc.ParseFromString(program_str); - framework::ProgramDesc program(program_desc); + auto *block = Attr(kOptimizeBlock); + auto *program = block->Program(); framework::Executor executor(dev_place); // TODO(typhoonzero): change this to a while_op for every cluster-batch. @@ -142,8 +141,9 @@ class RecvOp : public framework::OperatorBase { if (exit_flag) { break; } + try { - executor.Run(program, &recv_scope, 0, /*global_block*/ + executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ false /*create_local_scope*/, false /*create_vars*/); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); @@ -175,8 +175,8 @@ This operator will recv tensor from send_op "IP address to listen on.") .SetDefault("127.0.0.1:6164") .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); - AddAttr("OptimizeProgram", "type string", - "Serialized ProgramDesc string for recv to run."); + AddAttr( + kOptimizeBlock, "Serialized ProgramDesc string for recv to run."); AddAttr>( "ParamList", "type list of string", "grad->param name mapping to find which param to optimize.") diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc index ea091694798475dfd9631910a750405be950c20c..045a0f5434f339bab345d14881ed05450ce6588d 100644 --- a/paddle/operators/send_recv_op_test.cc +++ b/paddle/operators/send_recv_op_test.cc @@ -130,10 +130,7 @@ void StartServerNet(bool is_sparse) { attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"ParamList", std::vector({"Out"})}); attrs.insert({"GradList", std::vector({"x1"})}); - std::string program_proto; - PADDLE_ENFORCE(program.Proto()->SerializeToString(&program_proto)); - - attrs.insert({"OptimizeProgram", program_proto}); + attrs.insert({"OptimizeBlock", block}); recv_op = f::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, {}, attrs); recv_op->Run(scope, place); } diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index bd957f88de5d51a2fa3e482284e2d8080f1be76b..02a0e4cd2639e857bce07afa9858531e8d177ad0 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -452,7 +452,7 @@ class DistributeTranspiler: }, # grads to recv outputs={}, attrs={ - "OptimizeProgram": optimize_sub_program.desc, + "OptimizeBlock": optimize_sub_program.global_block(), "endpoint": endpoint, "ParamList": [ p.name diff --git a/python/paddle/v2/fluid/distribute_transpiler_simple.py b/python/paddle/v2/fluid/distribute_transpiler_simple.py index bd88f02bde0c6a58138e20db2b07cbd06cd40ba3..56ffb56b1247646903485e5859b60f63df9b97a2 100644 --- a/python/paddle/v2/fluid/distribute_transpiler_simple.py +++ b/python/paddle/v2/fluid/distribute_transpiler_simple.py @@ -243,7 +243,7 @@ class SimpleDistributeTranspiler: self.param_grad_map[endpoint]["grads"]}, # grads to recv outputs={}, attrs={ - "OptimizeProgram": optimize_sub_program.desc, + "OptimizeBlock": optimize_sub_program.global_block(), "endpoint": endpoint, "ParamList": [p.name for p in self.param_grad_map[endpoint]["params"]],