diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index 3f743e5a0e42afc68de03058aeea2ac09ce9e9fc..540ad322b0351ff230f150fd31cce47c0cd61e57 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" +#include "paddle/phi/core/utils/data_type.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 3afa78d09c34b17dd11065ad79e2b393f002f4f1..f7eb007348f5e7c7f7ab265b716454c7228e2eb9 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -944,6 +944,7 @@ void maximum_grad(const Tensor& x, } } +template void dropout_grad(const Tensor& mask, const Tensor& out_grad, const Scalar& p, diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py index 57d816c654a09dc835ebf1fee554e514ee6b544d..af183e8793e56c563296d27debfdf57a974ff7ff 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py @@ -244,22 +244,22 @@ class TestCompositeBatchNorm(unittest.TestCase): atol=attrs.get_atol("forward"), ) - # def test_forward(self): - # for i in self.training: - # for j in self.dtypes: - # for m in self.momentum: - # attrs.set_training(i) - # attrs.set_dtype(j) - # attrs.set_momentum(m) - # self.compare_forward() - - # for n in self.shapes: - # for s in self.data_formats: - # for t in self.use_global_stats: - # attrs.set_shape(n) - # attrs.set_data_format(s) - # attrs.set_use_global_stats(t) - # self.compare_forward() + def test_forward(self): + for i in self.training: + for j in self.dtypes: + for m in self.momentum: + attrs.set_training(i) + attrs.set_dtype(j) + attrs.set_momentum(m) + self.compare_forward() + + for n in self.shapes: + for s in self.data_formats: + for t in self.use_global_stats: + attrs.set_shape(n) + attrs.set_data_format(s) + attrs.set_use_global_stats(t) + self.compare_forward() def apply_to_static(net, use_cinn): diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 3757bc9917e65a176be3a504697227ad3602190e..5bfd05156c3786ce31d510de6dccd4152fd18b5d 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -226,7 +226,7 @@ def to_prim(blocks, exclude=frozenset()): if not core._is_fwd_prim_enabled(): return if isinstance(blocks, paddle.fluid.framework.Block): - logging.debug("Atomize composite op to primitive ops begin.") + logging.info("Atomize composite op to primitive ops begin.") main_program = blocks.program elif isinstance(blocks, typing.Sequence): for item in blocks: @@ -245,9 +245,9 @@ def to_prim(blocks, exclude=frozenset()): ) with framework.program_guard(main_program): - logging.debug("Lowering composite forward ops begin...") + print("Lowering composite forward ops begin...") primx._lower_composite( blocks, prim_config["forward_blacklist"] | exclude ) replace_ops = prim_config["composite_ops_record"] - logging.debug(f"Lowering composite forward ops finish: {replace_ops}") + print(f"Lowering composite forward ops finish: {replace_ops}") diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 996da16fba9547d5547a3d08849df42a94c64270..69a6e004606af393e33425a2cc6e7f57a7826411 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1218,7 +1218,11 @@ class ProgramCache: return infer_program partial_program = partial_program_from(concrete_program) - if not _in_amp_guard() and not _in_pure_fp16_guard(): + if ( + core._is_fwd_prim_enabled() + and not _in_amp_guard() + and not _in_pure_fp16_guard() + ): partial_program.set_hooker(PrimHooker()) return concrete_program, partial_program