From bbca66f2fada26a03cb82ef851a5e53f75771b82 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Thu, 2 Mar 2023 17:17:23 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Prim=E3=80=91Fix=20slice=20error=20and?= =?UTF-8?q?=20eager=20comp=20(#51086)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix attrs copy error * fix bert by fix slice error * fix op test --- .../generator/eager_gen.py | 4 +- paddle/fluid/operators/slice_op.cc | 23 ++++++---- .../composite_backward_api.h | 16 ++++--- .../prim/utils/static/static_global_utils.cc | 1 + .../prim/utils/static/static_global_utils.h | 7 +++ paddle/fluid/prim/utils/utils.cc | 8 ++++ paddle/fluid/prim/utils/utils.h | 2 + paddle/fluid/pybind/pybind.cc | 4 ++ python/paddle/fluid/core.py | 29 +++++++++---- .../fluid/tests/unittests/prim/model/bert.py | 8 ++-- .../prim/model/test_bert_prim_cinn.py | 5 +-- .../prim/prim/flags/test_prim_flags.py | 3 ++ .../vjp/eager/test_comp_eager_add_grad.py | 8 ++-- .../vjp/eager/test_comp_eager_cast_grad.py | 4 +- .../vjp/eager/test_comp_eager_div_grad.py | 8 ++-- .../vjp/eager/test_comp_eager_exp_grad.py | 4 +- .../vjp/eager/test_comp_eager_expand_grad.py | 6 +-- .../vjp/eager/test_comp_eager_gather_grad.py | 6 +-- .../test_comp_eager_matmul_double_grad.py | 12 +++--- .../eager/test_comp_eager_multiply_grad.py | 4 +- .../vjp/eager/test_comp_eager_reshape_grad.py | 8 ++-- .../vjp/eager/test_comp_eager_sqrt_grad.py | 4 +- .../vjp/eager/test_comp_eager_sub_grad.py | 8 ++-- .../vjp/eager/test_comp_eager_sum_grad.py | 4 +- .../vjp/eager/test_comp_eager_tanh_grad.py | 4 +- .../eager/test_comp_eager_transpose_grad.py | 8 ++-- .../fluid/tests/unittests/prim_op_test.py | 3 +- .../fluid/tests/unittests/test_slice_op.py | 43 ++++++------------- 28 files changed, 138 insertions(+), 106 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 8c48e40aea6..8a5e2aef19b 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -1840,7 +1840,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): if is_composite_grad_api and next_grad_node_creation_str != '': next_grad_node_creation_str = f""" - if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{ + if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{ {next_grad_node_creation_str} }} """ @@ -2260,7 +2260,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # TODO(Ruting):using composite only when we don't have backward kernel in the future. elif is_composite_grad_api: grad_function_call_str = f""" - if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{ + if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{ {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); VLOG(4) << "Composite api {composite_grad_api_name} is called "; }}else{{ diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 2519f3f97f4..0b653d44bde 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -423,19 +423,25 @@ class SliceCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { auto dx_ptr = this->GetOutputPtr(&input_grad); std::string dx_name = this->GetOutputName(input_grad); - auto axes = this->Attr>("axes"); - auto starts = this->Attr>("starts"); - auto ends = this->Attr>("ends"); - auto infer_flags = this->Attr>("infer_flags"); - auto decrease_axis = this->Attr>("decrease_axis"); + auto axes = this->Attr>("axes"); + auto starts = this->Attr>("starts"); + auto ends = this->Attr>("ends"); + auto infer_flags = this->Attr>("infer_flags"); + auto decrease_axis = this->Attr>("decrease_axis"); VLOG(6) << "Runing slice_grad composite func"; + std::vector new_axes = + std::vector(axes.begin(), axes.end()); + std::vector new_infer_flags = + std::vector(infer_flags.begin(), infer_flags.end()); + std::vector new_decrease_axis = + std::vector(decrease_axis.begin(), decrease_axis.end()); prim::slice_grad(input, out_grad, - axes, + new_axes, paddle::experimental::IntArray(starts), paddle::experimental::IntArray(ends), - infer_flags, - decrease_axis, + new_infer_flags, + new_decrease_axis, dx_ptr); this->RecoverOutputName(input_grad, dx_name); } @@ -478,6 +484,7 @@ REGISTER_OPERATOR(slice, ops::SliceOpMaker, ops::SliceOpGradMaker, ops::SliceOpGradMaker, + ops::SliceCompositeGradOpMaker, ops::SliceOpVarTypeInference); REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad, diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index a7fc0a4a930..5792daee7d1 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -704,6 +704,7 @@ void slice_grad(const Tensor& input, if (input_grad) { size_t rank = input.dims().size(); auto out_dims = out_grad.dims(); + std::vector origin_out_shape; auto in_dims = input.dims(); auto decrease_size = decrease_axis.size(); @@ -712,7 +713,7 @@ void slice_grad(const Tensor& input, // all dims decrease out_dims = phi::make_ddim(std::vector(decrease_size, 1)); } else { - std::vector origin_out_shape(out_dims.size() + decrease_size, -1); + origin_out_shape.resize(out_dims.size() + decrease_size, -1); for (size_t i = 0; i < decrease_size; ++i) { origin_out_shape[decrease_axis[i]] = 1; } @@ -734,7 +735,6 @@ void slice_grad(const Tensor& input, offsets[i] = 0; extents[i] = out_dims[i]; } - for (size_t i = 0; i < axes.size(); ++i) { int axis = axes[i]; int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; @@ -747,9 +747,15 @@ void slice_grad(const Tensor& input, paddings.push_back(offsets[i]); paddings.push_back((in_dims[i] - out_dims[i]) - offsets[i]); } - - auto out_tmp = pad(out_grad, paddings, 0.0); - set_output(out_tmp, input_grad); + if (decrease_size > 0 && + (decrease_size != static_cast(in_dims.size()))) { + auto out_tmp = + pad(reshape(out_grad, origin_out_shape), paddings, 0.0); + set_output(out_tmp, input_grad); + } else { + auto out_tmp = pad(out_grad, paddings, 0.0); + set_output(out_tmp, input_grad); + } } } diff --git a/paddle/fluid/prim/utils/static/static_global_utils.cc b/paddle/fluid/prim/utils/static/static_global_utils.cc index 9631994ab2b..3d1aa215804 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.cc +++ b/paddle/fluid/prim/utils/static/static_global_utils.cc @@ -20,5 +20,6 @@ StaticCompositeContext* StaticCompositeContext::static_composite_context_ = new StaticCompositeContext(); thread_local bool StaticCompositeContext::enable_bwd_prim_ = false; thread_local bool StaticCompositeContext::enable_fwd_prim_ = false; +thread_local bool StaticCompositeContext::enable_eager_prim_ = false; } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/utils/static/static_global_utils.h b/paddle/fluid/prim/utils/static/static_global_utils.h index 0d5620c8e4e..c08405bb18d 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.h +++ b/paddle/fluid/prim/utils/static/static_global_utils.h @@ -65,6 +65,12 @@ class StaticCompositeContext { bool IsFwdPrimEnabled() { return enable_fwd_prim_; } + void SetEagerPrimEnabled(bool enable_prim) { + enable_eager_prim_ = enable_prim; + } + + bool IsEagerPrimEnabled() { return enable_eager_prim_; } + void SetAllPrimEnabled(bool enable_prim) { enable_fwd_prim_ = enable_prim; enable_bwd_prim_ = enable_prim; @@ -102,6 +108,7 @@ class StaticCompositeContext { std::map target_grad_name_; static thread_local bool enable_bwd_prim_; static thread_local bool enable_fwd_prim_; + static thread_local bool enable_eager_prim_; static StaticCompositeContext* static_composite_context_; DISABLE_COPY_AND_ASSIGN(StaticCompositeContext); }; diff --git a/paddle/fluid/prim/utils/utils.cc b/paddle/fluid/prim/utils/utils.cc index e7653161680..aa5255f532a 100644 --- a/paddle/fluid/prim/utils/utils.cc +++ b/paddle/fluid/prim/utils/utils.cc @@ -27,6 +27,14 @@ void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) { StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim); } +bool PrimCommonUtils::IsEagerPrimEnabled() { + return StaticCompositeContext::Instance().IsEagerPrimEnabled(); +} + +void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) { + StaticCompositeContext::Instance().SetEagerPrimEnabled(enable_prim); +} + bool PrimCommonUtils::IsFwdPrimEnabled() { return StaticCompositeContext::Instance().IsFwdPrimEnabled(); } diff --git a/paddle/fluid/prim/utils/utils.h b/paddle/fluid/prim/utils/utils.h index 8718496b3f1..bdde972e494 100644 --- a/paddle/fluid/prim/utils/utils.h +++ b/paddle/fluid/prim/utils/utils.h @@ -23,6 +23,8 @@ class PrimCommonUtils { public: static bool IsBwdPrimEnabled(); static void SetBwdPrimEnabled(bool enabled); + static bool IsEagerPrimEnabled(); + static void SetEagerPrimEnabled(bool enabled); static bool IsFwdPrimEnabled(); static void SetFwdPrimEnabled(bool enabled); static void SetAllPrimEnabled(bool enabled); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 37c5249b89a..02d09d58b4f 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -681,6 +681,10 @@ PYBIND11_MODULE(libpaddle, m) { &paddle::prim::PrimCommonUtils::IsFwdPrimEnabled); m.def("__set_all_prim_enabled", &paddle::prim::PrimCommonUtils::SetAllPrimEnabled); + m.def("_is_eager_prim_enabled", + &paddle::prim::PrimCommonUtils::IsEagerPrimEnabled); + m.def("__set_eager_prim_enabled", + &paddle::prim::PrimCommonUtils::SetEagerPrimEnabled); m.def("_set_prim_target_grad_name", &paddle::prim::PrimCommonUtils::SetTargetGradName); m.def("set_num_threads", &platform::SetNumThreads); diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index c3a50f7767a..db3a7c29788 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -316,6 +316,8 @@ try: from .libpaddle import __set_fwd_prim_enabled from .libpaddle import _is_fwd_prim_enabled from .libpaddle import __set_all_prim_enabled + from .libpaddle import _is_eager_prim_enabled + from .libpaddle import __set_eager_prim_enabled from .libpaddle import _set_prim_target_grad_name # custom devivce @@ -475,26 +477,36 @@ def _set_prim_forward_blacklist(ops=None): def _set_prim_backward_enabled(value): __set_bwd_prim_enabled(bool(value)) - print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + if os.getenv("FLAGS_prim_log") is "1": + print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) def _set_prim_forward_enabled(value): __set_fwd_prim_enabled(bool(value)) - print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + if os.getenv("FLAGS_prim_log") is "1": + print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + + +def set_prim_eager_enabled(value): + __set_eager_prim_enabled(bool(value)) + if os.getenv("FLAGS_prim_log") is "1": + print("eager prim enabled: ", bool(_is_eager_prim_enabled())) def _set_prim_all_enabled(value): __set_all_prim_enabled(bool(value)) - print( - "all prim enabled: ", - bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), - ) + if os.getenv("FLAGS_prim_log") is "1": + print( + "all prim enabled: ", + bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), + ) def __sync_prim_backward_status(): flag_value = os.getenv("FLAGS_prim_backward") if flag_value is None: - print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + if os.getenv("FLAGS_prim_log") is "1": + print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) else: __sync_stat_with_flag("FLAGS_prim_backward") @@ -502,7 +514,8 @@ def __sync_prim_backward_status(): def __sync_prim_forward_status(): flag_value = os.getenv("FLAGS_prim_forward") if flag_value is None: - print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + if os.getenv("FLAGS_prim_log") is 1: + print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) else: __sync_stat_with_flag("FLAGS_prim_forward") diff --git a/python/paddle/fluid/tests/unittests/prim/model/bert.py b/python/paddle/fluid/tests/unittests/prim/model/bert.py index 8a0eadf8698..240179a0697 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/bert.py +++ b/python/paddle/fluid/tests/unittests/prim/model/bert.py @@ -207,7 +207,7 @@ class BertPooler(nn.Layer): class BertModel(nn.Layer): - def __init__(self, config: BertConfig): + def __init__(self, config: BertConfig, to_static): super(BertModel, self).__init__() self.config = config self.pad_token_id = config.pad_token_id @@ -247,6 +247,8 @@ class BertModel(nn.Layer): self.encoder = nn.TransformerEncoder( encoder_layer, config.num_hidden_layers ) + if to_static: + self.encoder = paddle.jit.to_static(self.encoder) self.pooler = BertPooler(config) # self.apply(self.init_weights) @@ -364,10 +366,10 @@ class BertModel(nn.Layer): class Bert(nn.Layer): - def __init__(self): + def __init__(self, to_static): super(Bert, self).__init__() config = BertConfig() - self.bert = BertModel(config) + self.bert = BertModel(config, to_static) self.cls = BertPretrainingHeads( config, embedding_weights=self.bert.embeddings.word_embeddings.weight, diff --git a/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py b/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py index 99a2404d763..04f72d486ec 100644 --- a/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py +++ b/python/paddle/fluid/tests/unittests/prim/model/test_bert_prim_cinn.py @@ -58,7 +58,7 @@ def train(to_static, enable_prim, enable_cinn): worker_init=None, ) - bert = Bert() + bert = Bert(to_static) criterion = BertPretrainingCriterion() if to_static: # input_sepc = [ @@ -72,9 +72,6 @@ def train(to_static, enable_prim, enable_cinn): build_strategy = paddle.static.BuildStrategy() if enable_cinn: build_strategy.build_cinn_pass = True - bert = paddle.jit.to_static( - bert, input_sepc, build_strategy=build_strategy - ) optimizer = fluid.optimizer.Adam(parameter_list=bert.parameters()) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py index fd156f3ea2f..2c6d5133123 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py @@ -58,6 +58,9 @@ class TestPrimFlags(unittest.TestCase): core.check_and_set_prim_all_enabled() self.assertFalse(core._is_fwd_prim_enabled()) + core.set_prim_eager_enabled(True) + self.assertTrue(core._is_eager_prim_enabled()) + with self.assertRaises(TypeError): core._test_use_sync("aaaa") diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py index f3aa6375cef..b50e49a3e44 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core._set_prim_backward_enabled(True) +core.set_prim_eager_enabled(True) @param.parameterized_class( @@ -61,7 +61,7 @@ class TestAddGradComp(unittest.TestCase): def test_add_grad_comp(self): def actual(primal0, primal1): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -72,7 +72,7 @@ class TestAddGradComp(unittest.TestCase): return res[0].numpy(), res[1].numpy() def desired(primal0, primal1): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -98,7 +98,7 @@ class TestAddGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_cast_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_cast_grad.py index 9ac4bd239a6..b8c47466e85 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_cast_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_cast_grad.py @@ -52,7 +52,7 @@ class TestCastGradComp(unittest.TestCase): cls.cotangent = cls.cotangent.astype(cls.src_dtype) def test_cast_grad_comp(self): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) def actual(primal, cotangent): x = paddle.to_tensor(primal) @@ -78,7 +78,7 @@ class TestCastGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py index 989888230a8..6546776a207 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core._set_prim_backward_enabled(True) +core.set_prim_eager_enabled(True) @param.parameterized_class( @@ -61,7 +61,7 @@ class TestDivGradComp(unittest.TestCase): def test_div_grad_comp(self): def actual(primal0, primal1): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -72,7 +72,7 @@ class TestDivGradComp(unittest.TestCase): return res[0].numpy(), res[1].numpy() def desired(primal0, primal1): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -98,7 +98,7 @@ class TestDivGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py index 85974031280..172a37956c8 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py @@ -32,14 +32,14 @@ from paddle.fluid import core class TestExpGradComp(unittest.TestCase): @classmethod def setUpClass(cls): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) cls.primal = cls.primal.astype(cls.dtype) if cls.cotangent is not None: cls.cotangent = cls.cotangent.astype(cls.dtype) @classmethod def tearDownClass(cls): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) def test_exp_grad_comp(self): def actual(primal, cotangent): diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py index 92b0b98942c..1991bd06f0c 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py @@ -62,7 +62,7 @@ class TestExpandGradComp(unittest.TestCase): @classmethod def tearDownClass(cls): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) def test_comp(self): def func(primal, cotangent, shape): @@ -74,11 +74,11 @@ class TestExpandGradComp(unittest.TestCase): ] def actual(primal, cotangent, shape): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) return func(primal, cotangent, shape) def desired(primal, cotangent, shape): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) return func(primal, cotangent, shape) np.testing.assert_allclose( diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py index 2da25afb7a9..7d71e5187f2 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_gather_grad.py @@ -75,11 +75,11 @@ class TestGatherGradComp(unittest.TestCase): @classmethod def tearDownClass(cls): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) def test_exp_grad_comp(self): def actual(primal0, index, axis): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) paddle.disable_static() x = paddle.to_tensor( primal0, dtype=primal0.dtype, stop_gradient=False @@ -92,7 +92,7 @@ class TestGatherGradComp(unittest.TestCase): return res[0].numpy() def desired(primal0, index, axis): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) paddle.disable_static() x = paddle.to_tensor( primal0, dtype=primal0.dtype, stop_gradient=False diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_matmul_double_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_matmul_double_grad.py index d697c160093..3d24604419d 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_matmul_double_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_matmul_double_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core._set_prim_backward_enabled(True) +core.set_prim_eager_enabled(True) # vector * vector out.shape = (1) # matrix * vector out.shape = (2) @@ -267,7 +267,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase): def test_matmul_grad_comp(self): def actual(primal0, primal1, trans_0, trans_1, dtype_): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False) y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False) @@ -287,7 +287,7 @@ class TestMatmulDoubleGradComp(unittest.TestCase): ) def desired(primal0, primal1, trans_0, trans_1, dtype_): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False) y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False) @@ -428,7 +428,7 @@ class TestMatmulTribleGradComp(unittest.TestCase): def test_matmul_grad_comp(self): def actual(primal0, primal1, trans_0, trans_1, dtype_): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False) y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False) @@ -465,7 +465,7 @@ class TestMatmulTribleGradComp(unittest.TestCase): ) def desired(primal0, primal1, trans_0, trans_1, dtype_): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype=dtype_, stop_gradient=False) y = paddle.to_tensor(primal1, dtype=dtype_, stop_gradient=False) @@ -549,7 +549,7 @@ class TestMatmulTribleGradComp(unittest.TestCase): atol=TOLERANCE[d_type]['atol'], ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py index fdef2779a41..207e3f414f2 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py @@ -81,10 +81,10 @@ class TestMultiplyGradComp(unittest.TestCase): return [g for g in grads if g is not None] def test_comp(self): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) actual = self.vjp() - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) desired = self.vjp() for i, j in zip(actual, desired): diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_reshape_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_reshape_grad.py index 1840307af69..e98f8ba58c3 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_reshape_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_reshape_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core._set_prim_backward_enabled(True) +core.set_prim_eager_enabled(True) @param.parameterized_class( @@ -42,7 +42,7 @@ class TestReshapeGradComp(unittest.TestCase): def test_reshape_grad_comp(self): def actual(primal0, shape): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x.stop_gradient = False @@ -51,7 +51,7 @@ class TestReshapeGradComp(unittest.TestCase): return res[0].numpy() def desired(primal0, shape): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x.stop_gradient = False @@ -69,7 +69,7 @@ class TestReshapeGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py index de5106bd0ca..4f6dc8b2ada 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sqrt_grad.py @@ -22,7 +22,7 @@ import parameterized as param import paddle from paddle.fluid import core -core._set_prim_backward_enabled(True) +core.set_prim_eager_enabled(True) @param.parameterized_class( @@ -57,7 +57,7 @@ class TestSqrtGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py index ff704270595..62aa0e936f7 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core._set_prim_backward_enabled(True) +core.set_prim_eager_enabled(True) @param.parameterized_class( @@ -61,7 +61,7 @@ class TestSubGradComp(unittest.TestCase): def test_sub_grad_comp(self): def actual(primal0, primal1): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -72,7 +72,7 @@ class TestSubGradComp(unittest.TestCase): return res[0].numpy(), res[1].numpy() def desired(primal0, primal1): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) @@ -98,7 +98,7 @@ class TestSubGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py index 030e6e0b64b..bed88e39521 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sum_grad.py @@ -21,7 +21,7 @@ from paddle.fluid import core def actual(primal, cotangent, axis, keep_dim): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) y = paddle.sum(x, axis=axis, keepdim=keep_dim) @@ -30,7 +30,7 @@ def actual(primal, cotangent, axis, keep_dim): def desired(primal, cotangent, axis, keep_dim): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) x = paddle.to_tensor(primal, dtype='float32', stop_gradient=False) v = paddle.to_tensor(cotangent, dtype='float32', stop_gradient=False) y = paddle.sum(x, axis=axis, keepdim=keep_dim) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py index f59c94c1a5c..b919ea3e952 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core._set_prim_backward_enabled(True) +core.set_prim_eager_enabled(True) @param.parameterized_class( @@ -68,7 +68,7 @@ class TestTanhGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_transpose_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_transpose_grad.py index 266f70fb243..b4a68059a53 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_transpose_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_transpose_grad.py @@ -20,7 +20,7 @@ import parameterized as param import paddle from paddle.fluid import core -core._set_prim_backward_enabled(True) +core.set_prim_eager_enabled(True) @param.parameterized_class( @@ -72,7 +72,7 @@ class TestTransposeGradComp(unittest.TestCase): def test_transpose_grad_comp(self): def actual(primal0, shape): - core._set_prim_backward_enabled(True) + core.set_prim_eager_enabled(True) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x.stop_gradient = False @@ -81,7 +81,7 @@ class TestTransposeGradComp(unittest.TestCase): return res[0].numpy() def desired(primal0, shape): - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) paddle.disable_static() x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) x.stop_gradient = False @@ -99,7 +99,7 @@ class TestTransposeGradComp(unittest.TestCase): rtol=1e-6, atol=0, ) - core._set_prim_backward_enabled(False) + core.set_prim_eager_enabled(False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/prim_op_test.py b/python/paddle/fluid/tests/unittests/prim_op_test.py index c3b1d44bb2c..19b9f4c9971 100644 --- a/python/paddle/fluid/tests/unittests/prim_op_test.py +++ b/python/paddle/fluid/tests/unittests/prim_op_test.py @@ -906,7 +906,7 @@ class PrimGradChecker(PrimForwardChecker): paddle.device.set_device("gpu:0") atol = self.rev_comp_atol rtol = self.rev_comp_rtol - core._set_prim_backward_enabled(self.enable_rev_comp) + core.set_prim_eager_enabled(self.enable_rev_comp) actual_ret = self.get_eager_desire() # check static forward if len(actual_ret) != len(self.eager_desire): @@ -941,6 +941,7 @@ class PrimGradChecker(PrimForwardChecker): ) ) raise RuntimeError(msg) + core.set_prim_eager_enabled(False) def check_static_comp(self): paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 893ec7366cb..ae7528b4900 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -213,9 +213,7 @@ class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim): class TestSliceOp_starts_ListTensor(OpTest): def setUp(self): self.op_type = "slice" - self.prim_op_type = "prim" self.python_api = paddle.slice - # self.enable_cinn = False self.config() starts_tensor = [] @@ -244,12 +242,10 @@ class TestSliceOp_starts_ListTensor(OpTest): self.starts_infer = [-1, 0, -1] def test_check_output(self): - self.check_output(check_prim=True) + self.check_output() def test_check_grad_normal(self): - self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_prim=True - ) + self.check_grad(['Input'], 'Out', max_relative_error=0.006) # Situation 2: starts(list, have tensor), ends(list, no tensor) @@ -257,7 +253,6 @@ class TestSliceOp_starts_ListTensor(OpTest): class TestSliceOp_decs_dim_starts_ListTensor(OpTest): def setUp(self): self.op_type = "slice" - self.prim_op_type = "prim" self.python_api = paddle.slice self.config() @@ -290,12 +285,10 @@ class TestSliceOp_decs_dim_starts_ListTensor(OpTest): self.starts_infer = [1, -1, 2] def test_check_output(self): - self.check_output(check_prim=True) + self.check_output() def test_check_grad_normal(self): - self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_prim=True - ) + self.check_grad(['Input'], 'Out', max_relative_error=0.006) class TestSliceOp_decs_dim_5_starts_ListTensor( @@ -318,7 +311,6 @@ class TestSliceOp_decs_dim_5_starts_ListTensor( class TestSliceOp_decs_dim_starts_OneTensor(OpTest): def setUp(self): self.op_type = "slice" - self.prim_op_type = "prim" self.python_api = paddle.slice self.config() self.inputs = { @@ -344,12 +336,10 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest): self.out = self.input[1, 0:3, 2:4, :] def test_check_output(self): - self.check_output(check_prim=True) + self.check_output() def test_check_grad_normal(self): - self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_prim=True - ) + self.check_grad(['Input'], 'Out', max_relative_error=0.006) # Situation 4: starts(tensor), ends(tensor) @@ -357,7 +347,6 @@ class TestSliceOp_decs_dim_starts_OneTensor(OpTest): class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): def setUp(self): self.op_type = "slice" - self.prim_op_type = "prim" self.python_api = paddle.slice self.config() @@ -383,12 +372,10 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): self.out = self.input[1:3, 0:3, 2:4, :] def test_check_output(self): - self.check_output(check_prim=True) + self.check_output() def test_check_grad_normal(self): - self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_prim=True - ) + self.check_grad(['Input'], 'Out', max_relative_error=0.006) # Situation 5: starts(tensor), ends(tensor) @@ -396,7 +383,6 @@ class TestSliceOp_starts_OneTensor_ends_OneTensor(OpTest): class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): def setUp(self): self.op_type = "slice" - self.prim_op_type = "prim" self.python_api = paddle.slice self.config() self.inputs = { @@ -423,12 +409,10 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): self.out = self.input[1, 0, 2:4, :] def test_check_output(self): - self.check_output(check_prim=True) + self.check_output() def test_check_grad_normal(self): - self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_prim=True - ) + self.check_grad(['Input'], 'Out', max_relative_error=0.006) # Situation 6: starts(tensor), ends(list, have tensor) @@ -436,7 +420,6 @@ class TestSliceOp_decs_dim_starts_and_ends_OneTensor(OpTest): class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): def setUp(self): self.op_type = "slice" - self.prim_op_type = "prim" self.python_api = paddle.slice self.config() @@ -470,12 +453,10 @@ class TestSliceOp_starts_OneTensor_ends_ListTensor(OpTest): self.ends_infer = [-1, 3, 4] def test_check_output(self): - self.check_output(check_prim=True) + self.check_output() def test_check_grad_normal(self): - self.check_grad( - ['Input'], 'Out', max_relative_error=0.006, check_prim=True - ) + self.check_grad(['Input'], 'Out', max_relative_error=0.006) # Test CUDA float16 -- GitLab