From 1a67061fee99e3205cf40cfc4d1153198bd371fa Mon Sep 17 00:00:00 2001
From: Xin Pan <panxin.grad@gmail.com>
Date: Fri, 24 Aug 2018 13:49:41 +0800
Subject: [PATCH] graph to program pass

fix a few other things
---
 paddle/fluid/framework/CMakeLists.txt         |  4 +-
 paddle/fluid/framework/ir/CMakeLists.txt      |  2 +
 .../framework/ir/graph_to_program_pass.cc     | 65 +++++++++++++++++++
 .../framework/ir/graph_to_program_pass.h      | 30 +++++++++
 .../ir/graph_to_program_pass_test.cc          | 21 ++++++
 paddle/fluid/framework/op_desc.cc             |  4 +-
 paddle/fluid/framework/program_desc.cc        |  6 ++
 paddle/fluid/framework/program_desc.h         |  2 +
 paddle/fluid/inference/CMakeLists.txt         |  2 +-
 paddle/fluid/inference/io.cc                  |  1 -
 paddle/fluid/inference/tests/test_helper.h    | 14 ++++
 paddle/fluid/operators/parallel_do_op.cc      |  1 +
 .../test_memopt_image_classification_train.py |  4 +-
 .../test_memopt_machine_translation.py        |  4 +-
 14 files changed, 151 insertions(+), 9 deletions(-)
 create mode 100644 paddle/fluid/framework/ir/graph_to_program_pass.cc
 create mode 100644 paddle/fluid/framework/ir/graph_to_program_pass.h
 create mode 100644 paddle/fluid/framework/ir/graph_to_program_pass_test.cc

diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index 2c62d4ed6b0..0668ff43c81 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -107,11 +107,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
 cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
 
 if(WITH_DISTRIBUTE)
-  cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr)
+  cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass)
   set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
   set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
 else()
-  cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method)
+  cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
 endif()
 
 if (NOT WIN32)
diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt
index da0955a9a00..9300573d7fb 100644
--- a/paddle/fluid/framework/ir/CMakeLists.txt
+++ b/paddle/fluid/framework/ir/CMakeLists.txt
@@ -3,6 +3,7 @@ cc_library(graph SRCS graph.cc DEPS node)
 cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
 cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
 cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
+cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
 cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
 cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
 cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
@@ -12,5 +13,6 @@ cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass
 cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
 cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
 cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
+cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
 cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
 cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.cc b/paddle/fluid/framework/ir/graph_to_program_pass.cc
new file mode 100644
index 00000000000..414d8f79b15
--- /dev/null
+++ b/paddle/fluid/framework/ir/graph_to_program_pass.cc
@@ -0,0 +1,65 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "paddle/fluid/framework/ir/graph.h"
+#include "paddle/fluid/framework/ir/graph_helper.h"
+
+#include "paddle/fluid/framework/program_desc.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
+    std::unique_ptr<Graph> graph) const {
+  ProgramDesc& program = Get<ProgramDesc>("program");
+
+  std::unique_ptr<proto::ProgramDesc> program_pb(
+      new proto::ProgramDesc(*program.Proto()));
+
+  auto block = program_pb->mutable_blocks(kRootBlockIndex);
+  block->clear_vars();
+  std::unordered_set<std::string> visited_vars;
+  for (ir::Node* n : graph->Nodes()) {
+    if (n->NodeType() == ir::Node::Type::kVariable) {
+      if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
+        visited_vars.insert(n->Var()->Name());
+        block->add_vars()->MergeFrom(*n->Var()->Proto());
+      }
+    }
+  }
+
+  block->clear_ops();
+  std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
+  for (ir::Node* n : nodes) {
+    if (!n->Op()) {
+      continue;
+    }
+    block->add_ops()->MergeFrom(*n->Op()->Proto());
+  }
+
+  program.CopyFrom(*program_pb);
+  return graph;
+}
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(graph_to_program_pass, paddle::framework::ir::GraphToProgramPass);
diff --git a/paddle/fluid/framework/ir/graph_to_program_pass.h b/paddle/fluid/framework/ir/graph_to_program_pass.h
new file mode 100644
index 00000000000..124ec5a8e77
--- /dev/null
+++ b/paddle/fluid/framework/ir/graph_to_program_pass.h
@@ -0,0 +1,30 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "paddle/fluid/framework/ir/pass.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+class GraphToProgramPass : public Pass {
+ protected:
+  std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_to_program_pass_test.cc b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc
new file mode 100644
index 00000000000..3adbf888a8b
--- /dev/null
+++ b/paddle/fluid/framework/ir/graph_to_program_pass_test.cc
@@ -0,0 +1,21 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc
index 122dc161b41..59b6007284b 100644
--- a/paddle/fluid/framework/op_desc.cc
+++ b/paddle/fluid/framework/op_desc.cc
@@ -132,7 +132,9 @@ OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
     std::string attr_name = attr.name();
     // The sub_block referred to by the BLOCK attr hasn't been added
     // to ProgramDesc class yet, we skip setting BLOCK attr here.
-    if (attr.type() != proto::AttrType::BLOCK) {
+    // TODO(paddle-dev): Need copy fix this to copy Block as well.
+    if (attr.type() != proto::AttrType::BLOCK &&
+        attr.type() != proto::AttrType::BLOCKS) {
       attrs_[attr_name] = GetAttrValue(attr);
     }
   }
diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc
index 344c001a69b..c2b91069d9a 100644
--- a/paddle/fluid/framework/program_desc.cc
+++ b/paddle/fluid/framework/program_desc.cc
@@ -80,6 +80,12 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
   InitFromProto();
 }
 
+void ProgramDesc::CopyFrom(const proto::ProgramDesc &desc) {
+  blocks_.clear();
+  desc_ = desc;
+  InitFromProto();
+}
+
 ProgramDesc::ProgramDesc(const std::string &binary_str) {
   PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
                  "Fail to parse program_desc from binary string.");
diff --git a/paddle/fluid/framework/program_desc.h b/paddle/fluid/framework/program_desc.h
index f3afc85eb92..a0e81cade18 100644
--- a/paddle/fluid/framework/program_desc.h
+++ b/paddle/fluid/framework/program_desc.h
@@ -53,6 +53,8 @@ class ProgramDesc {
 
   void Flush();
 
+  void CopyFrom(const proto::ProgramDesc &desc);
+
   proto::ProgramDesc *Proto();
 
   // The output variable of feed_op is referenced as feed_target.
diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt
index ba7645aa024..a4f6364ae5b 100644
--- a/paddle/fluid/inference/CMakeLists.txt
+++ b/paddle/fluid/inference/CMakeLists.txt
@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
 # TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
 cc_library(paddle_fluid_api
     SRCS io.cc
-    DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
+    DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} graph_to_program_pass)
 
 get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
 
diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc
index 181868977dd..f29b190a730 100644
--- a/paddle/fluid/inference/io.cc
+++ b/paddle/fluid/inference/io.cc
@@ -138,7 +138,6 @@ std::unique_ptr<framework::ProgramDesc> Load(
 
   std::unique_ptr<framework::ProgramDesc> main_program(
       new framework::ProgramDesc(program_desc_str));
-
   LoadPersistables(executor, scope, *main_program, "", param_filename);
   return main_program;
 }
diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h
index 695790a37dc..94f0550df57 100644
--- a/paddle/fluid/inference/tests/test_helper.h
+++ b/paddle/fluid/inference/tests/test_helper.h
@@ -18,6 +18,7 @@ limitations under the License. */
 #include <string>
 #include <vector>
 
+#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
 #include "paddle/fluid/framework/lod_tensor.h"
 #include "paddle/fluid/inference/io.h"
 #include "paddle/fluid/platform/profiler.h"
@@ -135,6 +136,15 @@ std::vector<std::vector<int64_t>> GetFeedTargetShapes(
   return feed_target_shapes;
 }
 
+void Compile(paddle::framework::ProgramDesc* program) {
+  std::unique_ptr<paddle::framework::ir::Graph> g(
+      new paddle::framework::ir::Graph(*program));
+  auto pass = paddle::framework::ir::PassRegistry::Instance().Get(
+      "graph_to_program_pass");
+  pass->SetNotOwned<paddle::framework::ProgramDesc>("program", program);
+  pass->Apply(std::move(g));
+}
+
 template <typename Place, bool CreateVars = true, bool PrepareContext = false>
 void TestInference(const std::string& dirname,
                    const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
@@ -172,6 +182,8 @@ void TestInference(const std::string& dirname,
         paddle::platform::DeviceContextPool::Instance().Get(place));
     inference_program = InitProgram(&executor, scope, dirname, is_combined);
   }
+  Compile(inference_program.get());
+
   // Disable the profiler and print the timing information
   paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
                                     "load_program_profiler");
@@ -249,3 +261,5 @@ void TestInference(const std::string& dirname,
 
   delete scope;
 }
+
+USE_PASS(graph_to_program_pass);
diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc
index eb09470f37e..97c36a83fc5 100644
--- a/paddle/fluid/operators/parallel_do_op.cc
+++ b/paddle/fluid/operators/parallel_do_op.cc
@@ -355,6 +355,7 @@ class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
         grad->SetInput(framework::GradVarName(output_param), og_names);
       }
     }
+    grad->SetInput("Communicator", {"nccl_com__do_not_change_"});
     grad->SetAttrMap(this->Attrs());
     grad->SetBlockAttr(kParallelBlock, grad_block_[0]);
 
diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
index 3951e7b8ca6..a231bbfbc8d 100644
--- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
+++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py
@@ -125,8 +125,8 @@ opts = optimizer.minimize(avg_cost)
 batch_size = fluid.layers.create_tensor(dtype='int64')
 batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size)
 
-# fluid.memory_optimize(fluid.default_main_program(), level=0)
-fluid.release_memory(fluid.default_main_program())
+fluid.memory_optimize(fluid.default_main_program(), level=0)
+# fluid.release_memory(fluid.default_main_program())
 
 BATCH_SIZE = 16
 PASS_NUM = 1
diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py
index 1ad51936b5b..e520c896508 100644
--- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py
+++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_machine_translation.py
@@ -92,8 +92,8 @@ def main():
     optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
     optimizer.minimize(avg_cost)
 
-    # fluid.memory_optimize(fluid.default_main_program())
-    fluid.release_memory(fluid.default_main_program())
+    fluid.memory_optimize(fluid.default_main_program())
+    # fluid.release_memory(fluid.default_main_program())
 
     # fix the order of training data
     train_data = paddle.batch(
-- 
GitLab