未验证 提交 65950324 编写于 作者: C Charles-hit 提交者: GitHub

【AMP Prim OP】support instance_norm prim ops for fp16 and bf16 dtype (#55368)

* [prim]support fp16 for instance_norm and instance_norm_grad

* support fp16 and bfp16 dtype for instance_norm prim rules

* fix new ir test

---------
Co-authored-by: Ncxxly <chenxx_id@163.com>
上级 788be26d
......@@ -1384,40 +1384,80 @@ void instance_norm_grad(const Tensor& x,
const int h = x.dims()[2];
const int w = x.dims()[3];
auto promoted_y_grad = y_grad;
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
promoted_y_grad = cast<T>(y_grad, phi::DataType::FLOAT32);
}
Tensor x_hat;
Tensor std_inv;
if (scale_grad || x_grad) {
auto mean = reshape<T>(saved_mean, IntArray({n, c, 1, 1}))
auto promoted_x = x;
auto promoted_saved_mean = saved_mean;
auto promoted_saved_var = saved_variance;
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
promoted_x = cast<T>(x, phi::DataType::FLOAT32);
promoted_saved_mean = cast<T>(saved_mean, phi::DataType::FLOAT32);
promoted_saved_var = cast<T>(saved_variance, phi::DataType::FLOAT32);
}
auto mean = reshape<T>(promoted_saved_mean, IntArray({n, c, 1, 1}))
.tile(IntArray({1, 1, h, w}));
std_inv = reshape<T>(saved_variance, IntArray({n, c, 1, 1}))
std_inv = reshape<T>(promoted_saved_var, IntArray({n, c, 1, 1}))
.tile(IntArray({1, 1, h, w}));
x_hat = (x - mean) * std_inv;
x_hat = (promoted_x - mean) * std_inv;
}
// x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad *
// x_hat).mean((h,w)))
if (x_grad) {
auto scale_t =
auto scale_data =
reshape<T>(scale.get_ptr() ? scale.get()
: full<T>(IntArray({c}), 1., x.dtype()),
IntArray({1, c, 1, 1}))
.tile(IntArray({n, 1, h, w}));
set_output<T>(
(scale_t * std_inv) *
(y_grad -
y_grad.sum(IntArray({2, 3}), y_grad.dtype(), true) / (h * w) -
(x_hat *
((y_grad * x_hat).sum(IntArray({2, 3}), y_grad.dtype(), true) /
(h * w)))),
x_grad);
auto promoted_scale = scale_data;
if (scale_data.dtype() == phi::DataType::FLOAT16 ||
scale_data.dtype() == phi::DataType::BFLOAT16) {
promoted_scale = cast<T>(scale_data, phi::DataType::FLOAT32);
}
auto result =
(promoted_scale * std_inv) *
(promoted_y_grad -
promoted_y_grad.sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) /
(h * w) -
(x_hat * ((promoted_y_grad * x_hat)
.sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) /
(h * w))));
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, x.dtype()), x_grad);
} else {
set_output<T>(result, x_grad);
}
}
// scale_grad = x_hat * y_grad.sum(n, h, w)
if (scale_grad) {
set_output<T>((y_grad * x_hat).sum(IntArray({0, 2, 3})), scale_grad);
auto result = (promoted_y_grad * x_hat).sum(IntArray({0, 2, 3}));
auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype();
if (scale_dtype == phi::DataType::FLOAT16 ||
scale_dtype == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, scale_dtype), scale_grad);
} else {
set_output<T>(result, scale_grad);
}
}
// d_bias = y_grad.sum(n, h, w)
if (bias_grad) {
set_output<T>(y_grad.sum(IntArray({0, 2, 3})), bias_grad);
auto result = promoted_y_grad.sum(IntArray({0, 2, 3}));
auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype();
if (scale_dtype == phi::DataType::FLOAT16 ||
scale_dtype == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, scale_dtype), bias_grad);
} else {
set_output<T>(result, bias_grad);
}
}
}
......
......@@ -384,11 +384,18 @@ void InstanceNormInferMeta(const MetaTensor& x,
y->share_lod(x);
y->set_dtype(x.dtype());
y->set_layout(x.layout());
phi::DataType x_dtype = x.dtype();
phi::DataType param_type =
(x_dtype == phi::DataType::BFLOAT16 || x_dtype == phi::DataType::FLOAT16)
? phi::DataType::FLOAT32
: x_dtype;
if (saved_mean) {
saved_mean->set_dims({NxC});
saved_mean->set_dtype(param_type);
}
if (saved_variance) {
saved_variance->set_dims({NxC});
saved_variance->set_dtype(param_type);
}
}
......
......@@ -193,6 +193,16 @@ def instancenorm_composite(x, scale, bias, epsilon):
out = (x - mean(x)) / sqrt(var + epsilon))
var = mean((x-mean(x))^2)
"""
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
scale = cast(scale, "float32") if scale else scale
bias = cast(bias, "float32") if bias else bias
n, c, h, w = x.shape
axis = tuple(range(2, len(x.shape)))
mean_ = mean(x, axis=axis, keepdim=True)
......@@ -213,6 +223,10 @@ def instancenorm_composite(x, scale, bias, epsilon):
mean_ = reshape(mean_, [-1])
saved_variance = 1 / sqrt_var
saved_variance = reshape(saved_variance, [-1])
if is_amp:
out = cast(out, dtype)
return out, mean_, saved_variance
......
......@@ -12,15 +12,69 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.nn.functional as F
from paddle import fluid, nn
from paddle.fluid import Program, core, framework, program_guard
from paddle import fluid
from paddle.fluid import Program, core, program_guard
def instance_norm_wrapper(
input, weight, bias, epsilon=1e-5, momentum=0.9, data_format='NCHW'
):
if data_format == "AnyLayout":
data_format = "NCDHW"
return paddle.nn.functional.instance_norm(
input, None, None, weight, bias, True, momentum, epsilon, data_format
)
def _reference_instance_norm(x, scale, bias, epsilon):
N, C, H, W = x.shape
mean = np.mean(x, axis=(2, 3), keepdims=True)
variance = np.var(x, axis=(2, 3), keepdims=True)
std = np.sqrt(variance) + epsilon
x_norm = (x - mean) / std
scale = scale.reshape([1, C, 1, 1])
bias = bias.reshape([1, C, 1, 1])
x_norm = scale * x_norm + bias
return x_norm, mean.reshape(N * C), std.reshape(N * C)
def _reference_instance_norm_grad(x, scale, mean, var):
n, c, h, w = x.shape
d_y = np.ones(x.shape) / (np.prod(x.shape))
d_bias = np.ones((c,)) / c
mean_tile = np.reshape(mean, (n, c, 1, 1))
mean_tile = np.tile(mean_tile, (1, 1, h, w))
var_tile = np.reshape(var, (n, c, 1, 1))
var_tile = np.tile(var_tile, (1, 1, h, w))
d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3))
var_inv = var_tile
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
d_x = (
scale_tile
* var_inv
* (
d_y
- np.mean(d_y, axis=(2, 3), keepdims=True)
- (x - mean_tile)
* var_inv
* np.mean(
d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True
)
)
)
return d_x, d_scale, d_bias
class TestInstanceNorm(unittest.TestCase):
......@@ -86,95 +140,42 @@ class TestInstanceNorm(unittest.TestCase):
np.testing.assert_allclose(y1, y2, rtol=1e-05)
def test_static(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(
"instance_norm"
):
places.append(fluid.CUDAPlace(0))
for p in places:
exe = fluid.Executor(p)
shape = [4, 10, 16, 16]
def compute_v1(x_np):
with program_guard(Program(), Program()):
ins = paddle.nn.InstanceNorm2D(shape[1])
x = paddle.static.data(
name='x', shape=x_np.shape, dtype=x_np.dtype
)
y = ins(x)
exe.run(fluid.default_startup_program())
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
def compute_v2(x_np):
with program_guard(Program(), Program()):
ins = paddle.nn.InstanceNorm2D(shape[1])
x = paddle.static.data(
name='x', shape=x_np.shape, dtype=x_np.dtype
)
y = ins(x)
exe.run(fluid.default_startup_program())
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
np.testing.assert_allclose(y1, y2, rtol=1e-05)
def instance_norm_warpper(
input, weight, bias, epsilon=1e-5, momentum=0.9, data_format='NCHW'
):
if data_format == "AnyLayout":
data_format = "NCDHW"
return paddle._C_ops.instance_norm(
input, weight, bias, epsilon, momentum, data_format
)
def _reference_instance_norm(x, scale, bias, epsilon):
N, C, H, W = x.shape
mean = np.mean(x, axis=(2, 3), keepdims=True)
variance = np.var(x, axis=(2, 3), keepdims=True)
std = np.sqrt(variance) + epsilon
x_norm = (x - mean) / std
scale = scale.reshape([1, C, 1, 1])
bias = bias.reshape([1, C, 1, 1])
x_norm = scale * x_norm + bias
return x_norm, mean.reshape(N * C), std.reshape(N * C)
def _reference_instance_norm_grad(x, scale, mean, var):
n, c, h, w = x.shape
d_y = np.ones(x.shape) / (np.prod(x.shape))
d_bias = np.ones((c,)) / c
mean_tile = np.reshape(mean, (n, c, 1, 1))
mean_tile = np.tile(mean_tile, (1, 1, h, w))
var_tile = np.reshape(var, (n, c, 1, 1))
var_tile = np.tile(var_tile, (1, 1, h, w))
d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3))
var_inv = var_tile
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
d_x = (
scale_tile
* var_inv
* (
d_y
- np.mean(d_y, axis=(2, 3), keepdims=True)
- (x - mean_tile)
* var_inv
* np.mean(
d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True
)
)
)
return d_x, d_scale, d_bias
with paddle.fluid.framework._static_guard():
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(
"instance_norm"
):
places.append(fluid.CUDAPlace(0))
for p in places:
exe = fluid.Executor(p)
shape = [4, 10, 16, 16]
def compute_v1(x_np):
with program_guard(Program(), Program()):
ins = paddle.nn.InstanceNorm2D(shape[1])
x = paddle.static.data(
name='x', shape=x_np.shape, dtype=x_np.dtype
)
y = ins(x)
exe.run(fluid.default_startup_program())
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
def compute_v2(x_np):
with program_guard(Program(), Program()):
ins = paddle.nn.InstanceNorm2D(shape[1])
x = paddle.static.data(
name='x', shape=x_np.shape, dtype=x_np.dtype
)
y = ins(x)
exe.run(fluid.default_startup_program())
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x)
y2 = compute_v2(x)
np.testing.assert_allclose(y1, y2, rtol=1e-05)
class TestInstanceNormFP32OP(OpTest):
......@@ -182,7 +183,6 @@ class TestInstanceNormFP32OP(OpTest):
'''Test instance_norm op with default value'''
self.op_type = "instance_norm"
self.__class__.op_type = self.op_type
self.python_api = instance_norm_warpper
self.data_format = "NCHW"
self.eps = 1e-5
self.init_dtype()
......@@ -204,15 +204,18 @@ class TestInstanceNormFP32OP(OpTest):
'SavedMean': mean,
'SavedVariance': 1.0 / variance,
}
self.prim_op_type = "comp"
self.python_api = instance_norm_wrapper
self.public_python_api = instance_norm_wrapper
self.check_prim = (
False if os.getenv("FLAGS_enable_new_ir_in_executor") else True
)
def test_check_output(self):
self.check_output(atol=self.atol)
self.check_output(atol=self.atol, check_prim=self.check_prim)
def test_check_grad(self):
self.check_grad(
['X', 'Scale', 'Bias'],
'Y',
)
self.check_grad(['X', 'Scale', 'Bias'], 'Y', check_prim=self.check_prim)
def init_dtype(self):
self.dtype = np.float32
......@@ -228,6 +231,12 @@ class TestInstanceNormFP32OP(OpTest):
def set_err_thre(self):
self.atol = 1e-3
self.fw_comp_rtol = 1e-6
self.fw_comp_atol = 1e-6
self.rev_comp_rtol = 1e-4
self.rev_comp_atol = 1e-4
self.cinn_rtol = 1e-4
self.cinn_atol = 1e-4
@unittest.skipIf(
......@@ -236,6 +245,9 @@ class TestInstanceNormFP32OP(OpTest):
"core is not compiled with CUDA or not support the float16",
)
class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
def setUp(self):
super().setUp()
def init_dtype(self):
self.dtype = np.float16
......@@ -245,7 +257,9 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=self.atol)
self.check_output_with_place(
place, atol=self.atol, check_prim=self.check_prim
)
def test_check_grad(self):
place = core.CUDAPlace(0)
......@@ -254,6 +268,7 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
['X', 'Scale', 'Bias'],
'Y',
max_relative_error=self.max_relative_error,
check_prim=self.check_prim,
)
......@@ -265,8 +280,10 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
class TestInstanceNormBF16OP(OpTest):
def setUp(self):
self.op_type = "instance_norm"
self.prim_op_type = "comp"
self.__class__.op_type = self.op_type
self.python_api = instance_norm_warpper
self.python_api = instance_norm_wrapper
self.public_python_api = instance_norm_wrapper
self.eps = 1e-5
self.data_format = "NCHW"
self.dtype = np.uint16
......@@ -296,6 +313,9 @@ class TestInstanceNormBF16OP(OpTest):
'momentum': 0.9,
'data_format': self.data_format,
}
self.check_prim = (
False if os.getenv("FLAGS_enable_new_ir_in_executor") else True
)
def init_value(self):
np.random.seed(0)
......@@ -308,7 +328,7 @@ class TestInstanceNormBF16OP(OpTest):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_prim=self.check_prim)
def test_check_grad(self):
place = core.CUDAPlace(0)
......@@ -317,19 +337,22 @@ class TestInstanceNormBF16OP(OpTest):
['X', 'Scale', 'Bias'],
'Y',
user_defined_grads=self.user_defined_grads,
check_prim=self.check_prim,
)
class PrimNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.instance_norm = nn.InstanceNorm2D(4)
self.conv = paddle.nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.instance_norm = paddle.nn.InstanceNorm2D(4)
def forward(self, x):
y = self.conv(x)
out = self.instance_norm(y)
res = F.max_pool2d(out, kernel_size=2, stride=2, padding=0)
res = paddle.nn.functional.max_pool2d(
out, kernel_size=2, stride=2, padding=0
)
return res
......@@ -368,7 +391,9 @@ class TestPrimForwardAndBackward(unittest.TestCase):
return loss
def test_amp_nchw(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace):
if not isinstance(
paddle.fluid.framework._current_expected_place(), core.CPUPlace
):
expected = self.train(False)
actual = self.train(True)
np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册