From 030c0b91fb272b5899cdd53797aa5b902f1454e1 Mon Sep 17 00:00:00 2001
From: Abhinav Arora <aroraabhinav@baidu.com>
Date: Thu, 18 Jan 2018 15:46:32 -0800
Subject: [PATCH] Do not send to optimize_ops to distribute transpiler again

---
 python/paddle/v2/fluid/distribute_transpiler.py           | 8 ++++----
 .../fluid/tests/book_distribute/notest_dist_fit_a_line.py | 2 +-
 .../book_distribute/notest_dist_label_semantic_roles.py   | 2 +-
 .../fluid/tests/book_distribute/notest_dist_word2vec.py   | 2 +-
 .../book_distribute/notest_recognize_digits_conv_dist.py  | 2 +-
 .../book_distribute/notest_recognize_digits_mlp_dist.py   | 2 +-
 .../notest_understand_sentiment_conv_dist.py              | 2 +-
 7 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py
index 06a7b6fb0..9591bd92a 100644
--- a/python/paddle/v2/fluid/distribute_transpiler.py
+++ b/python/paddle/v2/fluid/distribute_transpiler.py
@@ -407,7 +407,7 @@ class DistributeTranspiler:
             outputs=opt_op.outputs,
             attrs=opt_op.attrs)
 
-    def get_pserver_program(self, endpoint, optimize_ops):
+    def get_pserver_program(self, endpoint):
         """
         get pserver side program by endpoint
 
@@ -422,9 +422,9 @@ class DistributeTranspiler:
             self._clone_var(pserver_program.global_block(), v)
         # step6
         optimize_sub_program = Program()
-        for idx, opt_op in enumerate(optimize_ops):
-            is_op_on_pserver = self._is_op_on_pserver(endpoint, optimize_ops,
-                                                      idx)
+        for idx, opt_op in enumerate(self.optimize_ops):
+            is_op_on_pserver = self._is_op_on_pserver(endpoint,
+                                                      self.optimize_ops, idx)
             if not is_op_on_pserver:
                 continue
             if opt_op.inputs.has_key("Grad"):
diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_dist_fit_a_line.py b/python/paddle/v2/fluid/tests/book_distribute/notest_dist_fit_a_line.py
index b886071f9..881cca2b7 100644
--- a/python/paddle/v2/fluid/tests/book_distribute/notest_dist_fit_a_line.py
+++ b/python/paddle/v2/fluid/tests/book_distribute/notest_dist_fit_a_line.py
@@ -53,7 +53,7 @@ if training_role == "PSERVER":
     if not current_endpoint:
         print("need env SERVER_ENDPOINT")
         exit(1)
-    pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
+    pserver_prog = t.get_pserver_program(current_endpoint)
     exe.run(fluid.default_startup_program())
     exe.run(pserver_prog)
 else:
diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_dist_label_semantic_roles.py b/python/paddle/v2/fluid/tests/book_distribute/notest_dist_label_semantic_roles.py
index 2b5a098ff..a9730415e 100644
--- a/python/paddle/v2/fluid/tests/book_distribute/notest_dist_label_semantic_roles.py
+++ b/python/paddle/v2/fluid/tests/book_distribute/notest_dist_label_semantic_roles.py
@@ -197,7 +197,7 @@ def main():
         if not current_endpoint:
             print("need env SERVER_ENDPOINT")
             exit(1)
-        pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
+        pserver_prog = t.get_pserver_program(current_endpoint)
         exe.run(fluid.default_startup_program())
         exe.run(pserver_prog)
     elif training_role == "TRAINER":
diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_dist_word2vec.py b/python/paddle/v2/fluid/tests/book_distribute/notest_dist_word2vec.py
index dc04af5b7..65b81778d 100644
--- a/python/paddle/v2/fluid/tests/book_distribute/notest_dist_word2vec.py
+++ b/python/paddle/v2/fluid/tests/book_distribute/notest_dist_word2vec.py
@@ -87,7 +87,7 @@ if training_role == "PSERVER":
     if not current_endpoint:
         print("need env SERVER_ENDPOINT")
         exit(1)
-    pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
+    pserver_prog = t.get_pserver_program(current_endpoint)
     exe.run(fluid.default_startup_program())
     exe.run(pserver_prog)
 elif training_role == "TRAINER":
diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
index 27512c4f7..c0a3a3650 100644
--- a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
+++ b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
@@ -66,7 +66,7 @@ if training_role == "PSERVER":
     if not current_endpoint:
         print("need env SERVER_ENDPOINT")
         exit(1)
-    pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
+    pserver_prog = t.get_pserver_program(current_endpoint)
     exe.run(fluid.default_startup_program())
     exe.run(pserver_prog)
 elif training_role == "TRAINER":
diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_mlp_dist.py b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_mlp_dist.py
index 6de3e4da7..6cd3e84f3 100644
--- a/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_mlp_dist.py
+++ b/python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_mlp_dist.py
@@ -60,7 +60,7 @@ if training_role == "PSERVER":
     if not current_endpoint:
         print("need env SERVER_ENDPOINT")
         exit(1)
-    pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
+    pserver_prog = t.get_pserver_program(current_endpoint)
     exe.run(fluid.default_startup_program())
     exe.run(pserver_prog)
 elif training_role == "TRAINER":
diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_understand_sentiment_conv_dist.py b/python/paddle/v2/fluid/tests/book_distribute/notest_understand_sentiment_conv_dist.py
index 74f20f3f4..9e0a1a52a 100644
--- a/python/paddle/v2/fluid/tests/book_distribute/notest_understand_sentiment_conv_dist.py
+++ b/python/paddle/v2/fluid/tests/book_distribute/notest_understand_sentiment_conv_dist.py
@@ -98,7 +98,7 @@ def main():
         if not current_endpoint:
             print("need env SERVER_ENDPOINT")
             exit(1)
-        pserver_prog = t.get_pserver_program(current_endpoint, optimize_ops)
+        pserver_prog = t.get_pserver_program(current_endpoint)
         exe.run(pserver_prog)
     elif training_role == "TRAINER":
         trainer_prog = t.get_trainer_program()
-- 
GitLab