提交 12a0e2ed 编写于 作者: X Xin Pan

polish codes

test=develop
上级 19d78f67
...@@ -50,7 +50,7 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -50,7 +50,7 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
std::unordered_map<std::string, int> vars; std::unordered_map<std::string, int> vars;
// TODO(gongwb): use graph topology sort to find the order of operators. // TODO(gongwb): use graph topology sort to find the order of operators.
// Note that must assert topology sort is stable // Note that must assert topology sort is stable
auto& ops = graph->Get<const std::vector<OpDesc*>>(kAllOpDescs); auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
for (auto* op_desc : ops) { for (auto* op_desc : ops) {
auto outputs = op_desc->Outputs(); auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) { for (auto& o_it : outputs) {
...@@ -120,4 +120,4 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -120,4 +120,4 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
REGISTER_PASS(all_reduce_deps_pass, REGISTER_PASS(all_reduce_deps_pass,
paddle::framework::details::AllReduceDepsPass) paddle::framework::details::AllReduceDepsPass)
.RequireGraphAttr(paddle::framework::details::kAllOpDescs); .RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
...@@ -33,10 +33,10 @@ namespace details { ...@@ -33,10 +33,10 @@ namespace details {
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) { std::vector<ir::Node*> SortOpLikeDescOrder(const ir::Graph& graph) {
PADDLE_ENFORCE(graph.Has(kAllOpDescs), PADDLE_ENFORCE(graph.Has(kStaleProgramOpDescs),
"Graph has no attribute of kAllOpDescs."); "Graph has no attribute of kStaleProgramOpDescs.");
// 1. get op desc order // 1. get op desc order
auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kAllOpDescs); auto& op_descs = graph.Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
// 2. topology sort order // 2. topology sort order
auto nodes = graph.Nodes(); auto nodes = graph.Nodes();
......
...@@ -336,4 +336,5 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var, ...@@ -336,4 +336,5 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
} // namespace paddle } // namespace paddle
REGISTER_PASS(memory_optimize_pass, REGISTER_PASS(memory_optimize_pass,
paddle::framework::details::MemoryOptimizePass); paddle::framework::details::MemoryOptimizePass)
.RequireGraphAttr(paddle::framework::details::kAllOpDescs);
...@@ -40,7 +40,7 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( ...@@ -40,7 +40,7 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
static std::unordered_set<std::string> skip_dist_ops{ static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"}; "send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = graph->Get<const std::vector<OpDesc *>>(kAllOpDescs); auto &ops = graph->Get<const std::vector<OpDesc *>>(kStaleProgramOpDescs);
std::vector<ir::Node *> op_node_list; std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size()); op_node_list.reserve(ops.size());
...@@ -107,4 +107,4 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl( ...@@ -107,4 +107,4 @@ std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
REGISTER_PASS(sequential_execution_pass, REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass) paddle::framework::details::SequentialExecutionPass)
.RequireGraphAttr(paddle::framework::details::kAllOpDescs); .RequireGraphAttr(paddle::framework::details::kStaleProgramOpDescs);
...@@ -77,7 +77,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( ...@@ -77,7 +77,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
} }
} }
Set<const std::vector<OpDesc *>>( Set<const std::vector<OpDesc *>>(
details::kAllOpDescs, details::kStaleProgramOpDescs,
new std::vector<OpDesc *>(program.Block(0).AllOps())); new std::vector<OpDesc *>(program.Block(0).AllOps()));
return var_nodes; return var_nodes;
} }
......
...@@ -31,7 +31,7 @@ namespace details { ...@@ -31,7 +31,7 @@ namespace details {
// This attr is not recommended, because the graph should not dependence // This attr is not recommended, because the graph should not dependence
// the program once it is built. // the program once it is built.
constexpr char kAllOpDescs[] = "all_op_descs"; constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs";
} // namespace details } // namespace details
namespace ir { namespace ir {
......
...@@ -2322,7 +2322,7 @@ class Program(object): ...@@ -2322,7 +2322,7 @@ class Program(object):
@staticmethod @staticmethod
def _construct_from_desc(desc): def _construct_from_desc(desc):
""" """
Construct a program from program desc. (Experiment) Construct a program from program desc.
Args: Args:
desc(core.ProgramDesc): The program desc for constructing. desc(core.ProgramDesc): The program desc for constructing.
...@@ -2332,7 +2332,6 @@ class Program(object): ...@@ -2332,7 +2332,6 @@ class Program(object):
""" """
p = Program() p = Program()
p.desc = desc p.desc = desc
# TODO(wangzhen): Block.vars/ops are not filled, should fix it.
p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())] p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())]
p._sync_with_cpp() p._sync_with_cpp()
return p return p
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册