diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index 700a69fa3ce48776f35d44ec74dfd9d2b6905350..5b81438aaf125e5e9201cd3a3c1f5dbb937939a1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -67,6 +67,12 @@ class ElementwiseAddCompositeGradOpMaker auto* dy_ptr = this->GetOutputPtr(&dy); std::string dy_name = this->GetOutputName(dy); int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument( + "We only support axis = -1 in composite add_grad but we got: ", + axis)); VLOG(6) << "Runing add_grad composite func"; prim::add_grad(x, y, out_grad, axis, dx_ptr, dy_ptr); this->RecoverOutputName(dx, dx_name); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 97941aa82f3954c34c871f49f9175e639fdd47da..7012f3f671e0ffa6459991fcf5002d456fe931a5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -84,6 +84,11 @@ class ElementwiseDivCompositeGradOpMaker auto dy_ptr = this->GetOutputPtr(&dy); std::string dy_name = this->GetOutputName(dy); int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument( + "We only support axis = -1 in composite div but we got: ", axis)); VLOG(6) << "Runing div_grad composite func"; prim::divide_grad( x, y, out, out_grad, axis, dx_ptr, dy_ptr); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 9821cc226128323d48254f020f3470e919469b80..c4a1060497ea1b9b02b1209e8a39745b685b9399 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -81,13 +81,15 @@ class ElementwiseMulCompositeGradOpMaker auto y_grad = this->GetSingleInputGrad("Y"); auto y_grad_p = this->GetOutputPtr(&y_grad); auto y_grad_name = this->GetOutputName(y_grad); + int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument( + "We only support axis = -1 in composite mul_grad but we got: ", + axis)); prim::multiply_grad( - x, - y, - out_grad, - static_cast(this->Attr("axis")), - x_grad_p, - y_grad_p); + x, y, out_grad, axis, x_grad_p, y_grad_p); VLOG(6) << "Runing mul_grad composite func"; this->RecoverOutputName(x_grad, x_grad_name); this->RecoverOutputName(y_grad, y_grad_name); diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc index a7244062632699992533f851277563edce450998..6088ac3d01945c0d9da78a293a0074267083c46b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -70,6 +70,12 @@ class ElementwiseSubCompositeGradOpMaker auto dy_ptr = this->GetOutputPtr(&dy); std::string dy_name = this->GetOutputName(dy); int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument( + "We only support axis = -1 in composite sub_grad but we got: ", + axis)); VLOG(6) << "Runing sub_grad composite func"; prim::subtract_grad(x, y, out_grad, axis, dx_ptr, dy_ptr); this->RecoverOutputName(dx, dx_name); diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py index a4780cae6c076621f9a9d3fe78873e9c886b58b0..af183e8793e56c563296d27debfdf57a974ff7ff 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py @@ -18,8 +18,12 @@ import numpy as np from utils import SUB_TOLERANCE import paddle +import paddle.nn as nn import paddle.nn.functional as F -from paddle.fluid import core +from paddle.fluid import core, framework +from paddle.nn import BatchNorm +from paddle.tensor import ones # noqa: F401 +from paddle.tensor import zeros # noqa: F401 np.random.seed(2023) @@ -258,5 +262,69 @@ class TestCompositeBatchNorm(unittest.TestCase): self.compare_forward() +def apply_to_static(net, use_cinn): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = use_cinn + return paddle.jit.to_static(net, build_strategy=False) + + +class PrimeNet(paddle.nn.Layer): + def __init__(self): + super(PrimeNet, self).__init__() + self.conv = nn.Conv2D(4, 2, (3, 3), bias_attr=False) + self.bn = BatchNorm(2, act="relu") + self.run_mean = zeros([2]) + self.run_var = ones([2]) + self.scale = ones([2]) + self.bias = ones([2]) + + def forward(self, x): + y = self.conv(x) + out = self.bn(y) + res = F.max_pool2d(out, kernel_size=2, stride=2, padding=0) + return res + + +class TestPrimForwardAndBackward(unittest.TestCase): + """ + Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph + """ + + def setUp(self): + paddle.seed(2022) + self.x = paddle.randn([4, 4, 6, 6], dtype="float32") + self.x.stop_gradient = False + + def train(self, use_prim): + core._set_prim_all_enabled(use_prim) + paddle.seed(2022) + net = PrimeNet() + sgd = paddle.optimizer.SGD( + learning_rate=0.1, parameters=net.parameters() + ) + + net = paddle.amp.decorate(models=net, level='O2') + + net = apply_to_static(net, False) + with paddle.amp.auto_cast(level='O2'): + out = net(self.x) + loss = paddle.mean(out) + loss.backward() + sgd.step() + sgd.clear_grad() + return loss + + def test_amp(self): + if not isinstance(framework._current_expected_place(), core.CPUPlace): + expected = self.train(False) + actual = self.train(True) + np.testing.assert_allclose( + expected, + actual, + rtol=1e-3, + atol=1e-3, + ) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 1a57c7895706094718d2252a6ce6da7c0d6f8d24..341ed98efb02c2d44b3ad6e46a6f309a0b0efc93 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -61,6 +61,13 @@ def composite_batchnorm( trainable_statistics, ): """define composite rule of op batch_norm""" + is_amp = False + from paddle.fluid.data_feeder import convert_dtype + + if convert_dtype(x.dtype) == "float16": + print("Running batch_norm in amp") + is_amp = True + x = cast(x, "float32") feature_axis = ( 1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1 @@ -99,6 +106,8 @@ def composite_batchnorm( reshape(run_var, stats_shape) + epsilon ) y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape) + if is_amp: + y = cast(y, "float16") # add op assign to detach tensor in void unsafe change outside the rule. batch_mean_ = assign(reshape(batch_mean, run_mean.shape)) diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index ce5c15738ea64b465fd5e6977f648660be7fafec..3d86441087f092cd7cac06d3a882260ad182e832 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -18,7 +18,6 @@ import numpy as np import paddle from paddle import _legacy_C_ops -from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard from paddle.fluid import backward, core, framework, program_guard from paddle.fluid.compiler import BuildStrategy from paddle.fluid.dygraph import layers @@ -277,6 +276,11 @@ class PartialProgramLayer: paddle.static.amp.fp16_utils.cast_model_to_fp16( pure_fp16_program, self._amp_list, use_fp16_guard=False ) + + core.check_and_set_prim_all_enabled() + from paddle.incubate.autograd.primapi import to_prim + + to_prim(pure_fp16_program.blocks) if is_infer_mode: return pure_fp16_program else: @@ -431,6 +435,8 @@ class PartialProgramLayer: """ Return current train or eval program hash id. """ + from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard + if self.training: if _in_amp_guard(): return self._train_amp_program_id @@ -448,6 +454,8 @@ class PartialProgramLayer: @property def train_program(self): + from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard + if _in_amp_guard(): return self._train_amp_program elif _in_pure_fp16_guard(): @@ -457,6 +465,8 @@ class PartialProgramLayer: @property def infer_program(self): + from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard + if _in_amp_guard(): return self._infer_amp_program elif _in_pure_fp16_guard(): @@ -466,6 +476,8 @@ class PartialProgramLayer: @property def forward_program(self): + from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard + if self.training: if _in_amp_guard(): progs = self._train_amp_forward_backward_program @@ -479,6 +491,8 @@ class PartialProgramLayer: @property def backward_program(self): + from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard + if self.training: if _in_amp_guard(): progs = self._train_amp_forward_backward_program @@ -663,6 +677,8 @@ class PartialProgramLayer: return self._valid_vars(double_grads) def _cast_fp16_if_pure_fp16(self, in_vars): + from paddle.amp.auto_cast import _in_pure_fp16_guard + if _in_pure_fp16_guard(): for i, var in enumerate(in_vars): name = var.name diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index e3e8d8afdd1ce05f9b59a6bc4ab4abc2a3c7f7fb..73c313204d5da7bdd717962c1c111e9ac6f26bec 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -19,7 +19,7 @@ import threading import warnings import weakref -import paddle +from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard from paddle.fluid import _non_static_mode, core, framework from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import layers @@ -325,11 +325,10 @@ class StaticFunction: if input_spec is not None and prim_or_cinn_is_enabled( kwargs.get("build_strategy", None) ): + from paddle.static import InputSpec + for spec in flatten(input_spec): - if ( - isinstance(spec, paddle.static.InputSpec) - and -1 in spec.shape - ): + if isinstance(spec, InputSpec) and -1 in spec.shape: input_spec = None warnings.warn( 'Now prim and cinn do not support -1 shape, but input_spec has -1 shape so we set it to None.' @@ -1190,8 +1189,8 @@ class ProgramCache: var.name, var.shape ) ) - - concrete_program._to_prim() + if not _in_amp_guard() and not _in_pure_fp16_guard(): + concrete_program._to_prim() return concrete_program, partial_program_from(concrete_program) def __getitem__(self, item): diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 5fa2c0eda0dbe7248e3b98b4410b282bdc8f39a6..74d9806723ebb9f0d686c8a83bfac7db5c0137c7 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1889,7 +1889,7 @@ def linear(x, weight, bias=None, name=None): type='elementwise_add', inputs={'X': [tmp], 'Y': [bias]}, outputs={'Out': [res]}, - attrs={'axis': len(x.shape) - 1}, + attrs={'axis': -1}, ) else: res = tmp