未验证 提交 d1d43c22 编写于 作者: C Charles-hit 提交者: GitHub

support some prim ops bf16 dtype (#54263)

上级 585f1136
...@@ -542,7 +542,15 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { ...@@ -542,7 +542,15 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
template <typename T> template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
set_output<T>(out_grad * out, x_grad); if (out.dtype() == phi::DataType::FLOAT16 ||
out.dtype() == phi::DataType::BFLOAT16) {
Tensor out_promote = cast<T>(out, phi::DataType::FLOAT32);
Tensor out_grad_promote = cast<T>(out_grad, phi::DataType::FLOAT32);
set_output<T>(cast<T>(out_promote * out_grad_promote, out.dtype()),
x_grad);
} else {
set_output<T>(out_grad * out, x_grad);
}
} }
} }
......
...@@ -113,6 +113,7 @@ class TestExpFp32_Prim(OpTest): ...@@ -113,6 +113,7 @@ class TestExpFp32_Prim(OpTest):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.if_enable_cinn() self.if_enable_cinn()
self.convert_input_output()
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -129,6 +130,9 @@ class TestExpFp32_Prim(OpTest): ...@@ -129,6 +130,9 @@ class TestExpFp32_Prim(OpTest):
def if_enable_cinn(self): def if_enable_cinn(self):
pass pass
def convert_input_output(self):
pass
class TestExpFp64_Prim(TestExpFp32_Prim): class TestExpFp64_Prim(TestExpFp32_Prim):
def init_dtype(self): def init_dtype(self):
...@@ -4003,6 +4007,7 @@ def create_test_act_fp16_class( ...@@ -4003,6 +4007,7 @@ def create_test_act_fp16_class(
create_test_act_fp16_class(TestActivation) create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestExpFp32_Prim, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestExpm1) create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True) create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True) create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True)
...@@ -4133,6 +4138,7 @@ def create_test_act_bf16_class( ...@@ -4133,6 +4138,7 @@ def create_test_act_bf16_class(
create_test_act_bf16_class(TestActivation) create_test_act_bf16_class(TestActivation)
create_test_act_bf16_class(TestExpFp32_Prim, check_prim=True)
create_test_act_bf16_class(TestExpm1) create_test_act_bf16_class(TestExpm1)
create_test_act_bf16_class(TestSigmoid, check_prim=True) create_test_act_bf16_class(TestSigmoid, check_prim=True)
create_test_act_bf16_class(TestSilu, check_prim=True) create_test_act_bf16_class(TestSilu, check_prim=True)
......
...@@ -108,6 +108,9 @@ class TestCastOpBf16ToFp32(OpTest): ...@@ -108,6 +108,9 @@ class TestCastOpBf16ToFp32(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = cast_wrapper self.python_api = cast_wrapper
self.public_python_api = cast_wrapper self.public_python_api = cast_wrapper
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
...@@ -130,6 +133,9 @@ class TestCastOpFp32ToBf16(OpTest): ...@@ -130,6 +133,9 @@ class TestCastOpFp32ToBf16(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = cast_wrapper self.python_api = cast_wrapper
self.public_python_api = cast_wrapper self.public_python_api = cast_wrapper
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False self.enable_cinn = False
def test_check_output(self): def test_check_output(self):
......
...@@ -222,9 +222,8 @@ class TestElementwiseDivOpBF16(ElementwiseDivOp): ...@@ -222,9 +222,8 @@ class TestElementwiseDivOpBF16(ElementwiseDivOp):
check_args.insert(0, self.place) check_args.insert(0, self.place)
self.check_grad_with_place(*check_args, **check_kwargs) self.check_grad_with_place(*check_args, **check_kwargs)
# elementwise_pow does't support bfloat16
def if_check_prim(self): def if_check_prim(self):
self.check_prim = False self.check_prim = True
def if_enable_cinn(self): def if_enable_cinn(self):
self.enable_cinn = False self.enable_cinn = False
......
...@@ -211,7 +211,7 @@ class TestReduceMeanBF16Op(OpTest): ...@@ -211,7 +211,7 @@ class TestReduceMeanBF16Op(OpTest):
self.axis = [0] self.axis = [0]
self.keepdim = False self.keepdim = False
self.set_attrs() self.set_attrs()
self.enable_cinn = False self.if_enable_cinn()
np.random.seed(10) np.random.seed(10)
x_np = np.random.uniform(-1, 1, self.shape).astype(np.float32) x_np = np.random.uniform(-1, 1, self.shape).astype(np.float32)
...@@ -227,6 +227,9 @@ class TestReduceMeanBF16Op(OpTest): ...@@ -227,6 +227,9 @@ class TestReduceMeanBF16Op(OpTest):
'reduce_all': self.reduce_all, 'reduce_all': self.reduce_all,
} }
def if_enable_cinn(self):
self.enable_cinn = False
def set_attrs(self): def set_attrs(self):
pass pass
......
...@@ -398,7 +398,9 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): ...@@ -398,7 +398,9 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
class TestSoftmaxBF16Op(OpTest): class TestSoftmaxBF16Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "softmax" self.op_type = "softmax"
self.prim_op_type = "comp"
self.python_api = F.softmax self.python_api = F.softmax
self.public_python_api = F.softmax
self.use_cudnn = self.init_cudnn() self.use_cudnn = self.init_cudnn()
self.use_mkldnn = False self.use_mkldnn = False
self.dtype = np.uint16 self.dtype = np.uint16
...@@ -424,7 +426,9 @@ class TestSoftmaxBF16Op(OpTest): ...@@ -424,7 +426,9 @@ class TestSoftmaxBF16Op(OpTest):
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_output_with_place(place, check_dygraph=(not self.use_mkldnn)) self.check_output_with_place(
place, check_dygraph=(not self.use_mkldnn), check_prim=True
)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -434,6 +438,7 @@ class TestSoftmaxBF16Op(OpTest): ...@@ -434,6 +438,7 @@ class TestSoftmaxBF16Op(OpTest):
"Out", "Out",
numeric_grad_delta=0.05, numeric_grad_delta=0.05,
check_dygraph=(not self.use_mkldnn), check_dygraph=(not self.use_mkldnn),
check_prim=True,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册