未验证 提交 65950324 编写于 作者: C Charles-hit 提交者: GitHub

【AMP Prim OP】support instance_norm prim ops for fp16 and bf16 dtype (#55368)

* [prim]support fp16 for instance_norm and instance_norm_grad

* support fp16 and bfp16 dtype for instance_norm prim rules

* fix new ir test

---------
Co-authored-by: Ncxxly <chenxx_id@163.com>
上级 788be26d
...@@ -1384,40 +1384,80 @@ void instance_norm_grad(const Tensor& x, ...@@ -1384,40 +1384,80 @@ void instance_norm_grad(const Tensor& x,
const int h = x.dims()[2]; const int h = x.dims()[2];
const int w = x.dims()[3]; const int w = x.dims()[3];
auto promoted_y_grad = y_grad;
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
promoted_y_grad = cast<T>(y_grad, phi::DataType::FLOAT32);
}
Tensor x_hat; Tensor x_hat;
Tensor std_inv; Tensor std_inv;
if (scale_grad || x_grad) { if (scale_grad || x_grad) {
auto mean = reshape<T>(saved_mean, IntArray({n, c, 1, 1})) auto promoted_x = x;
auto promoted_saved_mean = saved_mean;
auto promoted_saved_var = saved_variance;
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
promoted_x = cast<T>(x, phi::DataType::FLOAT32);
promoted_saved_mean = cast<T>(saved_mean, phi::DataType::FLOAT32);
promoted_saved_var = cast<T>(saved_variance, phi::DataType::FLOAT32);
}
auto mean = reshape<T>(promoted_saved_mean, IntArray({n, c, 1, 1}))
.tile(IntArray({1, 1, h, w})); .tile(IntArray({1, 1, h, w}));
std_inv = reshape<T>(saved_variance, IntArray({n, c, 1, 1})) std_inv = reshape<T>(promoted_saved_var, IntArray({n, c, 1, 1}))
.tile(IntArray({1, 1, h, w})); .tile(IntArray({1, 1, h, w}));
x_hat = (x - mean) * std_inv; x_hat = (promoted_x - mean) * std_inv;
} }
// x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad * // x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad *
// x_hat).mean((h,w))) // x_hat).mean((h,w)))
if (x_grad) { if (x_grad) {
auto scale_t = auto scale_data =
reshape<T>(scale.get_ptr() ? scale.get() reshape<T>(scale.get_ptr() ? scale.get()
: full<T>(IntArray({c}), 1., x.dtype()), : full<T>(IntArray({c}), 1., x.dtype()),
IntArray({1, c, 1, 1})) IntArray({1, c, 1, 1}))
.tile(IntArray({n, 1, h, w})); .tile(IntArray({n, 1, h, w}));
set_output<T>( auto promoted_scale = scale_data;
(scale_t * std_inv) * if (scale_data.dtype() == phi::DataType::FLOAT16 ||
(y_grad - scale_data.dtype() == phi::DataType::BFLOAT16) {
y_grad.sum(IntArray({2, 3}), y_grad.dtype(), true) / (h * w) - promoted_scale = cast<T>(scale_data, phi::DataType::FLOAT32);
(x_hat * }
((y_grad * x_hat).sum(IntArray({2, 3}), y_grad.dtype(), true) / auto result =
(h * w)))), (promoted_scale * std_inv) *
x_grad); (promoted_y_grad -
promoted_y_grad.sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) /
(h * w) -
(x_hat * ((promoted_y_grad * x_hat)
.sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) /
(h * w))));
if (x.dtype() == phi::DataType::FLOAT16 ||
x.dtype() == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, x.dtype()), x_grad);
} else {
set_output<T>(result, x_grad);
}
} }
// scale_grad = x_hat * y_grad.sum(n, h, w) // scale_grad = x_hat * y_grad.sum(n, h, w)
if (scale_grad) { if (scale_grad) {
set_output<T>((y_grad * x_hat).sum(IntArray({0, 2, 3})), scale_grad); auto result = (promoted_y_grad * x_hat).sum(IntArray({0, 2, 3}));
auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype();
if (scale_dtype == phi::DataType::FLOAT16 ||
scale_dtype == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, scale_dtype), scale_grad);
} else {
set_output<T>(result, scale_grad);
}
} }
// d_bias = y_grad.sum(n, h, w) // d_bias = y_grad.sum(n, h, w)
if (bias_grad) { if (bias_grad) {
set_output<T>(y_grad.sum(IntArray({0, 2, 3})), bias_grad); auto result = promoted_y_grad.sum(IntArray({0, 2, 3}));
auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype();
if (scale_dtype == phi::DataType::FLOAT16 ||
scale_dtype == phi::DataType::BFLOAT16) {
set_output<T>(cast<T>(result, scale_dtype), bias_grad);
} else {
set_output<T>(result, bias_grad);
}
} }
} }
......
...@@ -384,11 +384,18 @@ void InstanceNormInferMeta(const MetaTensor& x, ...@@ -384,11 +384,18 @@ void InstanceNormInferMeta(const MetaTensor& x,
y->share_lod(x); y->share_lod(x);
y->set_dtype(x.dtype()); y->set_dtype(x.dtype());
y->set_layout(x.layout()); y->set_layout(x.layout());
phi::DataType x_dtype = x.dtype();
phi::DataType param_type =
(x_dtype == phi::DataType::BFLOAT16 || x_dtype == phi::DataType::FLOAT16)
? phi::DataType::FLOAT32
: x_dtype;
if (saved_mean) { if (saved_mean) {
saved_mean->set_dims({NxC}); saved_mean->set_dims({NxC});
saved_mean->set_dtype(param_type);
} }
if (saved_variance) { if (saved_variance) {
saved_variance->set_dims({NxC}); saved_variance->set_dims({NxC});
saved_variance->set_dtype(param_type);
} }
} }
......
...@@ -193,6 +193,16 @@ def instancenorm_composite(x, scale, bias, epsilon): ...@@ -193,6 +193,16 @@ def instancenorm_composite(x, scale, bias, epsilon):
out = (x - mean(x)) / sqrt(var + epsilon)) out = (x - mean(x)) / sqrt(var + epsilon))
var = mean((x-mean(x))^2) var = mean((x-mean(x))^2)
""" """
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
dtype = convert_dtype(x.dtype)
if dtype in ["float16", "uint16"]:
is_amp = True
x = cast(x, "float32")
scale = cast(scale, "float32") if scale else scale
bias = cast(bias, "float32") if bias else bias
n, c, h, w = x.shape n, c, h, w = x.shape
axis = tuple(range(2, len(x.shape))) axis = tuple(range(2, len(x.shape)))
mean_ = mean(x, axis=axis, keepdim=True) mean_ = mean(x, axis=axis, keepdim=True)
...@@ -213,6 +223,10 @@ def instancenorm_composite(x, scale, bias, epsilon): ...@@ -213,6 +223,10 @@ def instancenorm_composite(x, scale, bias, epsilon):
mean_ = reshape(mean_, [-1]) mean_ = reshape(mean_, [-1])
saved_variance = 1 / sqrt_var saved_variance = 1 / sqrt_var
saved_variance = reshape(saved_variance, [-1]) saved_variance = reshape(saved_variance, [-1])
if is_amp:
out = cast(out, dtype)
return out, mean_, saved_variance return out, mean_, saved_variance
......
...@@ -12,15 +12,69 @@ ...@@ -12,15 +12,69 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16 from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.nn.functional as F from paddle import fluid
from paddle import fluid, nn from paddle.fluid import Program, core, program_guard
from paddle.fluid import Program, core, framework, program_guard
def instance_norm_wrapper(
input, weight, bias, epsilon=1e-5, momentum=0.9, data_format='NCHW'
):
if data_format == "AnyLayout":
data_format = "NCDHW"
return paddle.nn.functional.instance_norm(
input, None, None, weight, bias, True, momentum, epsilon, data_format
)
def _reference_instance_norm(x, scale, bias, epsilon):
N, C, H, W = x.shape
mean = np.mean(x, axis=(2, 3), keepdims=True)
variance = np.var(x, axis=(2, 3), keepdims=True)
std = np.sqrt(variance) + epsilon
x_norm = (x - mean) / std
scale = scale.reshape([1, C, 1, 1])
bias = bias.reshape([1, C, 1, 1])
x_norm = scale * x_norm + bias
return x_norm, mean.reshape(N * C), std.reshape(N * C)
def _reference_instance_norm_grad(x, scale, mean, var):
n, c, h, w = x.shape
d_y = np.ones(x.shape) / (np.prod(x.shape))
d_bias = np.ones((c,)) / c
mean_tile = np.reshape(mean, (n, c, 1, 1))
mean_tile = np.tile(mean_tile, (1, 1, h, w))
var_tile = np.reshape(var, (n, c, 1, 1))
var_tile = np.tile(var_tile, (1, 1, h, w))
d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3))
var_inv = var_tile
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
d_x = (
scale_tile
* var_inv
* (
d_y
- np.mean(d_y, axis=(2, 3), keepdims=True)
- (x - mean_tile)
* var_inv
* np.mean(
d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True
)
)
)
return d_x, d_scale, d_bias
class TestInstanceNorm(unittest.TestCase): class TestInstanceNorm(unittest.TestCase):
...@@ -86,95 +140,42 @@ class TestInstanceNorm(unittest.TestCase): ...@@ -86,95 +140,42 @@ class TestInstanceNorm(unittest.TestCase):
np.testing.assert_allclose(y1, y2, rtol=1e-05) np.testing.assert_allclose(y1, y2, rtol=1e-05)
def test_static(self): def test_static(self):
places = [fluid.CPUPlace()] with paddle.fluid.framework._static_guard():
if core.is_compiled_with_cuda() and core.op_support_gpu( places = [fluid.CPUPlace()]
"instance_norm" if core.is_compiled_with_cuda() and core.op_support_gpu(
): "instance_norm"
places.append(fluid.CUDAPlace(0)) ):
for p in places: places.append(fluid.CUDAPlace(0))
exe = fluid.Executor(p) for p in places:
shape = [4, 10, 16, 16] exe = fluid.Executor(p)
shape = [4, 10, 16, 16]
def compute_v1(x_np):
with program_guard(Program(), Program()): def compute_v1(x_np):
ins = paddle.nn.InstanceNorm2D(shape[1]) with program_guard(Program(), Program()):
x = paddle.static.data( ins = paddle.nn.InstanceNorm2D(shape[1])
name='x', shape=x_np.shape, dtype=x_np.dtype x = paddle.static.data(
) name='x', shape=x_np.shape, dtype=x_np.dtype
y = ins(x) )
exe.run(fluid.default_startup_program()) y = ins(x)
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] exe.run(fluid.default_startup_program())
return r r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
def compute_v2(x_np):
with program_guard(Program(), Program()): def compute_v2(x_np):
ins = paddle.nn.InstanceNorm2D(shape[1]) with program_guard(Program(), Program()):
x = paddle.static.data( ins = paddle.nn.InstanceNorm2D(shape[1])
name='x', shape=x_np.shape, dtype=x_np.dtype x = paddle.static.data(
) name='x', shape=x_np.shape, dtype=x_np.dtype
y = ins(x) )
exe.run(fluid.default_startup_program()) y = ins(x)
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] exe.run(fluid.default_startup_program())
return r r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x) x = np.random.randn(*shape).astype("float32")
y2 = compute_v2(x) y1 = compute_v1(x)
np.testing.assert_allclose(y1, y2, rtol=1e-05) y2 = compute_v2(x)
np.testing.assert_allclose(y1, y2, rtol=1e-05)
def instance_norm_warpper(
input, weight, bias, epsilon=1e-5, momentum=0.9, data_format='NCHW'
):
if data_format == "AnyLayout":
data_format = "NCDHW"
return paddle._C_ops.instance_norm(
input, weight, bias, epsilon, momentum, data_format
)
def _reference_instance_norm(x, scale, bias, epsilon):
N, C, H, W = x.shape
mean = np.mean(x, axis=(2, 3), keepdims=True)
variance = np.var(x, axis=(2, 3), keepdims=True)
std = np.sqrt(variance) + epsilon
x_norm = (x - mean) / std
scale = scale.reshape([1, C, 1, 1])
bias = bias.reshape([1, C, 1, 1])
x_norm = scale * x_norm + bias
return x_norm, mean.reshape(N * C), std.reshape(N * C)
def _reference_instance_norm_grad(x, scale, mean, var):
n, c, h, w = x.shape
d_y = np.ones(x.shape) / (np.prod(x.shape))
d_bias = np.ones((c,)) / c
mean_tile = np.reshape(mean, (n, c, 1, 1))
mean_tile = np.tile(mean_tile, (1, 1, h, w))
var_tile = np.reshape(var, (n, c, 1, 1))
var_tile = np.tile(var_tile, (1, 1, h, w))
d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3))
var_inv = var_tile
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
d_x = (
scale_tile
* var_inv
* (
d_y
- np.mean(d_y, axis=(2, 3), keepdims=True)
- (x - mean_tile)
* var_inv
* np.mean(
d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True
)
)
)
return d_x, d_scale, d_bias
class TestInstanceNormFP32OP(OpTest): class TestInstanceNormFP32OP(OpTest):
...@@ -182,7 +183,6 @@ class TestInstanceNormFP32OP(OpTest): ...@@ -182,7 +183,6 @@ class TestInstanceNormFP32OP(OpTest):
'''Test instance_norm op with default value''' '''Test instance_norm op with default value'''
self.op_type = "instance_norm" self.op_type = "instance_norm"
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
self.python_api = instance_norm_warpper
self.data_format = "NCHW" self.data_format = "NCHW"
self.eps = 1e-5 self.eps = 1e-5
self.init_dtype() self.init_dtype()
...@@ -204,15 +204,18 @@ class TestInstanceNormFP32OP(OpTest): ...@@ -204,15 +204,18 @@ class TestInstanceNormFP32OP(OpTest):
'SavedMean': mean, 'SavedMean': mean,
'SavedVariance': 1.0 / variance, 'SavedVariance': 1.0 / variance,
} }
self.prim_op_type = "comp"
self.python_api = instance_norm_wrapper
self.public_python_api = instance_norm_wrapper
self.check_prim = (
False if os.getenv("FLAGS_enable_new_ir_in_executor") else True
)
def test_check_output(self): def test_check_output(self):
self.check_output(atol=self.atol) self.check_output(atol=self.atol, check_prim=self.check_prim)
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(['X', 'Scale', 'Bias'], 'Y', check_prim=self.check_prim)
['X', 'Scale', 'Bias'],
'Y',
)
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
...@@ -228,6 +231,12 @@ class TestInstanceNormFP32OP(OpTest): ...@@ -228,6 +231,12 @@ class TestInstanceNormFP32OP(OpTest):
def set_err_thre(self): def set_err_thre(self):
self.atol = 1e-3 self.atol = 1e-3
self.fw_comp_rtol = 1e-6
self.fw_comp_atol = 1e-6
self.rev_comp_rtol = 1e-4
self.rev_comp_atol = 1e-4
self.cinn_rtol = 1e-4
self.cinn_atol = 1e-4
@unittest.skipIf( @unittest.skipIf(
...@@ -236,6 +245,9 @@ class TestInstanceNormFP32OP(OpTest): ...@@ -236,6 +245,9 @@ class TestInstanceNormFP32OP(OpTest):
"core is not compiled with CUDA or not support the float16", "core is not compiled with CUDA or not support the float16",
) )
class TestInstanceNormFP16OP(TestInstanceNormFP32OP): class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
def setUp(self):
super().setUp()
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
...@@ -245,7 +257,9 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP): ...@@ -245,7 +257,9 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=self.atol) self.check_output_with_place(
place, atol=self.atol, check_prim=self.check_prim
)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -254,6 +268,7 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP): ...@@ -254,6 +268,7 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
['X', 'Scale', 'Bias'], ['X', 'Scale', 'Bias'],
'Y', 'Y',
max_relative_error=self.max_relative_error, max_relative_error=self.max_relative_error,
check_prim=self.check_prim,
) )
...@@ -265,8 +280,10 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP): ...@@ -265,8 +280,10 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
class TestInstanceNormBF16OP(OpTest): class TestInstanceNormBF16OP(OpTest):
def setUp(self): def setUp(self):
self.op_type = "instance_norm" self.op_type = "instance_norm"
self.prim_op_type = "comp"
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
self.python_api = instance_norm_warpper self.python_api = instance_norm_wrapper
self.public_python_api = instance_norm_wrapper
self.eps = 1e-5 self.eps = 1e-5
self.data_format = "NCHW" self.data_format = "NCHW"
self.dtype = np.uint16 self.dtype = np.uint16
...@@ -296,6 +313,9 @@ class TestInstanceNormBF16OP(OpTest): ...@@ -296,6 +313,9 @@ class TestInstanceNormBF16OP(OpTest):
'momentum': 0.9, 'momentum': 0.9,
'data_format': self.data_format, 'data_format': self.data_format,
} }
self.check_prim = (
False if os.getenv("FLAGS_enable_new_ir_in_executor") else True
)
def init_value(self): def init_value(self):
np.random.seed(0) np.random.seed(0)
...@@ -308,7 +328,7 @@ class TestInstanceNormBF16OP(OpTest): ...@@ -308,7 +328,7 @@ class TestInstanceNormBF16OP(OpTest):
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place, check_prim=self.check_prim)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -317,19 +337,22 @@ class TestInstanceNormBF16OP(OpTest): ...@@ -317,19 +337,22 @@ class TestInstanceNormBF16OP(OpTest):
['X', 'Scale', 'Bias'], ['X', 'Scale', 'Bias'],
'Y', 'Y',
user_defined_grads=self.user_defined_grads, user_defined_grads=self.user_defined_grads,
check_prim=self.check_prim,
) )
class PrimNet(paddle.nn.Layer): class PrimNet(paddle.nn.Layer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False) self.conv = paddle.nn.Conv2D(2, 4, (3, 3), bias_attr=False)
self.instance_norm = nn.InstanceNorm2D(4) self.instance_norm = paddle.nn.InstanceNorm2D(4)
def forward(self, x): def forward(self, x):
y = self.conv(x) y = self.conv(x)
out = self.instance_norm(y) out = self.instance_norm(y)
res = F.max_pool2d(out, kernel_size=2, stride=2, padding=0) res = paddle.nn.functional.max_pool2d(
out, kernel_size=2, stride=2, padding=0
)
return res return res
...@@ -368,7 +391,9 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -368,7 +391,9 @@ class TestPrimForwardAndBackward(unittest.TestCase):
return loss return loss
def test_amp_nchw(self): def test_amp_nchw(self):
if not isinstance(framework._current_expected_place(), core.CPUPlace): if not isinstance(
paddle.fluid.framework._current_expected_place(), core.CPUPlace
):
expected = self.train(False) expected = self.train(False)
actual = self.train(True) actual = self.train(True)
np.testing.assert_allclose( np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册