From 686f0ecb6a02b29773c4e8a731ea8ad94b979860 Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Wed, 11 Dec 2019 16:56:35 +0800 Subject: [PATCH] add `no_need_buffer_slots` interface to pybind (#21575) * add no_need_buffer_slots interface to pybind --- paddle/fluid/pybind/pybind.cc | 15 +++- .../test_infer_no_need_buffer_slots.py | 72 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_infer_no_need_buffer_slots.py diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 13e32f2acf2..960bbaff822 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1103,7 +1103,20 @@ All parameter, weight, gradient are variables in Paddle. m.def("has_infer_inplace", [](const std::string op_type) { return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace(); }); - + m.def("infer_no_need_buffer_slots", + [](const std::string op_type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) { + auto infer_func = framework::OpInfoMap::Instance() + .Get(op_type) + .NoNeedBufferVarsInferer(); + if (infer_func) { + return infer_func(inputs, outputs, attrs); + } else { + std::unordered_set empty = {}; + return empty; + } + }); m.def("prune", [](const ProgramDesc &origin, const std::set &feeded_var_names, const std::vector> &targets) { diff --git a/python/paddle/fluid/tests/unittests/test_infer_no_need_buffer_slots.py b/python/paddle/fluid/tests/unittests/test_infer_no_need_buffer_slots.py new file mode 100644 index 00000000000..3656cdfd5a0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_infer_no_need_buffer_slots.py @@ -0,0 +1,72 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import paddle.compat as cpt +import paddle.fluid.core as core + + +class TestInferNoNeedBufferSlots(unittest.TestCase): + def net(self): + x1 = fluid.default_main_program().global_block().create_var( + dtype="float32", shape=[1], lod_level=0, name="x1") + x2 = fluid.default_main_program().global_block().create_var( + dtype="float32", shape=[1], lod_level=0, name="x2") + x = fluid.layers.elementwise_add(x1, x2) + return x + + def test_infer_no_need_buffer_slots(self): + program = framework.Program() + startup_program = framework.Program() + with fluid.program_guard(program, startup_program): + loss = self.net() + sgd = fluid.optimizer.SGD(learning_rate=0.01) + sgd.minimize(loss) + + block = program.global_block() + for idx, op in enumerate(block.ops): + op_desc = op.desc + inputs = {} + for input_name in op_desc.input_names(): + inputs[input_name] = op_desc.input(input_name) + outputs = {} + for output_name in op_desc.output_names(): + outputs[output_name] = op_desc.output(output_name) + attrs = {} + for attr_name in op_desc.attr_names(): + attrs[attr_name] = op_desc.attr(attr_name) + if idx == 0: + # elementwise_add op + self.assertEqual( + core.infer_no_need_buffer_slots(op.type, inputs, outputs, + attrs), set([])) + elif idx == 1: + # fill constant op + self.assertEqual( + core.infer_no_need_buffer_slots(op.type, inputs, outputs, + attrs), set([])) + else: + # elementwise_add_grad op + self.assertEqual( + core.infer_no_need_buffer_slots(op.type, inputs, outputs, + attrs), set(['Y', 'X'])) + + +if __name__ == '__main__': + unittest.main() -- GitLab