diff --git a/paddle/fluid/operators/softmax_op_mlu.cc b/paddle/fluid/operators/softmax_op_mlu.cc index 9cb698e94fc56854ed7daa38caba75d33a9eb5cc..9b97e779f29efe088ff391c6c27db9bbe10f44d3 100644 --- a/paddle/fluid/operators/softmax_op_mlu.cc +++ b/paddle/fluid/operators/softmax_op_mlu.cc @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class SoftmaxMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -45,7 +45,7 @@ class SoftmaxMLUKernel : public framework::OpKernel { regard_in_shape = {d1, d2, d3}; } - static const cnnlSoftmaxAlgorithm_t algo = CNNL_SOFTMAX_ACCURATE; + static const cnnlSoftmaxAlgorithm_t algo = softmax_algo; MLUCnnlTensorDesc in_desc(cnnl_softmax_dims, regard_in_shape.data(), ToCnnlDataType()); MLUCnnl::SoftmaxForward(ctx, algo, mode, NULL, in_desc.get(), @@ -54,7 +54,7 @@ class SoftmaxMLUKernel : public framework::OpKernel { } }; -template +template class SoftmaxGradMLUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -82,7 +82,7 @@ class SoftmaxGradMLUKernel : public framework::OpKernel { regard_out_shape = {d1, d2, d3}; } - static const cnnlSoftmaxAlgorithm_t algo = CNNL_SOFTMAX_ACCURATE; + static const cnnlSoftmaxAlgorithm_t algo = softmax_algo; MLUCnnlTensorDesc out_desc(cnnl_softmax_dims, regard_out_shape.data(), ToCnnlDataType()); MLUCnnl::SoftmaxBackward(ctx, algo, mode, out_desc.get(), GetBasePtr(out), @@ -97,7 +97,16 @@ class SoftmaxGradMLUKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_MLU_KERNEL(softmax, ops::SoftmaxMLUKernel, - ops::SoftmaxMLUKernel); -REGISTER_OP_MLU_KERNEL(softmax_grad, ops::SoftmaxGradMLUKernel, - ops::SoftmaxGradMLUKernel); +REGISTER_OP_MLU_KERNEL( + softmax, ops::SoftmaxMLUKernel, + ops::SoftmaxMLUKernel); +REGISTER_OP_MLU_KERNEL(softmax_grad, + ops::SoftmaxGradMLUKernel, + ops::SoftmaxGradMLUKernel); +REGISTER_OP_MLU_KERNEL( + log_softmax, ops::SoftmaxMLUKernel, + ops::SoftmaxMLUKernel); +REGISTER_OP_MLU_KERNEL( + log_softmax_grad, ops::SoftmaxGradMLUKernel, + ops::SoftmaxGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_log_softmax_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_log_softmax_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..dea6391b8bae049ab7067d83e6f873fda0494afb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_log_softmax_op_mlu.py @@ -0,0 +1,163 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +import paddle +import paddle.fluid.core as core +import paddle.nn.functional as F + +np.random.seed(10) +paddle.enable_static() + + +def ref_log_softmax(x): + shiftx = (x - np.max(x)) + out = shiftx - np.log(np.exp(shiftx).sum()) + return out + + +def ref_log_softmax_grad(x, axis): + if axis < 0: + axis += len(x.shape) + out = np.apply_along_axis(ref_log_softmax, axis, x) + axis_dim = x.shape[axis] + dout = np.full_like(x, fill_value=1. / x.size) + dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat( + axis_dim, axis=axis) + return dx + + +class TestLogSoftmaxOp(OpTest): + def setUp(self): + self.op_type = 'log_softmax' + self.set_mlu() + self.python_api = F.log_softmax + self.dtype = 'float32' + self.shape = [2, 3, 4, 5] + self.axis = -1 + self.set_attrs() + + x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype) + out = np.apply_along_axis(ref_log_softmax, self.axis, x) + self.x_grad = ref_log_softmax_grad(x, self.axis) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + + def set_attrs(self): + pass + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X'], ['Out'], user_defined_grads=[self.x_grad]) + + +class TestLogSoftmaxShape(TestLogSoftmaxOp): + def set_attrs(self): + self.shape = [12, 10] + + +class TestLogSoftmaxAxis(TestLogSoftmaxOp): + def set_attrs(self): + self.axis = 1 + + +class TestNNLogSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.set_mlu() + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32) + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def check_api(self, axis=-1): + ref_out = np.apply_along_axis(ref_log_softmax, axis, self.x) + + logsoftmax = paddle.nn.LogSoftmax(axis) + # test static api + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data(name='x', shape=self.x_shape) + y = logsoftmax(x) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) + self.assertTrue(np.allclose(out[0], ref_out)) + + # test dygrapg api + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = logsoftmax(x) + self.assertTrue(np.allclose(y.numpy(), ref_out)) + paddle.enable_static() + + def test_check_api(self): + for axis in [-1, 1]: + self.check_api(axis) + + +class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.set_mlu() + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def check_api(self, axis=-1, dtype=None): + x = self.x.copy() + if dtype is not None: + x = x.astype(dtype) + ref_out = np.apply_along_axis(ref_log_softmax, axis, x) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data(name='x', shape=self.x_shape) + y = F.log_softmax(x, axis, dtype) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) + self.assertTrue(np.allclose(out[0], ref_out)) + + paddle.disable_static() + x = paddle.to_tensor(self.x) + y = F.log_softmax(x, axis, dtype) + self.assertTrue(np.allclose(y.numpy(), ref_out), True) + paddle.enable_static() + + def test_check_api(self): + for axis in [-1, 1]: + self.check_api(axis) + self.check_api(-1, 'float32') + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data(name='X1', shape=[100], dtype='int32') + self.assertRaises(TypeError, F.log_softmax, x) + + x = paddle.fluid.data(name='X2', shape=[100], dtype='float32') + self.assertRaises(TypeError, F.log_softmax, x, dtype='int32') + + +if __name__ == "__main__": + unittest.main()