未验证 提交 586a6dd3 编写于 作者: Z zhupengyang 提交者: GitHub

log_softmax and LogSoftmax: impl kernel and refind docs (#26088)

上级 23261ff4
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/log_softmax_op.h"
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/common_infer_shape_functions.h"
namespace paddle {
namespace operators {
class LogSoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
return UnaryOpUnchangedInferShapeCheckAxis(ctx);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class LogSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of softmax, "
"whose dimension :attr:`axis` is the input_feature_dimensions.");
AddOutput("Out", "The normalized values with the same shape as X.");
AddAttr<int>("axis",
"The dimension index of Input(x) to perform log_softmax,"
"default -1 for last dimension")
.SetDefault(-1);
AddComment(R"DOC(
LogSoftmax Operator.
)DOC");
}
};
class LogSoftmaxOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class LogSoftmaxGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "log_softmax_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@grad", "log_softmax_grad");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
platform::errors::InvalidArgument("Input(Out) and its gradients "
"should have the same shape."));
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class LogSoftmaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("log_softmax_grad");
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(log_softmax, ops::LogSoftmaxOp, ops::LogSoftmaxOpMaker,
ops::LogSoftmaxOpInferVarType,
ops::LogSoftmaxGradOpMaker<paddle::framework::OpDesc>,
ops::LogSoftmaxGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(log_softmax_grad, ops::LogSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
log_softmax,
ops::LogSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
log_softmax_grad,
ops::LogSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/log_softmax_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
log_softmax_grad, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, const framework::DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, const framework::DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename T>
struct ValueClip {
HOSTDEVICE T operator()(const T& x) const {
const T kThreshold = static_cast<T>(-64.);
return x < kThreshold ? kThreshold : x;
}
};
template <typename DeviceContext, typename T>
struct LogSoftmaxFunctor {
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y, const int axis) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
int axis_dim = X->dims()[axis];
const int n = SizeToAxis(axis, X->dims());
const int d = SizeFromAxis(axis, X->dims());
framework::DDim dim_2d{n, d};
auto logits = EigenMatrix<T>::From(*X, dim_2d);
auto log_softmax = EigenMatrix<T>::From(*Y, dim_2d);
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_axis(kAxisDim);
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
// For numerical stability, logits should be shifted by maximum number along
// axis, calculate shifted_logits into log_softmax tensor for memory reuse.
if (num_remain == 1) {
// axis == -1, axis and class in same dimension, calculate along
// class dimension directly for higher performance
log_softmax.device(*context.eigen_device()) =
(logits -
logits.maximum(along_axis)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class))
.unaryExpr(ValueClip<T>());
} else {
// axis != -1, class dimension split into (axis, remain), max and sum
// should be calculated along axis dimension
log_softmax.device(*context.eigen_device()) =
(logits.reshape(batch_axis_remain) -
logits.reshape(batch_axis_remain)
.maximum(along_axis)
.eval()
.reshape(batch_one_remain)
.broadcast(one_axis_one)
.reshape(batch_classes))
.unaryExpr(ValueClip<T>());
}
log_softmax.device(*context.eigen_device()) =
log_softmax -
log_softmax.exp()
.eval()
.reshape(batch_axis_remain)
.sum(along_axis)
.log()
.broadcast(one_axis);
}
};
template <typename DeviceContext, typename T>
class LogSoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Out = context.Output<framework::Tensor>("Out");
const int rank = X->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
LogSoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Out, axis);
}
};
template <typename DeviceContext, typename T>
struct LogSoftmaxGradFunctor {
void operator()(const DeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* dY, framework::Tensor* dX,
const int axis) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
const int n = SizeToAxis(axis, Y->dims());
const int d = SizeFromAxis(axis, Y->dims());
framework::DDim dim_2d{n, d};
auto y = EigenMatrix<T>::From(*Y, dim_2d);
auto dy = EigenMatrix<T>::From(*dY, dim_2d);
auto dx = EigenMatrix<T>::From(*dX, dim_2d);
const int axis_dim = Y->dims()[axis];
const int batch_size = y.dimension(kBatchDim);
const int num_classes = y.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
dx.device(*context.eigen_device()) =
dy -
(y.exp()) * (dy.reshape(batch_axis_remain)
.sum(along_class)
.broadcast(one_axis));
}
};
template <typename DeviceContext, typename T>
class LogSoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<framework::Tensor>("Out");
auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
const int rank = Out->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
LogSoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Out, dOut, dX, axis);
}
};
} // namespace operators
} // namespace paddle
......@@ -14,93 +14,136 @@
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.nn as nn
import paddle.nn.functional as functional
from op_test import OpTest
import paddle
import paddle.nn.functional as F
np.random.seed(10)
def stable_softmax(x):
def ref_log_softmax(x):
shiftx = (x - np.max(x))
exps = np.exp(shiftx)
return exps / np.sum(exps)
out = shiftx - np.log(np.exp(shiftx).sum())
return out
def ref_log_softmax(x, axis=None, dtype=None):
x_t = x.copy()
if dtype is not None:
x_t = x_t.astype(dtype)
if axis is None:
axis = -1
out = np.apply_along_axis(stable_softmax, axis, x_t)
return np.log(out)
def ref_log_softmax_grad(x, axis):
if axis < 0:
axis += len(x.shape)
out = np.apply_along_axis(ref_log_softmax, axis, x)
axis_dim = x.shape[axis]
dout = np.full_like(x, fill_value=1. / x.size)
dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat(
axis_dim, axis=axis)
return dx
class TestNNLogSoftmaxAPI(unittest.TestCase):
class TestLogSoftmaxOp(OpTest):
def setUp(self):
self.init_data()
self.op_type = 'log_softmax'
self.dtype = 'float64'
self.shape = [2, 3, 4, 5]
self.axis = -1
self.set_attrs()
def init_data(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype)
out = np.apply_along_axis(ref_log_softmax, self.axis, x)
self.x_grad = ref_log_softmax_grad(x, self.axis)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis}
def set_attrs(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad])
class TestLogSoftmaxShape(TestLogSoftmaxOp):
def set_attrs(self):
self.shape = [12, 10]
def check_api(self, place=fluid.CPUPlace(), axis=None):
ref_out = ref_log_softmax(self.x, axis)
main_program = fluid.Program()
mylogsoftmax = nn.LogSoftmax(axis)
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
y = mylogsoftmax(x)
exe = fluid.Executor(place)
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
class TestLogSoftmaxAxis(TestLogSoftmaxOp):
def set_attrs(self):
self.axis = 1
class TestNNLogSoftmaxAPI(unittest.TestCase):
def setUp(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32)
self.place = paddle.CUDAPlace(0) \
if paddle.fluid.core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def check_api(self, axis=-1):
ref_out = np.apply_along_axis(ref_log_softmax, axis, self.x)
logsoftmax = paddle.nn.LogSoftmax(axis)
# test static api
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data(name='x', shape=self.x_shape)
y = logsoftmax(x)
exe = paddle.static.Executor(self.place)
out = exe.run(feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], ref_out))
with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x)
y = mylogsoftmax(x)
# test dygrapg api
paddle.disable_static()
x = paddle.to_variable(self.x)
y = logsoftmax(x)
self.assertTrue(np.allclose(y.numpy(), ref_out))
paddle.enable_static()
def test_check_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
for axis in [None, 2]:
self.check_api(place, axis)
for axis in [-1, 1]:
self.check_api(axis)
class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase):
def setUp(self):
self.init_data()
def init_data(self):
self.x_shape = [2, 3, 4, 5]
self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32)
def check_api(self, place=fluid.CPUPlace(), axis=None, dtype=None):
ref_out = ref_log_softmax(self.x, axis, dtype)
main_program = fluid.Program()
mylogsoftmax = nn.LogSoftmax(axis)
with fluid.program_guard(main_program):
x = fluid.data(name='x', shape=self.x_shape)
y = functional.log_softmax(x, axis, dtype)
exe = fluid.Executor(place)
out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y])
self.place = paddle.CUDAPlace(0) \
if paddle.fluid.core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def check_api(self, axis=-1, dtype=None):
x = self.x.copy()
if dtype is not None:
x = x.astype(dtype)
ref_out = np.apply_along_axis(ref_log_softmax, axis, x)
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data(name='x', shape=self.x_shape)
y = F.log_softmax(x, axis, dtype)
exe = paddle.static.Executor(self.place)
out = exe.run(feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], ref_out))
with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x)
y = functional.log_softmax(x, axis, dtype)
self.assertTrue(np.allclose(y.numpy(), ref_out))
paddle.disable_static()
x = paddle.to_variable(self.x)
y = F.log_softmax(x, axis, dtype)
self.assertTrue(np.allclose(y.numpy(), ref_out), True)
paddle.enable_static()
def test_check_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.check_api(place, None, None)
self.check_api(place, None, np.float64)
for axis in [-1, 1]:
self.check_api(axis)
self.check_api(-1, 'float64')
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data(name='X1', shape=[100], dtype='int32')
self.assertRaises(TypeError, F.log_softmax, x)
x = paddle.data(name='X2', shape=[100], dtype='float32')
self.assertRaises(TypeError, F.log_softmax, x, dtype='int32')
if __name__ == "__main__":
......
......@@ -65,7 +65,7 @@ import warnings
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_
from ...fluid import core
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
import paddle
......@@ -413,12 +413,10 @@ def softmax(x, axis=-1, name=None):
return paddle.fluid.layers.softmax(input=x, axis=axis, name=name)
def log_softmax(input, axis=None, dtype=None, name=None):
def log_softmax(x, axis=-1, dtype=None, name=None):
"""
:alias_main: paddle.nn.functional.log_softmax
:alias: paddle.nn.functional.log_softmax,paddle.nn.functional.activation.log_softmax
This operator implements the log_softmax layer. The calculation process is as follows:
This operator implements the log_softmax layer. The calculation process is
as follows:
.. math::
......@@ -426,78 +424,85 @@ def log_softmax(input, axis=None, dtype=None, name=None):
= log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
Parameters:
input (Variable): The input variable. A multi-dimension Tensor with type float32, or float64.
axis (int, optional): The index of dimension to perform softmax calculations, it should be in
range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None.
None and -1 means the last dimension.
dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified,
the input tensor is casted to dtype before the operation is performed. This is useful for
preventing data type overflows. Default: None. Supported dtype: float32 or float64
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
x (Tensor): The input Tensor with data type float32, float64.
axis (int, optional): The axis along which to perform log_softmax
calculations. It should be in range [-D, D), where D is the
dimensions of ``x`` . If ``axis`` < 0, it works the same way as
:math:`axis + D` . Default is -1.
dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
type of the output tensor. If dtype is specified, ``x`` is casted
to ``dtype`` before the operation is performed. This is useful for
preventing data type overflows. Supported dtype: float32, float64.
If ``dtype`` is None, the output Tensor has the same dtype as x.
Default is None.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input``.
A Tensor with the same shape and data type (use ``dtype`` if it is
specified) as x.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn.functional as F
import numpy as np
import paddle
import paddle.nn.functional as F
import numpy as np
data = np.array([[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]]).astype('float32')
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
res = F.log_softmax(data, -1)
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
paddle.disable_static()
x = np.array([[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]], 'float32')
x = paddle.to_tensor(x)
out1 = F.log_softmax(x)
out2 = F.log_softmax(x, dtype='float64')
# out1's data type is float32; out2's data type is float64
# out1 and out2's value is as follows:
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
"""
axis = -1 if axis is None else axis
dtype = convert_np_dtype_to_dtype_(dtype) if dtype is not None else dtype
if axis is None:
axis = -1
if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
outs_cast = input if dtype is None \
else core.ops.cast(input, 'in_dtype', input.dtype, 'out_dtype', dtype)
outs_softmax = core.ops.softmax(outs_cast, 'axis', axis, 'use_cudnn',
False)
return core.ops.log(outs_softmax)
if dtype is not None:
x = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
return core.ops.log_softmax(x, 'axis', axis)
if dtype is None:
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'log_softmax')
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'log_softmax')
else:
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'log_softmax',
'If dtype is not None, it only support float32 or float64.')
helper = LayerHelper("log_softmax", **locals())
outs_cast = input
out_cast = x
if dtype is not None:
outs_cast = helper.create_variable_for_type_inference(dtype)
out_cast = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='cast',
inputs={'X': input},
outputs={'Out': outs_cast},
attrs={'in_dtype': input.dtype,
inputs={'X': x},
outputs={'Out': out_cast},
attrs={'in_dtype': x.dtype,
'out_dtype': dtype})
outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype)
helper.append_op(
type='softmax',
inputs={'X': outs_cast},
outputs={'Out': outs_softmax},
attrs={'axis': axis,
'use_cudnn': False})
outs_log = helper.create_variable_for_type_inference(outs_softmax.dtype)
out = helper.create_variable_for_type_inference(out_cast.dtype)
helper.append_op(
type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log})
type='log_softmax',
inputs={'X': out_cast},
outputs={'Out': out},
attrs={'axis': axis})
return outs_log
return out
......@@ -338,9 +338,6 @@ class Sigmoid(layers.Layer):
class LogSoftmax(layers.Layer):
"""
:alias_main: paddle.nn.LogSoftmax
:alias: paddle.nn.LogSoftmax,paddle.nn.layer.LogSoftmax,paddle.nn.layer.activation.LogSoftmax
This operator implements the log_softmax layer. The calculation process is as follows:
.. math::
......@@ -349,44 +346,46 @@ class LogSoftmax(layers.Layer):
= log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
Parameters:
axis (int, optional): The index of dimension to perform softmax calculations, it should be in
range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None.
None and -1 means the last dimension.
dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified,
the input tensor is casted to dtype before the operation is performed. This is useful for
preventing data type overflows. Default: None. Supported dtype: float32 or float64
axis (int, optional): The axis along which to perform log_softmax
calculations. It should be in range [-D, D), where D is the
dimensions of the input Tensor . If ``axis`` < 0, it works the
same way as :math:`axis + D` . Default is -1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
None
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn as nn
import numpy as np
import paddle
import numpy as np
data = np.array([[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]]).astype('float32')
my_log_softnmax = nn.LogSoftmax()
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
res = my_log_softnmax(data)
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
paddle.disable_static()
x = np.array([[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]])
m = paddle.nn.LogSoftmax()
x = paddle.to_tensor(x)
out = m(x)
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
"""
def __init__(self, axis=None):
def __init__(self, axis=-1, name=None):
super(LogSoftmax, self).__init__()
self._axis = axis
self._name = name
def forward(self, input):
return F.log_softmax(input, self._axis)
def forward(self, x):
return F.log_softmax(x, self._axis)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册