未验证 提交 a9f76d07 编写于 作者: C Chenxiao Niu 提交者: GitHub

[MLU] fix log_softmax mode selection. (#44669)

上级 dfeb1942
...@@ -117,10 +117,9 @@ REGISTER_OP_MLU_KERNEL(softmax_grad, ...@@ -117,10 +117,9 @@ REGISTER_OP_MLU_KERNEL(softmax_grad,
ops::SoftmaxGradMLUKernel<CNNL_SOFTMAX_ACCURATE, float>, ops::SoftmaxGradMLUKernel<CNNL_SOFTMAX_ACCURATE, float>,
ops::SoftmaxGradMLUKernel<CNNL_SOFTMAX_ACCURATE, ops::SoftmaxGradMLUKernel<CNNL_SOFTMAX_ACCURATE,
paddle::platform::float16>); paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL( REGISTER_OP_MLU_KERNEL(log_softmax,
log_softmax, ops::SoftmaxMLUKernel<CNNL_SOFTMAX_LOG, float>,
ops::SoftmaxMLUKernel<CNNL_SOFTMAX_LOG, float>, ops::SoftmaxMLUKernel<CNNL_SOFTMAX_LOG, plat::float16>);
ops::SoftmaxMLUKernel<CNNL_SOFTMAX_ACCURATE, plat::float16>);
REGISTER_OP_MLU_KERNEL( REGISTER_OP_MLU_KERNEL(
log_softmax_grad, log_softmax_grad,
ops::SoftmaxGradMLUKernel<CNNL_SOFTMAX_LOG, float>, ops::SoftmaxGradMLUKernel<CNNL_SOFTMAX_LOG, float>,
......
...@@ -86,6 +86,41 @@ class TestLogSoftmaxAxis(TestLogSoftmaxOp): ...@@ -86,6 +86,41 @@ class TestLogSoftmaxAxis(TestLogSoftmaxOp):
self.axis = 1 self.axis = 1
class TestLogSoftmaxOpFp16(OpTest):
def setUp(self):
self.op_type = 'log_softmax'
self.set_mlu()
self.python_api = F.log_softmax
self.dtype = 'float16'
self.shape = [2, 3, 4, 5]
self.axis = -1
self.set_attrs()
x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype)
out = np.apply_along_axis(ref_log_softmax, self.axis, x)
self.x_grad = ref_log_softmax_grad(x, self.axis)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis}
def set_attrs(self):
pass
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-2)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], ['Out'],
user_defined_grads=[self.x_grad],
max_relative_error=0.015)
class TestNNLogSoftmaxAPI(unittest.TestCase): class TestNNLogSoftmaxAPI(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册