未验证 提交 6664a232 编写于 作者: J Jiabin Yang 提交者: GitHub

【Prim】Fix prim amp (#50518)

* change amp with to_prim

* fix prim amp

* fix rules

* fix liear

* add amp test

* add test

* disable this test on cpu

* disable this test on cpu

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
上级 07c416c8
......@@ -67,6 +67,12 @@ class ElementwiseAddCompositeGradOpMaker
auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite add_grad but we got: ",
axis));
VLOG(6) << "Runing add_grad composite func";
prim::add_grad<prim::DescTensor>(x, y, out_grad, axis, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
......
......@@ -84,6 +84,11 @@ class ElementwiseDivCompositeGradOpMaker
auto dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite div but we got: ", axis));
VLOG(6) << "Runing div_grad composite func";
prim::divide_grad<prim::DescTensor>(
x, y, out, out_grad, axis, dx_ptr, dy_ptr);
......
......@@ -81,13 +81,15 @@ class ElementwiseMulCompositeGradOpMaker
auto y_grad = this->GetSingleInputGrad("Y");
auto y_grad_p = this->GetOutputPtr(&y_grad);
auto y_grad_name = this->GetOutputName(y_grad);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite mul_grad but we got: ",
axis));
prim::multiply_grad<prim::DescTensor>(
x,
y,
out_grad,
static_cast<int>(this->Attr<int>("axis")),
x_grad_p,
y_grad_p);
x, y, out_grad, axis, x_grad_p, y_grad_p);
VLOG(6) << "Runing mul_grad composite func";
this->RecoverOutputName(x_grad, x_grad_name);
this->RecoverOutputName(y_grad, y_grad_name);
......
......@@ -70,6 +70,12 @@ class ElementwiseSubCompositeGradOpMaker
auto dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite sub_grad but we got: ",
axis));
VLOG(6) << "Runing sub_grad composite func";
prim::subtract_grad<prim::DescTensor>(x, y, out_grad, axis, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
......
......@@ -18,8 +18,12 @@ import numpy as np
from utils import SUB_TOLERANCE
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.fluid import core
from paddle.fluid import core, framework
from paddle.nn import BatchNorm
from paddle.tensor import ones # noqa: F401
from paddle.tensor import zeros # noqa: F401
np.random.seed(2023)
......@@ -258,5 +262,69 @@ class TestCompositeBatchNorm(unittest.TestCase):
self.compare_forward()
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=False)
class PrimeNet(paddle.nn.Layer):
def __init__(self):
super(PrimeNet, self).__init__()
self.conv = nn.Conv2D(4, 2, (3, 3), bias_attr=False)
self.bn = BatchNorm(2, act="relu")
self.run_mean = zeros([2])
self.run_var = ones([2])
self.scale = ones([2])
self.bias = ones([2])
def forward(self, x):
y = self.conv(x)
out = self.bn(y)
res = F.max_pool2d(out, kernel_size=2, stride=2, padding=0)
return res
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
paddle.seed(2022)
self.x = paddle.randn([4, 4, 6, 6], dtype="float32")
self.x.stop_gradient = False
def train(self, use_prim):
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
net = paddle.amp.decorate(models=net, level='O2')
net = apply_to_static(net, False)
with paddle.amp.auto_cast(level='O2'):
out = net(self.x)
loss = paddle.mean(out)
loss.backward()
sgd.step()
sgd.clear_grad()
return loss
def test_amp(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
expected = self.train(False)
actual = self.train(True)
np.testing.assert_allclose(
expected,
actual,
rtol=1e-3,
atol=1e-3,
)
if __name__ == '__main__':
unittest.main()
......@@ -61,6 +61,13 @@ def composite_batchnorm(
trainable_statistics,
):
"""define composite rule of op batch_norm"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
print("Running batch_norm in amp")
is_amp = True
x = cast(x, "float32")
feature_axis = (
1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1
......@@ -99,6 +106,8 @@ def composite_batchnorm(
reshape(run_var, stats_shape) + epsilon
)
y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape)
if is_amp:
y = cast(y, "float16")
# add op assign to detach tensor in void unsafe change outside the rule.
batch_mean_ = assign(reshape(batch_mean, run_mean.shape))
......
......@@ -18,7 +18,6 @@ import numpy as np
import paddle
from paddle import _legacy_C_ops
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.dygraph import layers
......@@ -277,6 +276,11 @@ class PartialProgramLayer:
paddle.static.amp.fp16_utils.cast_model_to_fp16(
pure_fp16_program, self._amp_list, use_fp16_guard=False
)
core.check_and_set_prim_all_enabled()
from paddle.incubate.autograd.primapi import to_prim
to_prim(pure_fp16_program.blocks)
if is_infer_mode:
return pure_fp16_program
else:
......@@ -431,6 +435,8 @@ class PartialProgramLayer:
"""
Return current train or eval program hash id.
"""
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
if self.training:
if _in_amp_guard():
return self._train_amp_program_id
......@@ -448,6 +454,8 @@ class PartialProgramLayer:
@property
def train_program(self):
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
......@@ -457,6 +465,8 @@ class PartialProgramLayer:
@property
def infer_program(self):
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
......@@ -466,6 +476,8 @@ class PartialProgramLayer:
@property
def forward_program(self):
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
if self.training:
if _in_amp_guard():
progs = self._train_amp_forward_backward_program
......@@ -479,6 +491,8 @@ class PartialProgramLayer:
@property
def backward_program(self):
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
if self.training:
if _in_amp_guard():
progs = self._train_amp_forward_backward_program
......@@ -663,6 +677,8 @@ class PartialProgramLayer:
return self._valid_vars(double_grads)
def _cast_fp16_if_pure_fp16(self, in_vars):
from paddle.amp.auto_cast import _in_pure_fp16_guard
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
name = var.name
......
......@@ -19,7 +19,7 @@ import threading
import warnings
import weakref
import paddle
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import _non_static_mode, core, framework
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import layers
......@@ -325,11 +325,10 @@ class StaticFunction:
if input_spec is not None and prim_or_cinn_is_enabled(
kwargs.get("build_strategy", None)
):
from paddle.static import InputSpec
for spec in flatten(input_spec):
if (
isinstance(spec, paddle.static.InputSpec)
and -1 in spec.shape
):
if isinstance(spec, InputSpec) and -1 in spec.shape:
input_spec = None
warnings.warn(
'Now prim and cinn do not support -1 shape, but input_spec has -1 shape so we set it to None.'
......@@ -1190,8 +1189,8 @@ class ProgramCache:
var.name, var.shape
)
)
concrete_program._to_prim()
if not _in_amp_guard() and not _in_pure_fp16_guard():
concrete_program._to_prim()
return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item):
......
......@@ -1889,7 +1889,7 @@ def linear(x, weight, bias=None, name=None):
type='elementwise_add',
inputs={'X': [tmp], 'Y': [bias]},
outputs={'Out': [res]},
attrs={'axis': len(x.shape) - 1},
attrs={'axis': -1},
)
else:
res = tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册