From 5dda91a80a841e7801cec2d7e429bc2076ed6497 Mon Sep 17 00:00:00 2001 From: cxxly Date: Thu, 2 Mar 2023 03:21:37 +0000 Subject: [PATCH] fix cast prim and vjp dtype mapping error bug --- paddle/fluid/operators/cast_op.cc | 1 + .../composite_backward_api.h | 1 + .../test_composite_batch_norm.py | 32 +++++++++---------- python/paddle/incubate/autograd/primapi.py | 6 ++-- .../jit/dy2static/program_translator.py | 6 +++- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 3f743e5a0e4..540ad322b03 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" +#include "paddle/phi/core/utils/data_type.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 3afa78d09c3..f7eb007348f 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -944,6 +944,7 @@ void maximum_grad(const Tensor& x, } } +template void dropout_grad(const Tensor& mask, const Tensor& out_grad, const Scalar& p, diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py index 57d816c654a..af183e8793e 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py @@ -244,22 +244,22 @@ class TestCompositeBatchNorm(unittest.TestCase): atol=attrs.get_atol("forward"), ) - # def test_forward(self): - # for i in self.training: - # for j in self.dtypes: - # for m in self.momentum: - # attrs.set_training(i) - # attrs.set_dtype(j) - # attrs.set_momentum(m) - # self.compare_forward() - - # for n in self.shapes: - # for s in self.data_formats: - # for t in self.use_global_stats: - # attrs.set_shape(n) - # attrs.set_data_format(s) - # attrs.set_use_global_stats(t) - # self.compare_forward() + def test_forward(self): + for i in self.training: + for j in self.dtypes: + for m in self.momentum: + attrs.set_training(i) + attrs.set_dtype(j) + attrs.set_momentum(m) + self.compare_forward() + + for n in self.shapes: + for s in self.data_formats: + for t in self.use_global_stats: + attrs.set_shape(n) + attrs.set_data_format(s) + attrs.set_use_global_stats(t) + self.compare_forward() def apply_to_static(net, use_cinn): diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 3757bc9917e..5bfd05156c3 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -226,7 +226,7 @@ def to_prim(blocks, exclude=frozenset()): if not core._is_fwd_prim_enabled(): return if isinstance(blocks, paddle.fluid.framework.Block): - logging.debug("Atomize composite op to primitive ops begin.") + logging.info("Atomize composite op to primitive ops begin.") main_program = blocks.program elif isinstance(blocks, typing.Sequence): for item in blocks: @@ -245,9 +245,9 @@ def to_prim(blocks, exclude=frozenset()): ) with framework.program_guard(main_program): - logging.debug("Lowering composite forward ops begin...") + print("Lowering composite forward ops begin...") primx._lower_composite( blocks, prim_config["forward_blacklist"] | exclude ) replace_ops = prim_config["composite_ops_record"] - logging.debug(f"Lowering composite forward ops finish: {replace_ops}") + print(f"Lowering composite forward ops finish: {replace_ops}") diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 996da16fba9..69a6e004606 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1218,7 +1218,11 @@ class ProgramCache: return infer_program partial_program = partial_program_from(concrete_program) - if not _in_amp_guard() and not _in_pure_fp16_guard(): + if ( + core._is_fwd_prim_enabled() + and not _in_amp_guard() + and not _in_pure_fp16_guard() + ): partial_program.set_hooker(PrimHooker()) return concrete_program, partial_program -- GitLab