From d8643cb6d4364f3adb30eff71e4c9719488d40a4 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Thu, 2 Feb 2023 18:40:39 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PRIM=E3=80=91Support=20use=20operator'?= =?UTF-8?q?s=20output=20metadata=20info=20=20in=20constructing=20static=20?= =?UTF-8?q?backward=20composite=20(#50043)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [prim] support custom target_gradients * support infershape after append one gradop * [prim] add simple net test * fix test_loop segment fault bug * [prim] fix infer shape segment fault bug when output of grad_op_desc is empty --- .../api/generated/prim_api/static_prim_api.cc | 5 +- .../manual/backward/composite_backward_api.h | 8 +- .../utils/static/composite_grad_desc_maker.h | 13 +- .../prim/utils/static/static_global_utils.h | 9 ++ paddle/fluid/prim/utils/utils.cc | 6 + paddle/fluid/prim/utils/utils.h | 4 + paddle/fluid/pybind/pybind.cc | 2 + python/paddle/fluid/backward.py | 119 ++++++++++++++--- python/paddle/fluid/core.py | 1 + .../fluid/tests/unittests/prim/CMakeLists.txt | 1 + .../tests/unittests/prim/model/CMakeLists.txt | 9 ++ .../prim/model/test_comp_model_simple_net.py | 120 ++++++++++++++++++ .../vjp/static/test_comp_multiply_grad.py | 4 +- ...test_comp_get_grad_op_desc_prim_enabled.py | 1 - .../tests/unittests/test_eager_run_program.py | 2 +- .../tests/unittests/test_run_program_op.py | 2 +- .../paddle/jit/dy2static/partial_program.py | 6 +- python/paddle/jit/dy2static/utils.py | 7 +- python/paddle/jit/translated_layer.py | 2 +- 19 files changed, 279 insertions(+), 42 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/prim/model/CMakeLists.txt create mode 100644 python/paddle/fluid/tests/unittests/prim/model/test_comp_model_simple_net.py diff --git a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc b/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc index 30a82b49899..b879ade5a9e 100644 --- a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc +++ b/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc @@ -110,6 +110,7 @@ Tensor unsqueeze(const Tensor& x, const IntArray& axis) { op->SetAttr("axes", new_shape); op->CheckAttrs(); op->InferVarType(block); + op->InferShape(*block); return out; } @@ -209,7 +210,7 @@ Tensor sum(const Tensor& x, "Out", {std::static_pointer_cast(out.impl())->Name()}); op->CheckAttrs(); op->InferVarType(block); - // TODO(jiabin, cxxly): This may have runtime shape skip infershape for now. + op->InferShape(*block); return out; } @@ -232,7 +233,7 @@ Tensor reshape(const Tensor& x, const IntArray& shape) { "Out", {std::static_pointer_cast(out.impl())->Name()}); op->CheckAttrs(); op->InferVarType(block); - // TODO(jiabin, cxxly): This may have runtime shape skip infershape for now. + op->InferShape(*block); return out; } diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 99ef82d0888..a9c8953a228 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -232,8 +232,8 @@ void multiply_grad(const Tensor& x, Tensor* y_grad) { if (x_grad) { auto x_grad_unreduce = multiply(out_grad, y); - if (x.dims() != y.dims()) { - auto axes = get_reduce_dims(x.dims(), y.dims()); + if (x_grad_unreduce.dims() != x.dims()) { + auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims()); if (!axes.size()) { set_output(x_grad_unreduce, x_grad); } else { @@ -252,8 +252,8 @@ void multiply_grad(const Tensor& x, } if (y_grad) { auto y_grad_unreduce = multiply(out_grad, x); - if (y.dims() != x.dims()) { - auto axes = get_reduce_dims(y.dims(), x.dims()); + if (y_grad_unreduce.dims() != y.dims()) { + auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims()); if (!axes.size()) { set_output(y_grad_unreduce, y_grad); } else { diff --git a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h index e391d8ac530..a7969382685 100644 --- a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h +++ b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h @@ -318,6 +318,7 @@ class CompositeGradOpMakerBase { grad_var_name = framework::kEmptyVarName; if (drop_empty_grad) return nullptr; } + if (original_block_->HasVar(grad_var_name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(grad_var_name); @@ -333,6 +334,12 @@ class CompositeGradOpMakerBase { auto grad_var_name = framework::GradVarName(var_name); (*this->grad_to_var_)[grad_var_name] = var_name; VLOG(8) << "Valid gradients: " << grad_var_name; + + auto target_grad = StaticCompositeContext::Instance().GetTargetGradName(); + if (target_grad.find(grad_var_name) != target_grad.end()) { + grad_var_name = target_grad.at(grad_var_name); + } + if (original_block_->HasVar(grad_var_name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(grad_var_name); @@ -421,7 +428,11 @@ class CompositeGradOpMakerBase { return g_name; }); std::vector grad_out; - for (const auto& name : ret_val) { + for (auto name : ret_val) { + auto target_grad = StaticCompositeContext::Instance().GetTargetGradName(); + if (target_grad.find(name) != target_grad.end()) { + name = target_grad.at(name); + } // TODO(jiabin): Will this cause fill zeros error? if (original_block_->HasVar(name)) { // Copy Var from original block to active block, or create a new one. diff --git a/paddle/fluid/prim/utils/static/static_global_utils.h b/paddle/fluid/prim/utils/static/static_global_utils.h index 08407013673..e878c857f26 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.h +++ b/paddle/fluid/prim/utils/static/static_global_utils.h @@ -69,12 +69,21 @@ class StaticCompositeContext { enable_bwd_prim_ = enable_prim; } + void SetTargetGradName(const std::map& m) { + target_grad_name_ = m; + } + + std::map GetTargetGradName() { + return target_grad_name_; + } + private: StaticCompositeContext() : current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {} framework::BlockDesc* current_block_desc_; std::unique_ptr generator_; + std::map target_grad_name_; static thread_local bool enable_bwd_prim_; static thread_local bool enable_fwd_prim_; static StaticCompositeContext* static_composite_context_; diff --git a/paddle/fluid/prim/utils/utils.cc b/paddle/fluid/prim/utils/utils.cc index fb415262c8d..a869e5609b9 100644 --- a/paddle/fluid/prim/utils/utils.cc +++ b/paddle/fluid/prim/utils/utils.cc @@ -38,5 +38,11 @@ void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) { void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) { return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); } + +void PrimCommonUtils::SetTargetGradName( + const std::map& m) { + StaticCompositeContext::Instance().SetTargetGradName(m); +} + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/utils/utils.h b/paddle/fluid/prim/utils/utils.h index 38973dc87b8..4ede84a947b 100644 --- a/paddle/fluid/prim/utils/utils.h +++ b/paddle/fluid/prim/utils/utils.h @@ -14,6 +14,9 @@ #pragma once +#include +#include + namespace paddle { namespace prim { class PrimCommonUtils { @@ -23,6 +26,7 @@ class PrimCommonUtils { static bool IsFwdPrimEnabled(); static void SetFwdPrimEnabled(bool enabled); static void SetAllPrimEnabled(bool enabled); + static void SetTargetGradName(const std::map& m); }; } // namespace prim } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 020a926b473..8712e428bdf 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -673,6 +673,8 @@ PYBIND11_MODULE(libpaddle, m) { &paddle::prim::PrimCommonUtils::IsFwdPrimEnabled); m.def("__set_all_prim_enabled", &paddle::prim::PrimCommonUtils::SetAllPrimEnabled); + m.def("_set_prim_target_grad_name", + &paddle::prim::PrimCommonUtils::SetTargetGradName); m.def("set_num_threads", &platform::SetNumThreads); m.def("disable_signal_handler", &DisableSignalHandler); diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 5169f9f085f..1ba11e1fba4 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1337,6 +1337,18 @@ def _append_backward_ops_( rename_var_map = {} assert isinstance(rename_var_map, dict) + if core._is_bwd_prim_enabled(): + composite_block = program.clone().current_block() + # Infer shape for operators whose output haven't been created. + for op in composite_block.ops: + if not all( + tuple( + composite_block._find_var_recursive(arg) + for arg in op.output_arg_names + ) + ): + infershape_for_composite(composite_block, op.desc) + # add grad_op_desc by reversed ops for op in reversed(ops): grad_sub_block_list = [] @@ -1365,11 +1377,42 @@ def _append_backward_ops_( program._rollback() grad_sub_block_list.append(grad_sub_block.desc) + # In primitive mode, raw phi GradOp will be split into multiple small + # primitive operators, and the split rules are defined in c++ level, + # see detials: paddle/fluid/prim/api/manual/backward/composite_backward_api.h + # It means that the output's shape and dtype of previous operators which + # maybe used as the input of next operators must be known. Therefore, + # we infer shape and dtype in a sandbox block(named composite_block) for + # used in c++ level. + # For example: + # forward: + # z = multiply(x, y) //maybe broadcast in kernel + # bcckward: + # x_grad_unreduce = z_grad * y // maybe unreduce + # reduced_axes = get_reduced_axes(x_grad.shape, x.shape) // need known shape + # x_grad = reduce_sum(x_grad_unreduce) + grad_op_desc = [] + op_grad_to_var = {} + if core._is_bwd_prim_enabled(): + + def find_op_index(block_desc, cur_op_desc): + for idx in range(block_desc.op_size()): + if cur_op_desc == block_desc.op(idx): + return idx + return -1 - # Getting op's corresponding grad_op - grad_op_desc, op_grad_to_var = core.get_grad_op_desc( - op.desc, no_grad_dict[block.idx], grad_sub_block_list - ) + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + composite_block.desc.op(find_op_index(block.desc, op.desc)), + no_grad_dict[composite_block.idx], + grad_sub_block_list, + ) + for desc in grad_op_desc: + infershape_for_composite(composite_block, desc) + else: + # Getting op's corresponding grad_op + grad_op_desc, op_grad_to_var = core.get_grad_op_desc( + op.desc, no_grad_dict[block.idx], grad_sub_block_list + ) # record the mapping between fwd and bwd if grad_op_id_to_fwd_op is not None: @@ -1655,7 +1698,43 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): block.desc._remove_op(op_idx, op_idx + 1) -def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): +def infershape_for_composite(block, grad_op_desc): + # pruning empty output + if len(grad_op_desc.output_arg_names()) == 0: + return + + # append op to block + op_desc = block.desc.append_op() + op_desc.copy_from(grad_op_desc) + op_desc._set_attr( + core.op_proto_and_checker_maker.kOpRoleAttrName(), + core.op_proto_and_checker_maker.OpRole.Backward, + ) + + # create output var + new_vars = set() + # create new gradient variables + for grad_var_name in op_desc.output_arg_names(): + if not ( + block.desc.has_var_recursive(grad_var_name.encode()) + or grad_var_name == core.empty_var_name() + ): + block.desc.var(grad_var_name.encode()) + new_vars.add(grad_var_name) + + # infer shape and infer dthype + op_desc.check_attrs() + op_desc.infer_var_type(block.desc) + op_desc.infer_shape(block.desc) + + for arg in op_desc.output_arg_names(): + if arg in new_vars: + _infer_var_data_type_shape_(arg, block) + + +def _rename_grad_( + block, start_op_idx, grad_to_var, target_grad_map, skip_rename_var_list +): var_map = copy.copy(target_grad_map) for op_idx in range(start_op_idx, block.desc.op_size()): op_desc = block.desc.op(op_idx) @@ -1667,6 +1746,8 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map): if "@GRAD" not in name: continue if block.desc.find_var(name.encode("ascii")): + if name in skip_rename_var_list: + continue new_name = unique_name.generate(name) op_desc._rename_output(name, new_name) var_map[name] = new_name @@ -1993,7 +2074,7 @@ def append_backward( # Because append_backward may be called multiple times, # we need rename the internal gradient variables so that they have # different names. - _rename_grad_(target_grad_block, fwd_op_num, grad_to_var, {}) + _rename_grad_(target_grad_block, fwd_op_num, grad_to_var, {}, []) _append_backward_vars_( target_grad_block, fwd_op_num, grad_to_var, grad_info_map @@ -2297,33 +2378,24 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): target_grad_map = {} rename_var_map = {} + skip_rename_var_list = [] for i, grad in enumerate(target_gradients): target = targets[i] grad_name = _append_grad_suffix_(target.name) if grad is None: - target_shape = target.name + '_shape' - block.desc.append_op().copy_from( - _create_op_desc_( - "shape", - {'Input': [target.name]}, - {"Out": [target_shape]}, - {}, - ) - ) - input_grad_names_set.add(target_shape) op_desc = _create_op_desc_( - "fill_constant", - {"ShapeTensor": [target_shape]}, + "fill_any_like", + {"X": [target.name]}, {"Out": [grad_name]}, { - "shape": target.shape, "value": 1.0, "dtype": target.dtype, }, ) - block.desc.append_op().copy_from(op_desc) + block.program._sync_with_cpp() input_grad_names_set.add(grad_name) + skip_rename_var_list.append(grad_name) else: if target.block.idx != block_idx or target.block.program != prog: raise ValueError("all targets must be in the same block") @@ -2336,6 +2408,9 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): input_grad_names_set.add(grad.name) rename_var_map[grad_name] = grad.name + if core._is_bwd_prim_enabled(): + core._set_prim_target_grad_name(target_grad_map) + # For double backward, input_grad_names is used for filter # some non-used gradients op. rename_var_map is used to # associate target_grad var name with first grad_op input name. @@ -2378,7 +2453,9 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): # Because calc_gradient may be called multiple times, # we need rename the internal gradient variables so that they have # different names. - _rename_grad_(block, fwd_op_num, grad_to_var, target_grad_map) + _rename_grad_( + block, fwd_op_num, grad_to_var, target_grad_map, skip_rename_var_list + ) _append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map) prog._sync_with_cpp() diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 9aaf0f684f1..54b51f0face 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -313,6 +313,7 @@ try: from .libpaddle import __set_fwd_prim_enabled from .libpaddle import _is_fwd_prim_enabled from .libpaddle import __set_all_prim_enabled + from .libpaddle import _set_prim_target_grad_name # custom devivce from .libpaddle import _get_current_custom_device_stream diff --git a/python/paddle/fluid/tests/unittests/prim/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/CMakeLists.txt index ab3ee7ba1a3..7fd5f5ecebf 100644 --- a/python/paddle/fluid/tests/unittests/prim/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/CMakeLists.txt @@ -9,3 +9,4 @@ foreach(TEST_OP ${TEST_OPS}) endforeach() add_subdirectory(prim) +add_subdirectory(model) diff --git a/python/paddle/fluid/tests/unittests/prim/model/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/model/CMakeLists.txt new file mode 100644 index 00000000000..72c6bbd7d05 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/model/CMakeLists.txt @@ -0,0 +1,9 @@ +file( + GLOB TEST_OPS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(TEST_OP ${TEST_OPS}) + py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) +endforeach() diff --git a/python/paddle/fluid/tests/unittests/prim/model/test_comp_model_simple_net.py b/python/paddle/fluid/tests/unittests/prim/model/test_comp_model_simple_net.py new file mode 100644 index 00000000000..27b300e9afd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/model/test_comp_model_simple_net.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core, framework + + +@param.parameterized_class( + ('name', 'primals', 'stop_gradients', 'cotangents', 'dtype'), + ( + ( + 'test_normal_case', + (np.random.rand(2, 3, 4), np.random.rand(2, 3, 4)), + (False, False), + (np.random.rand(2, 3, 4),), + np.float32, + ), + ( + 'test_broadcast_diff_rank', + (np.random.rand(2, 3, 1, 4), np.random.rand(3, 3, 4)), + (False, False), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ( + 'test_broadcast_same_rank', + (np.random.rand(2, 3, 1, 4), np.random.rand(2, 1, 3, 4)), + (False, False), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ( + 'test_stop_gradient', + (np.random.rand(2, 3, 1, 4), np.random.rand(2, 1, 3, 4)), + (False, True), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ), +) +class TestMultiplyGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primals = tuple(primal.astype(cls.dtype) for primal in cls.primals) + cls.cotangents = tuple(co.astype(cls.dtype) for co in cls.cotangents) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def as_tuple(self, x): + return (x,) if isinstance(x, framework.Variable) else x + + def net(self): + primals, cotangents = self.primals, self.cotangents + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + primals = tuple( + paddle.static.data(f'primal{i}', primal.shape, primal.dtype) + for i, primal in enumerate(primals) + ) + for primal, flag in zip(primals, self.stop_gradients): + primal.stop_gradient = flag + cotangents = tuple( + paddle.static.data(f'cotangent{i}', co.shape, co.dtype) + for i, co in enumerate(cotangents) + ) + out = self.as_tuple(paddle.tanh(paddle.multiply(*primals))) + grads = paddle.static.gradients(out, primals) + exe = paddle.static.Executor() + exe.run(sp) + return exe.run( + program=mp, + feed={ + **{ + f'primal{i}': primal + for i, primal in enumerate(self.primals) + }, + **{f'cotangent{i}': co for i, co in enumerate(self.cotangents)}, + }, + fetch_list=[g for g in grads if g is not None], + ) + + def test_comp(self): + core._set_prim_backward_enabled(True) + actual = self.net() + + core._set_prim_backward_enabled(False) + desired = self.net() + + self.assertEqual(len(actual), len(desired)) + for i, j in zip(actual, desired): + np.testing.assert_allclose( + i, + j, + rtol=1e-6, + atol=0, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py index 2d1a10a6d4b..c2f15b6ab84 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py @@ -56,7 +56,7 @@ from paddle.fluid import core, framework 'test_reduce_axe_empty', (np.random.rand(2, 3, 3, 4), np.random.rand(2, 1, 3, 4)), (False, False), - (np.random.rand(2, 1, 3, 1),), + (np.random.rand(2, 3, 3, 4),), np.float32, ), ), @@ -91,7 +91,7 @@ class TestMultiplyGradComp(unittest.TestCase): for i, co in enumerate(cotangents) ) out = self.as_tuple(paddle.multiply(*primals)) - grads = paddle.static.gradients(out, primals) + grads = paddle.static.gradients(out, primals, cotangents) exe = paddle.static.Executor() exe.run(sp) return exe.run( diff --git a/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py b/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py index 18b445f38da..c576be20388 100644 --- a/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py +++ b/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py @@ -75,7 +75,6 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase): self.fwd, self.no_grad_var, self.grad_sub_block )[0] ) - print(actual) self.assertEquals(actual, self.desired_ops) core._set_prim_backward_enabled(False) diff --git a/python/paddle/fluid/tests/unittests/test_eager_run_program.py b/python/paddle/fluid/tests/unittests/test_eager_run_program.py index 33472f85e73..4863f46f487 100644 --- a/python/paddle/fluid/tests/unittests/test_eager_run_program.py +++ b/python/paddle/fluid/tests/unittests/test_eager_run_program.py @@ -105,7 +105,7 @@ class TestRunProgram(unittest.TestCase): ) backward_program = _add_build_strategy_for( program, - main_program.desc.block(0).op_size() + 2, + main_program.desc.block(0).op_size() + 1, program.desc.block(0).op_size(), ) diff --git a/python/paddle/fluid/tests/unittests/test_run_program_op.py b/python/paddle/fluid/tests/unittests/test_run_program_op.py index 73ad833a3ef..35c31deb3f8 100644 --- a/python/paddle/fluid/tests/unittests/test_run_program_op.py +++ b/python/paddle/fluid/tests/unittests/test_run_program_op.py @@ -131,7 +131,7 @@ class RunProgramOpTest(unittest.TestCase): forward_program = _add_build_strategy_for(program, 0, forward_op_num) backward_program = _add_build_strategy_for( program, - forward_op_num + 2 * output_num, + forward_op_num + output_num, program.desc.block(0).op_size(), ) return forward_program.desc, backward_program.desc diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 19d44b8e35c..fd509d74e53 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -576,9 +576,7 @@ class PartialProgramLayer: core.check_and_set_prim_all_enabled() backward.gradients(targets=targets, inputs=[]) - start_idx = len(main_program.block(0).ops) + 2 * len( - self._outputs.tolist() - ) + start_idx = len(main_program.block(0).ops) + len(self._outputs.tolist()) self.prepare_gradient_aggregation(start_idx, main_program, program) @@ -753,7 +751,7 @@ class PartialProgramLayer: ): # NOTE(dev): We apply build_strategy for backward firstly to # avoid skipping more gc variables. - backward_start_op_index = forward_end_op_index + 2 * len( + backward_start_op_index = forward_end_op_index + len( self._outputs.var_ids ) backward_end_op_index = whole_program.desc.block(0).op_size() diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 4397728576b..cad4c295561 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1512,12 +1512,11 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): """ names = [] for i in range( - fwd_end_op_index + 1, - min(fwd_end_op_index + 2 * out_size, program_desc.block(0).op_size()), - 2, + fwd_end_op_index, + min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), ): op = program_desc.block(0).op(i) - if op.type() == 'fill_constant': + if op.type() == 'fill_any_like': var_name = op.output('Out')[0] names.append(var_name) return names diff --git a/python/paddle/jit/translated_layer.py b/python/paddle/jit/translated_layer.py index c488c758f4a..45563584f16 100644 --- a/python/paddle/jit/translated_layer.py +++ b/python/paddle/jit/translated_layer.py @@ -373,7 +373,7 @@ class _ProgramHolder: @switch_to_static_graph def _create_backward_train_program(self): whole_program = _build_program_by_desc(self._train_program_desc) - start_op_index = self._infer_program_desc.block(0).op_size() + 2 * len( + start_op_index = self._infer_program_desc.block(0).op_size() + len( self._output_descs ) end_op_index = whole_program.desc.block(0).op_size() -- GitLab