From 586a6dd358fe63ca82af43880f38ac20267386f3 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 18 Aug 2020 15:51:05 +0800 Subject: [PATCH] log_softmax and LogSoftmax: impl kernel and refind docs (#26088) --- paddle/fluid/operators/log_softmax_op.cc | 128 ++++++++++++ paddle/fluid/operators/log_softmax_op.cu | 26 +++ paddle/fluid/operators/log_softmax_op.h | 192 ++++++++++++++++++ .../fluid/tests/unittests/test_log_softmax.py | 165 +++++++++------ python/paddle/nn/functional/activation.py | 123 +++++------ python/paddle/nn/layer/activation.py | 65 +++--- 6 files changed, 546 insertions(+), 153 deletions(-) create mode 100644 paddle/fluid/operators/log_softmax_op.cc create mode 100644 paddle/fluid/operators/log_softmax_op.cu create mode 100644 paddle/fluid/operators/log_softmax_op.h diff --git a/paddle/fluid/operators/log_softmax_op.cc b/paddle/fluid/operators/log_softmax_op.cc new file mode 100644 index 0000000000..d6e2b3ecff --- /dev/null +++ b/paddle/fluid/operators/log_softmax_op.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/log_softmax_op.h" +#include +#include +#include "paddle/fluid/operators/common_infer_shape_functions.h" + +namespace paddle { +namespace operators { + +class LogSoftmaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + return UnaryOpUnchangedInferShapeCheckAxis(ctx); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class LogSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of softmax, " + "whose dimension :attr:`axis` is the input_feature_dimensions."); + AddOutput("Out", "The normalized values with the same shape as X."); + AddAttr("axis", + "The dimension index of Input(x) to perform log_softmax," + "default -1 for last dimension") + .SetDefault(-1); + AddComment(R"DOC( +LogSoftmax Operator. + +)DOC"); + } +}; + +class LogSoftmaxOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + +class LogSoftmaxGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "log_softmax_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@grad", "log_softmax_grad"); + PADDLE_ENFORCE_EQ( + ctx->GetInputDim("Out"), + ctx->GetInputDim(framework::GradVarName("Out")), + platform::errors::InvalidArgument("Input(Out) and its gradients " + "should have the same shape.")); + + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class LogSoftmaxGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("log_softmax_grad"); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(log_softmax, ops::LogSoftmaxOp, ops::LogSoftmaxOpMaker, + ops::LogSoftmaxOpInferVarType, + ops::LogSoftmaxGradOpMaker, + ops::LogSoftmaxGradOpMaker); +REGISTER_OPERATOR(log_softmax_grad, ops::LogSoftmaxGradOp); + +REGISTER_OP_CPU_KERNEL( + log_softmax, + ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel); +REGISTER_OP_CPU_KERNEL( + log_softmax_grad, + ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel); diff --git a/paddle/fluid/operators/log_softmax_op.cu b/paddle/fluid/operators/log_softmax_op.cu new file mode 100644 index 0000000000..02fca246d2 --- /dev/null +++ b/paddle/fluid/operators/log_softmax_op.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/log_softmax_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + log_softmax, ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel); +REGISTER_OP_CUDA_KERNEL( + log_softmax_grad, ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel); diff --git a/paddle/fluid/operators/log_softmax_op.h b/paddle/fluid/operators/log_softmax_op.h new file mode 100644 index 0000000000..b983ac5415 --- /dev/null +++ b/paddle/fluid/operators/log_softmax_op.h @@ -0,0 +1,192 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +using EigenMatrix = framework::EigenMatrix; + +static inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +static inline int SizeToAxis(const int axis, const framework::DDim dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeFromAxis(const int axis, const framework::DDim dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +template +struct ValueClip { + HOSTDEVICE T operator()(const T& x) const { + const T kThreshold = static_cast(-64.); + return x < kThreshold ? kThreshold : x; + } +}; + +template +struct LogSoftmaxFunctor { + void operator()(const DeviceContext& context, const framework::Tensor* X, + framework::Tensor* Y, const int axis) { + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + constexpr int kAxisDim = 1; + + int axis_dim = X->dims()[axis]; + const int n = SizeToAxis(axis, X->dims()); + const int d = SizeFromAxis(axis, X->dims()); + framework::DDim dim_2d{n, d}; + + auto logits = EigenMatrix::From(*X, dim_2d); + auto log_softmax = EigenMatrix::From(*Y, dim_2d); + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + + // For numerical stability, logits should be shifted by maximum number along + // axis, calculate shifted_logits into log_softmax tensor for memory reuse. + if (num_remain == 1) { + // axis == -1, axis and class in same dimension, calculate along + // class dimension directly for higher performance + log_softmax.device(*context.eigen_device()) = + (logits - + logits.maximum(along_axis) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + } else { + // axis != -1, class dimension split into (axis, remain), max and sum + // should be calculated along axis dimension + log_softmax.device(*context.eigen_device()) = + (logits.reshape(batch_axis_remain) - + logits.reshape(batch_axis_remain) + .maximum(along_axis) + .eval() + .reshape(batch_one_remain) + .broadcast(one_axis_one) + .reshape(batch_classes)) + .unaryExpr(ValueClip()); + } + + log_softmax.device(*context.eigen_device()) = + log_softmax - + log_softmax.exp() + .eval() + .reshape(batch_axis_remain) + .sum(along_axis) + .log() + .broadcast(one_axis); + } +}; + +template +class LogSoftmaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Out = context.Output("Out"); + const int rank = X->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + + // allocate memory on device. + Out->mutable_data(context.GetPlace()); + + LogSoftmaxFunctor()( + context.template device_context(), X, Out, axis); + } +}; + +template +struct LogSoftmaxGradFunctor { + void operator()(const DeviceContext& context, const framework::Tensor* Y, + const framework::Tensor* dY, framework::Tensor* dX, + const int axis) { + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + + const int n = SizeToAxis(axis, Y->dims()); + const int d = SizeFromAxis(axis, Y->dims()); + framework::DDim dim_2d{n, d}; + + auto y = EigenMatrix::From(*Y, dim_2d); + auto dy = EigenMatrix::From(*dY, dim_2d); + auto dx = EigenMatrix::From(*dX, dim_2d); + + const int axis_dim = Y->dims()[axis]; + const int batch_size = y.dimension(kBatchDim); + const int num_classes = y.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); + + dx.device(*context.eigen_device()) = + dy - + (y.exp()) * (dy.reshape(batch_axis_remain) + .sum(along_class) + .broadcast(one_axis)); + } +}; + +template +class LogSoftmaxGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* Out = context.Input("Out"); + auto* dOut = + context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + const int rank = Out->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + + // allocate memory on device. + dX->mutable_data(context.GetPlace()); + + LogSoftmaxGradFunctor()( + context.template device_context(), Out, dOut, dX, axis); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_log_softmax.py b/python/paddle/fluid/tests/unittests/test_log_softmax.py index 2b77624734..e3d7003ece 100644 --- a/python/paddle/fluid/tests/unittests/test_log_softmax.py +++ b/python/paddle/fluid/tests/unittests/test_log_softmax.py @@ -14,93 +14,136 @@ import unittest import numpy as np -import paddle.fluid as fluid -import paddle.fluid.core as core -import paddle.nn as nn -import paddle.nn.functional as functional +from op_test import OpTest +import paddle +import paddle.nn.functional as F +np.random.seed(10) -def stable_softmax(x): + +def ref_log_softmax(x): shiftx = (x - np.max(x)) - exps = np.exp(shiftx) - return exps / np.sum(exps) + out = shiftx - np.log(np.exp(shiftx).sum()) + return out -def ref_log_softmax(x, axis=None, dtype=None): - x_t = x.copy() - if dtype is not None: - x_t = x_t.astype(dtype) - if axis is None: - axis = -1 - out = np.apply_along_axis(stable_softmax, axis, x_t) - return np.log(out) +def ref_log_softmax_grad(x, axis): + if axis < 0: + axis += len(x.shape) + out = np.apply_along_axis(ref_log_softmax, axis, x) + axis_dim = x.shape[axis] + dout = np.full_like(x, fill_value=1. / x.size) + dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat( + axis_dim, axis=axis) + return dx -class TestNNLogSoftmaxAPI(unittest.TestCase): +class TestLogSoftmaxOp(OpTest): def setUp(self): - self.init_data() + self.op_type = 'log_softmax' + self.dtype = 'float64' + self.shape = [2, 3, 4, 5] + self.axis = -1 + self.set_attrs() - def init_data(self): - self.x_shape = [2, 3, 4, 5] - self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype) + out = np.apply_along_axis(ref_log_softmax, self.axis, x) + self.x_grad = ref_log_softmax_grad(x, self.axis) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + + def set_attrs(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + + +class TestLogSoftmaxShape(TestLogSoftmaxOp): + def set_attrs(self): + self.shape = [12, 10] - def check_api(self, place=fluid.CPUPlace(), axis=None): - ref_out = ref_log_softmax(self.x, axis) - main_program = fluid.Program() - mylogsoftmax = nn.LogSoftmax(axis) - with fluid.program_guard(main_program): - x = fluid.data(name='x', shape=self.x_shape) - y = mylogsoftmax(x) - exe = fluid.Executor(place) - out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) +class TestLogSoftmaxAxis(TestLogSoftmaxOp): + def set_attrs(self): + self.axis = 1 + + +class TestNNLogSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32) + self.place = paddle.CUDAPlace(0) \ + if paddle.fluid.core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def check_api(self, axis=-1): + ref_out = np.apply_along_axis(ref_log_softmax, axis, self.x) + + logsoftmax = paddle.nn.LogSoftmax(axis) + # test static api + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data(name='x', shape=self.x_shape) + y = logsoftmax(x) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], ref_out)) - with fluid.dygraph.guard(place): - x = fluid.dygraph.to_variable(self.x) - y = mylogsoftmax(x) + # test dygrapg api + paddle.disable_static() + x = paddle.to_variable(self.x) + y = logsoftmax(x) self.assertTrue(np.allclose(y.numpy(), ref_out)) + paddle.enable_static() def test_check_api(self): - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - for axis in [None, 2]: - self.check_api(place, axis) + for axis in [-1, 1]: + self.check_api(axis) class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase): def setUp(self): - self.init_data() - - def init_data(self): self.x_shape = [2, 3, 4, 5] self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) - - def check_api(self, place=fluid.CPUPlace(), axis=None, dtype=None): - ref_out = ref_log_softmax(self.x, axis, dtype) - main_program = fluid.Program() - mylogsoftmax = nn.LogSoftmax(axis) - with fluid.program_guard(main_program): - x = fluid.data(name='x', shape=self.x_shape) - y = functional.log_softmax(x, axis, dtype) - exe = fluid.Executor(place) - out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) + self.place = paddle.CUDAPlace(0) \ + if paddle.fluid.core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def check_api(self, axis=-1, dtype=None): + x = self.x.copy() + if dtype is not None: + x = x.astype(dtype) + ref_out = np.apply_along_axis(ref_log_softmax, axis, x) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data(name='x', shape=self.x_shape) + y = F.log_softmax(x, axis, dtype) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], ref_out)) - with fluid.dygraph.guard(place): - x = fluid.dygraph.to_variable(self.x) - y = functional.log_softmax(x, axis, dtype) - self.assertTrue(np.allclose(y.numpy(), ref_out)) + paddle.disable_static() + x = paddle.to_variable(self.x) + y = F.log_softmax(x, axis, dtype) + self.assertTrue(np.allclose(y.numpy(), ref_out), True) + paddle.enable_static() def test_check_api(self): - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - self.check_api(place, None, None) - self.check_api(place, None, np.float64) + for axis in [-1, 1]: + self.check_api(axis) + self.check_api(-1, 'float64') + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data(name='X1', shape=[100], dtype='int32') + self.assertRaises(TypeError, F.log_softmax, x) + + x = paddle.data(name='X2', shape=[100], dtype='float32') + self.assertRaises(TypeError, F.log_softmax, x, dtype='int32') if __name__ == "__main__": diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 218ec4cfdc..4dd770139d 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -65,7 +65,7 @@ import warnings from ...fluid.layer_helper import LayerHelper from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_ from ...fluid import core -from ...fluid.data_feeder import check_variable_and_dtype +from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle @@ -413,12 +413,10 @@ def softmax(x, axis=-1, name=None): return paddle.fluid.layers.softmax(input=x, axis=axis, name=name) -def log_softmax(input, axis=None, dtype=None, name=None): +def log_softmax(x, axis=-1, dtype=None, name=None): """ - :alias_main: paddle.nn.functional.log_softmax - :alias: paddle.nn.functional.log_softmax,paddle.nn.functional.activation.log_softmax - - This operator implements the log_softmax layer. The calculation process is as follows: + This operator implements the log_softmax layer. The calculation process is + as follows: .. math:: @@ -426,78 +424,85 @@ def log_softmax(input, axis=None, dtype=None, name=None): = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) Parameters: - input (Variable): The input variable. A multi-dimension Tensor with type float32, or float64. - axis (int, optional): The index of dimension to perform softmax calculations, it should be in - range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. - None and -1 means the last dimension. - dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, - the input tensor is casted to dtype before the operation is performed. This is useful for - preventing data type overflows. Default: None. Supported dtype: float32 or float64 - name (str, optional): The default value is None. Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name` . + x (Tensor): The input Tensor with data type float32, float64. + axis (int, optional): The axis along which to perform log_softmax + calculations. It should be in range [-D, D), where D is the + dimensions of ``x`` . If ``axis`` < 0, it works the same way as + :math:`axis + D` . Default is -1. + dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data + type of the output tensor. If dtype is specified, ``x`` is casted + to ``dtype`` before the operation is performed. This is useful for + preventing data type overflows. Supported dtype: float32, float64. + If ``dtype`` is None, the output Tensor has the same dtype as x. + Default is None. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input``. + A Tensor with the same shape and data type (use ``dtype`` if it is + specified) as x. Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.nn.functional as F - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np - data = np.array([[[-2.0, 3.0, -4.0, 5.0], - [3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [6.0, 7.0, 8.0, 9.0]]]).astype('float32') - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - res = F.log_softmax(data, -1) - # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] - # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] - # [-16.313261 -17.313261 -1.3132617 -0.31326184]] - # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] - # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] - # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + paddle.disable_static() + + x = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]], 'float32') + x = paddle.to_tensor(x) + out1 = F.log_softmax(x) + out2 = F.log_softmax(x, dtype='float64') + # out1's data type is float32; out2's data type is float64 + # out1 and out2's value is as follows: + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] """ - axis = -1 if axis is None else axis - dtype = convert_np_dtype_to_dtype_(dtype) if dtype is not None else dtype + if axis is None: + axis = -1 + if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): + dtype = convert_np_dtype_to_dtype_(dtype) if in_dygraph_mode(): - outs_cast = input if dtype is None \ - else core.ops.cast(input, 'in_dtype', input.dtype, 'out_dtype', dtype) - outs_softmax = core.ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', - False) - return core.ops.log(outs_softmax) + if dtype is not None: + x = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) + return core.ops.log_softmax(x, 'axis', axis) if dtype is None: - check_variable_and_dtype( - input, 'input', ['float16', 'float32', 'float64'], 'log_softmax') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'log_softmax') + else: + check_dtype(dtype, 'dtype', ['float32', 'float64'], 'log_softmax', + 'If dtype is not None, it only support float32 or float64.') helper = LayerHelper("log_softmax", **locals()) - outs_cast = input + out_cast = x if dtype is not None: - outs_cast = helper.create_variable_for_type_inference(dtype) + out_cast = helper.create_variable_for_type_inference(dtype) helper.append_op( type='cast', - inputs={'X': input}, - outputs={'Out': outs_cast}, - attrs={'in_dtype': input.dtype, + inputs={'X': x}, + outputs={'Out': out_cast}, + attrs={'in_dtype': x.dtype, 'out_dtype': dtype}) - outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype) - helper.append_op( - type='softmax', - inputs={'X': outs_cast}, - outputs={'Out': outs_softmax}, - attrs={'axis': axis, - 'use_cudnn': False}) - - outs_log = helper.create_variable_for_type_inference(outs_softmax.dtype) + out = helper.create_variable_for_type_inference(out_cast.dtype) helper.append_op( - type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log}) + type='log_softmax', + inputs={'X': out_cast}, + outputs={'Out': out}, + attrs={'axis': axis}) - return outs_log + return out diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index c82b2d61ff..a0b24eebf5 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -338,9 +338,6 @@ class Sigmoid(layers.Layer): class LogSoftmax(layers.Layer): """ - :alias_main: paddle.nn.LogSoftmax - :alias: paddle.nn.LogSoftmax,paddle.nn.layer.LogSoftmax,paddle.nn.layer.activation.LogSoftmax - This operator implements the log_softmax layer. The calculation process is as follows: .. math:: @@ -349,44 +346,46 @@ class LogSoftmax(layers.Layer): = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) Parameters: - axis (int, optional): The index of dimension to perform softmax calculations, it should be in - range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. - None and -1 means the last dimension. - dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, - the input tensor is casted to dtype before the operation is performed. This is useful for - preventing data type overflows. Default: None. Supported dtype: float32 or float64 + axis (int, optional): The axis along which to perform log_softmax + calculations. It should be in range [-D, D), where D is the + dimensions of the input Tensor . If ``axis`` < 0, it works the + same way as :math:`axis + D` . Default is -1. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. - Returns: - None + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.nn as nn - import numpy as np + import paddle + import numpy as np - data = np.array([[[-2.0, 3.0, -4.0, 5.0], - [3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [6.0, 7.0, 8.0, 9.0]]]).astype('float32') - my_log_softnmax = nn.LogSoftmax() - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - res = my_log_softnmax(data) - # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] - # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] - # [-16.313261 -17.313261 -1.3132617 -0.31326184]] - # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] - # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] - # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + paddle.disable_static() + + x = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]) + m = paddle.nn.LogSoftmax() + x = paddle.to_tensor(x) + out = m(x) + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] """ - def __init__(self, axis=None): + def __init__(self, axis=-1, name=None): super(LogSoftmax, self).__init__() self._axis = axis + self._name = name - def forward(self, input): - return F.log_softmax(input, self._axis) + def forward(self, x): + return F.log_softmax(x, self._axis) -- GitLab