未验证 提交 d1d43c22 编写于 作者: C Charles-hit 提交者: GitHub

support some prim ops bf16 dtype (#54263)

上级 585f1136
......@@ -542,7 +542,15 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
set_output<T>(out_grad * out, x_grad);
if (out.dtype() == phi::DataType::FLOAT16 ||
out.dtype() == phi::DataType::BFLOAT16) {
Tensor out_promote = cast<T>(out, phi::DataType::FLOAT32);
Tensor out_grad_promote = cast<T>(out_grad, phi::DataType::FLOAT32);
set_output<T>(cast<T>(out_promote * out_grad_promote, out.dtype()),
x_grad);
} else {
set_output<T>(out_grad * out, x_grad);
}
}
}
......
......@@ -113,6 +113,7 @@ class TestExpFp32_Prim(OpTest):
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.if_enable_cinn()
self.convert_input_output()
def test_check_output(self):
self.check_output()
......@@ -129,6 +130,9 @@ class TestExpFp32_Prim(OpTest):
def if_enable_cinn(self):
pass
def convert_input_output(self):
pass
class TestExpFp64_Prim(TestExpFp32_Prim):
def init_dtype(self):
......@@ -4003,6 +4007,7 @@ def create_test_act_fp16_class(
create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestExpFp32_Prim, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestExpm1)
create_test_act_fp16_class(TestSigmoid, check_prim=True, enable_cinn=True)
create_test_act_fp16_class(TestSilu, check_prim=True, enable_cinn=True)
......@@ -4133,6 +4138,7 @@ def create_test_act_bf16_class(
create_test_act_bf16_class(TestActivation)
create_test_act_bf16_class(TestExpFp32_Prim, check_prim=True)
create_test_act_bf16_class(TestExpm1)
create_test_act_bf16_class(TestSigmoid, check_prim=True)
create_test_act_bf16_class(TestSilu, check_prim=True)
......
......@@ -108,6 +108,9 @@ class TestCastOpBf16ToFp32(OpTest):
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
......@@ -130,6 +133,9 @@ class TestCastOpFp32ToBf16(OpTest):
self.prim_op_type = "prim"
self.python_api = cast_wrapper
self.public_python_api = cast_wrapper
self.if_enable_cinn()
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_output(self):
......
......@@ -222,9 +222,8 @@ class TestElementwiseDivOpBF16(ElementwiseDivOp):
check_args.insert(0, self.place)
self.check_grad_with_place(*check_args, **check_kwargs)
# elementwise_pow does't support bfloat16
def if_check_prim(self):
self.check_prim = False
self.check_prim = True
def if_enable_cinn(self):
self.enable_cinn = False
......
......@@ -211,7 +211,7 @@ class TestReduceMeanBF16Op(OpTest):
self.axis = [0]
self.keepdim = False
self.set_attrs()
self.enable_cinn = False
self.if_enable_cinn()
np.random.seed(10)
x_np = np.random.uniform(-1, 1, self.shape).astype(np.float32)
......@@ -227,6 +227,9 @@ class TestReduceMeanBF16Op(OpTest):
'reduce_all': self.reduce_all,
}
def if_enable_cinn(self):
self.enable_cinn = False
def set_attrs(self):
pass
......
......@@ -398,7 +398,9 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
class TestSoftmaxBF16Op(OpTest):
def setUp(self):
self.op_type = "softmax"
self.prim_op_type = "comp"
self.python_api = F.softmax
self.public_python_api = F.softmax
self.use_cudnn = self.init_cudnn()
self.use_mkldnn = False
self.dtype = np.uint16
......@@ -424,7 +426,9 @@ class TestSoftmaxBF16Op(OpTest):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_dygraph=(not self.use_mkldnn))
self.check_output_with_place(
place, check_dygraph=(not self.use_mkldnn), check_prim=True
)
def test_check_grad(self):
place = core.CUDAPlace(0)
......@@ -434,6 +438,7 @@ class TestSoftmaxBF16Op(OpTest):
"Out",
numeric_grad_delta=0.05,
check_dygraph=(not self.use_mkldnn),
check_prim=True,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册