未验证 提交 586a6dd3 编写于 作者: Z zhupengyang 提交者: GitHub

log_softmax and LogSoftmax: impl kernel and refind docs (#26088)

上级 23261ff4
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/log_softmax_op.h"
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
namespace paddle {
namespace operators {
class LogSoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
return UnaryOpUnchangedInferShapeCheckAxis(ctx);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class LogSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of softmax, "
"whose dimension :attr:`axis` is the input_feature_dimensions.");
AddOutput("Out", "The normalized values with the same shape as X.");
AddAttr<int>("axis",
"The dimension index of Input(x) to perform log_softmax,"
"default -1 for last dimension")
.SetDefault(-1);
AddComment(R"DOC(
LogSoftmax Operator.
)DOC");
}
};
class LogSoftmaxOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class LogSoftmaxGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "log_softmax_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@grad", "log_softmax_grad");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
platform::errors::InvalidArgument("Input(Out) and its gradients "
"should have the same shape."));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class LogSoftmaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("log_softmax_grad");
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(log_softmax, ops::LogSoftmaxOp, ops::LogSoftmaxOpMaker,
ops::LogSoftmaxOpInferVarType,
ops::LogSoftmaxGradOpMaker<paddle::framework::OpDesc>,
ops::LogSoftmaxGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(log_softmax_grad, ops::LogSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
log_softmax,
ops::LogSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
log_softmax_grad,
ops::LogSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/log_softmax_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
log_softmax_grad, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, const framework::DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, const framework::DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename T>
struct ValueClip {
HOSTDEVICE T operator()(const T& x) const {
const T kThreshold = static_cast<T>(-64.);
return x < kThreshold ? kThreshold : x;
}
};
template <typename DeviceContext, typename T>
struct LogSoftmaxFunctor {
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y, const int axis) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
int axis_dim = X->dims()[axis];
const int n = SizeToAxis(axis, X->dims());
const int d = SizeFromAxis(axis, X->dims());
framework::DDim dim_2d{n, d};
auto logits = EigenMatrix<T>::From(*X, dim_2d);
auto log_softmax = EigenMatrix<T>::From(*Y, dim_2d);
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_axis(kAxisDim);
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
// For numerical stability, logits should be shifted by maximum number along
// axis, calculate shifted_logits into log_softmax tensor for memory reuse.
if (num_remain == 1) {
// axis == -1, axis and class in same dimension, calculate along
// class dimension directly for higher performance
log_softmax.device(*context.eigen_device()) =
(logits -
logits.maximum(along_axis)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class))
.unaryExpr(ValueClip<T>());
} else {
// axis != -1, class dimension split into (axis, remain), max and sum
// should be calculated along axis dimension
log_softmax.device(*context.eigen_device()) =
(logits.reshape(batch_axis_remain) -
logits.reshape(batch_axis_remain)
.maximum(along_axis)
.eval()
.reshape(batch_one_remain)
.broadcast(one_axis_one)
.reshape(batch_classes))
.unaryExpr(ValueClip<T>());
}
log_softmax.device(*context.eigen_device()) =
log_softmax -
log_softmax.exp()
.eval()
.reshape(batch_axis_remain)
.sum(along_axis)
.log()
.broadcast(one_axis);
}
};
template <typename DeviceContext, typename T>
class LogSoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Out = context.Output<framework::Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
LogSoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Out, axis);
}
};
template <typename DeviceContext, typename T>
struct LogSoftmaxGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* dY, framework::Tensor* dX,
const int axis) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
const int n = SizeToAxis(axis, Y->dims());
const int d = SizeFromAxis(axis, Y->dims());
framework::DDim dim_2d{n, d};
auto y = EigenMatrix<T>::From(*Y, dim_2d);
auto dy = EigenMatrix<T>::From(*dY, dim_2d);
auto dx = EigenMatrix<T>::From(*dX, dim_2d);
const int axis_dim = Y->dims()[axis];
const int batch_size = y.dimension(kBatchDim);
const int num_classes = y.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
dx.device(*context.eigen_device()) =
dy -
(y.exp()) * (dy.reshape(batch_axis_remain)
.sum(along_class)
.broadcast(one_axis));
}
};
template <typename DeviceContext, typename T>
class LogSoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<framework::Tensor>("Out");
auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
const int rank = Out->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
LogSoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Out, dOut, dX, axis);
}
};
} // namespace operators
} // namespace paddle
...@@ -14,93 +14,136 @@ ...@@ -14,93 +14,136 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid from op_test import OpTest
import paddle.fluid.core as core import paddle
import paddle.nn as nn import paddle.nn.functional as F
import paddle.nn.functional as functional
np.random.seed(10)
def stable_softmax(x):
def ref_log_softmax(x):
shiftx = (x - np.max(x)) shiftx = (x - np.max(x))
exps = np.exp(shiftx) out = shiftx - np.log(np.exp(shiftx).sum())
return exps / np.sum(exps) return out
def ref_log_softmax(x, axis=None, dtype=None): def ref_log_softmax_grad(x, axis):
x_t = x.copy() if axis < 0:
if dtype is not None: axis += len(x.shape)
x_t = x_t.astype(dtype) out = np.apply_along_axis(ref_log_softmax, axis, x)
if axis is None: axis_dim = x.shape[axis]
axis = -1 dout = np.full_like(x, fill_value=1. / x.size)
out = np.apply_along_axis(stable_softmax, axis, x_t) dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat(
return np.log(out) axis_dim, axis=axis)
return dx
class TestNNLogSoftmaxAPI(unittest.TestCase): class TestLogSoftmaxOp(OpTest):
def setUp(self): def setUp(self):
self.init_data() self.op_type = 'log_softmax'
self.dtype = 'float64'
self.shape = [2, 3, 4, 5]
self.axis = -1
self.set_attrs()
def init_data(self): x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype)
self.x_shape = [2, 3, 4, 5] out = np.apply_along_axis(ref_log_softmax, self.axis, x)
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) self.x_grad = ref_log_softmax_grad(x, self.axis)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis}
def set_attrs(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad])
class TestLogSoftmaxShape(TestLogSoftmaxOp):
def set_attrs(self):
self.shape = [12, 10]
def check_api(self, place=fluid.CPUPlace(), axis=None):
ref_out = ref_log_softmax(self.x, axis)
main_program = fluid.Program() class TestLogSoftmaxAxis(TestLogSoftmaxOp):
mylogsoftmax = nn.LogSoftmax(axis) def set_attrs(self):
with fluid.program_guard(main_program): self.axis = 1
x = fluid.data(name='x', shape=self.x_shape)
y = mylogsoftmax(x)
exe = fluid.Executor(place) class TestNNLogSoftmaxAPI(unittest.TestCase):
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) def setUp(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32)
self.place = paddle.CUDAPlace(0) \
if paddle.fluid.core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def check_api(self, axis=-1):
ref_out = np.apply_along_axis(ref_log_softmax, axis, self.x)
logsoftmax = paddle.nn.LogSoftmax(axis)
# test static api
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data(name='x', shape=self.x_shape)
y = logsoftmax(x)
exe = paddle.static.Executor(self.place)
out = exe.run(feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], ref_out)) self.assertTrue(np.allclose(out[0], ref_out))
with fluid.dygraph.guard(place): # test dygrapg api
x = fluid.dygraph.to_variable(self.x) paddle.disable_static()
y = mylogsoftmax(x) x = paddle.to_variable(self.x)
y = logsoftmax(x)
self.assertTrue(np.allclose(y.numpy(), ref_out)) self.assertTrue(np.allclose(y.numpy(), ref_out))
paddle.enable_static()
def test_check_api(self): def test_check_api(self):
places = [fluid.CPUPlace()] for axis in [-1, 1]:
if core.is_compiled_with_cuda(): self.check_api(axis)
places.append(fluid.CUDAPlace(0))
for place in places:
for axis in [None, 2]:
self.check_api(place, axis)
class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase): class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.init_data()
def init_data(self):
self.x_shape = [2, 3, 4, 5] self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
self.place = paddle.CUDAPlace(0) \
def check_api(self, place=fluid.CPUPlace(), axis=None, dtype=None): if paddle.fluid.core.is_compiled_with_cuda() \
ref_out = ref_log_softmax(self.x, axis, dtype) else paddle.CPUPlace()
main_program = fluid.Program()
mylogsoftmax = nn.LogSoftmax(axis) def check_api(self, axis=-1, dtype=None):
with fluid.program_guard(main_program): x = self.x.copy()
x = fluid.data(name='x', shape=self.x_shape) if dtype is not None:
y = functional.log_softmax(x, axis, dtype) x = x.astype(dtype)
exe = fluid.Executor(place) ref_out = np.apply_along_axis(ref_log_softmax, axis, x)
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data(name='x', shape=self.x_shape)
y = F.log_softmax(x, axis, dtype)
exe = paddle.static.Executor(self.place)
out = exe.run(feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], ref_out)) self.assertTrue(np.allclose(out[0], ref_out))
with fluid.dygraph.guard(place): paddle.disable_static()
x = fluid.dygraph.to_variable(self.x) x = paddle.to_variable(self.x)
y = functional.log_softmax(x, axis, dtype) y = F.log_softmax(x, axis, dtype)
self.assertTrue(np.allclose(y.numpy(), ref_out)) self.assertTrue(np.allclose(y.numpy(), ref_out), True)
paddle.enable_static()
def test_check_api(self): def test_check_api(self):
places = [fluid.CPUPlace()] for axis in [-1, 1]:
if core.is_compiled_with_cuda(): self.check_api(axis)
places.append(fluid.CUDAPlace(0)) self.check_api(-1, 'float64')
for place in places:
self.check_api(place, None, None) def test_errors(self):
self.check_api(place, None, np.float64) with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data(name='X1', shape=[100], dtype='int32')
self.assertRaises(TypeError, F.log_softmax, x)
x = paddle.data(name='X2', shape=[100], dtype='float32')
self.assertRaises(TypeError, F.log_softmax, x, dtype='int32')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -65,7 +65,7 @@ import warnings ...@@ -65,7 +65,7 @@ import warnings
from ...fluid.layer_helper import LayerHelper from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_ from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_
from ...fluid import core from ...fluid import core
from ...fluid.data_feeder import check_variable_and_dtype from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
import paddle import paddle
...@@ -413,12 +413,10 @@ def softmax(x, axis=-1, name=None): ...@@ -413,12 +413,10 @@ def softmax(x, axis=-1, name=None):
return paddle.fluid.layers.softmax(input=x, axis=axis, name=name) return paddle.fluid.layers.softmax(input=x, axis=axis, name=name)
def log_softmax(input, axis=None, dtype=None, name=None): def log_softmax(x, axis=-1, dtype=None, name=None):
""" """
:alias_main: paddle.nn.functional.log_softmax This operator implements the log_softmax layer. The calculation process is
:alias: paddle.nn.functional.log_softmax,paddle.nn.functional.activation.log_softmax as follows:
This operator implements the log_softmax layer. The calculation process is as follows:
.. math:: .. math::
...@@ -426,78 +424,85 @@ def log_softmax(input, axis=None, dtype=None, name=None): ...@@ -426,78 +424,85 @@ def log_softmax(input, axis=None, dtype=None, name=None):
= log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
Parameters: Parameters:
input (Variable): The input variable. A multi-dimension Tensor with type float32, or float64. x (Tensor): The input Tensor with data type float32, float64.
axis (int, optional): The index of dimension to perform softmax calculations, it should be in axis (int, optional): The axis along which to perform log_softmax
range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. calculations. It should be in range [-D, D), where D is the
None and -1 means the last dimension. dimensions of ``x`` . If ``axis`` < 0, it works the same way as
dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, :math:`axis + D` . Default is -1.
the input tensor is casted to dtype before the operation is performed. This is useful for dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
preventing data type overflows. Default: None. Supported dtype: float32 or float64 type of the output tensor. If dtype is specified, ``x`` is casted
name (str, optional): The default value is None. Normally there is no need for user to set this property. to ``dtype`` before the operation is performed. This is useful for
For more information, please refer to :ref:`api_guide_Name` . preventing data type overflows. Supported dtype: float32, float64.
If ``dtype`` is None, the output Tensor has the same dtype as x.
Default is None.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input``. A Tensor with the same shape and data type (use ``dtype`` if it is
specified) as x.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import numpy as np import numpy as np
data = np.array([[[-2.0, 3.0, -4.0, 5.0], paddle.disable_static()
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]], x = np.array([[[-2.0, 3.0, -4.0, 5.0],
[[1.0, -2.0, -3.0, 4.0], [3.0, -4.0, 5.0, -6.0],
[-5.0, 6.0, 7.0, -8.0], [-7.0, -8.0, 8.0, 9.0]],
[6.0, 7.0, 8.0, 9.0]]]).astype('float32') [[1.0, -2.0, -3.0, 4.0],
with fluid.dygraph.guard(): [-5.0, 6.0, 7.0, -8.0],
data = fluid.dygraph.to_variable(data) [6.0, 7.0, 8.0, 9.0]]], 'float32')
res = F.log_softmax(data, -1) x = paddle.to_tensor(x)
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] out1 = F.log_softmax(x)
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] out2 = F.log_softmax(x, dtype='float64')
# [-16.313261 -17.313261 -1.3132617 -0.31326184]] # out1's data type is float32; out2's data type is float64
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] # out1 and out2's value is as follows:
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ] # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
""" """
axis = -1 if axis is None else axis if axis is None:
dtype = convert_np_dtype_to_dtype_(dtype) if dtype is not None else dtype axis = -1
if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode(): if in_dygraph_mode():
outs_cast = input if dtype is None \ if dtype is not None:
else core.ops.cast(input, 'in_dtype', input.dtype, 'out_dtype', dtype) x = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
outs_softmax = core.ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', return core.ops.log_softmax(x, 'axis', axis)
False)
return core.ops.log(outs_softmax)
if dtype is None: if dtype is None:
check_variable_and_dtype( check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
input, 'input', ['float16', 'float32', 'float64'], 'log_softmax') 'log_softmax')
else:
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'log_softmax',
'If dtype is not None, it only support float32 or float64.')
helper = LayerHelper("log_softmax", **locals()) helper = LayerHelper("log_softmax", **locals())
outs_cast = input out_cast = x
if dtype is not None: if dtype is not None:
outs_cast = helper.create_variable_for_type_inference(dtype) out_cast = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='cast', type='cast',
inputs={'X': input}, inputs={'X': x},
outputs={'Out': outs_cast}, outputs={'Out': out_cast},
attrs={'in_dtype': input.dtype, attrs={'in_dtype': x.dtype,
'out_dtype': dtype}) 'out_dtype': dtype})
outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype) out = helper.create_variable_for_type_inference(out_cast.dtype)
helper.append_op(
type='softmax',
inputs={'X': outs_cast},
outputs={'Out': outs_softmax},
attrs={'axis': axis,
'use_cudnn': False})
outs_log = helper.create_variable_for_type_inference(outs_softmax.dtype)
helper.append_op( helper.append_op(
type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log}) type='log_softmax',
inputs={'X': out_cast},
outputs={'Out': out},
attrs={'axis': axis})
return outs_log return out
...@@ -338,9 +338,6 @@ class Sigmoid(layers.Layer): ...@@ -338,9 +338,6 @@ class Sigmoid(layers.Layer):
class LogSoftmax(layers.Layer): class LogSoftmax(layers.Layer):
""" """
:alias_main: paddle.nn.LogSoftmax
:alias: paddle.nn.LogSoftmax,paddle.nn.layer.LogSoftmax,paddle.nn.layer.activation.LogSoftmax
This operator implements the log_softmax layer. The calculation process is as follows: This operator implements the log_softmax layer. The calculation process is as follows:
.. math:: .. math::
...@@ -349,44 +346,46 @@ class LogSoftmax(layers.Layer): ...@@ -349,44 +346,46 @@ class LogSoftmax(layers.Layer):
= log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
Parameters: Parameters:
axis (int, optional): The index of dimension to perform softmax calculations, it should be in axis (int, optional): The axis along which to perform log_softmax
range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. calculations. It should be in range [-D, D), where D is the
None and -1 means the last dimension. dimensions of the input Tensor . If ``axis`` < 0, it works the
dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, same way as :math:`axis + D` . Default is -1.
the input tensor is casted to dtype before the operation is performed. This is useful for name (str, optional): Name for the operation (optional, default is None).
preventing data type overflows. Default: None. Supported dtype: float32 or float64 For more information, please refer to :ref:`api_guide_Name`.
Returns: Shape:
None - input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
import paddle.nn as nn import numpy as np
import numpy as np
data = np.array([[[-2.0, 3.0, -4.0, 5.0], paddle.disable_static()
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]], x = np.array([[[-2.0, 3.0, -4.0, 5.0],
[[1.0, -2.0, -3.0, 4.0], [3.0, -4.0, 5.0, -6.0],
[-5.0, 6.0, 7.0, -8.0], [-7.0, -8.0, 8.0, 9.0]],
[6.0, 7.0, 8.0, 9.0]]]).astype('float32') [[1.0, -2.0, -3.0, 4.0],
my_log_softnmax = nn.LogSoftmax() [-5.0, 6.0, 7.0, -8.0],
with fluid.dygraph.guard(): [6.0, 7.0, 8.0, 9.0]]])
data = fluid.dygraph.to_variable(data) m = paddle.nn.LogSoftmax()
res = my_log_softnmax(data) x = paddle.to_tensor(x)
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] out = m(x)
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]] # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] # [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ] # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] # [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
""" """
def __init__(self, axis=None): def __init__(self, axis=-1, name=None):
super(LogSoftmax, self).__init__() super(LogSoftmax, self).__init__()
self._axis = axis self._axis = axis
self._name = name
def forward(self, input): def forward(self, x):
return F.log_softmax(input, self._axis) return F.log_softmax(x, self._axis)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册