From 3f4c088ad8fe99286649194b438fef5a4784056c Mon Sep 17 00:00:00 2001 From: chengduo <30176695+chengduoZH@users.noreply.github.com> Date: Fri, 9 Aug 2019 23:14:07 +0800 Subject: [PATCH] prune the feed op in compiler (#18997) test=develop --- python/paddle/fluid/compiler.py | 10 +++ ...arallel_executor_run_load_infer_program.py | 85 +++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 49c4c4f246c..b2283242ab3 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -45,6 +45,15 @@ def _is_pserver_mode(main_program): return False +def _prune_feed_ops(program): + # prune the feed ops in the program. + pop_idx = [] + for i, op in enumerate(program.global_block().ops): + if op.type == "feed": pop_idx.append(i) + for index in pop_idx[::-1]: + program.global_block()._remove_op(index) + + class CompiledProgram(object): """ Compiles to Graph for execution. @@ -100,6 +109,7 @@ class CompiledProgram(object): # don't not create a new program here. self._program = None elif isinstance(program_or_graph, framework.Program): + _prune_feed_ops(program_or_graph) self._graph = core.Graph(program_or_graph.desc) self._program = program_or_graph else: diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py new file mode 100644 index 00000000000..fc76f5d152d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_load_infer_program.py @@ -0,0 +1,85 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.fluid as fluid +from simple_nets import simple_fc_net, init_data + + +class TestMNIST(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.save_dirname = "./" + cls.model_filename = "test_parallel_executor_run_load_infer_program_model" + cls.params_filename = "test_parallel_executor_run_load_infer_program_parameter" + cls.place = fluid.CPUPlace() + cls.exe = fluid.Executor(cls.place) + img, label = init_data() + cls.batch_data = [] + for img, label in zip(img, label): + cls.batch_data.append([img, label]) + + def test_simple_fc(self): + exe_loss = self.run_with_executor() + + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model( + self.save_dirname, self.exe, self.model_filename, + self.params_filename) + + train_exe = fluid.ParallelExecutor( + use_cuda=False, main_program=inference_program) + feed_vars = [ + inference_program.global_block().var(var_name) + for var_name in ["image", "label"] + ] + feeder = fluid.DataFeeder(place=self.place, feed_list=feed_vars) + + pe_loss = train_exe.run(feed=feeder.feed(self.batch_data), + fetch_list=[fetch_targets[0].name]) + assert exe_loss == pe_loss + + def run_with_executor(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = simple_fc_net() + + feed_vars = [ + main.global_block().var(var_name) + for var_name in ["image", "label"] + ] + feeder = fluid.DataFeeder(place=self.place, feed_list=feed_vars) + + self.exe.run(startup) + + loss_data = self.exe.run(main, + feed=feeder.feed(self.batch_data), + fetch_list=[loss.name]) + + fluid.io.save_inference_model( + self.save_dirname, ["image", "label"], [loss], + self.exe, + model_filename=self.model_filename, + params_filename=self.params_filename, + main_program=main) + + return loss_data + + +if __name__ == '__main__': + unittest.main() -- GitLab