提交 b314a695 编写于 作者: F fengjiayi

make softmax supporting tensors

上级 b1af7e5d
...@@ -37,10 +37,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -37,10 +37,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SoftmaxOp should not be null."); "Output(Out) of SoftmaxOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix.");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -81,8 +78,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,8 +78,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax, "
"2-D with shape [batch_size, input_feature_dimensions]."); "whose last dimension is the input_feature_dimensions.");
AddOutput("Out", "The normalized values with the same shape as X.") AddOutput("Out", "The normalized values with the same shape as X.")
.Reuse("X"); .Reuse("X");
AddAttr<bool>( AddAttr<bool>(
...@@ -105,20 +102,23 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,20 +102,23 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Softmax Operator. Softmax Operator.
The input of the softmax operator is a 2-D tensor with shape N x K (N is the The input of the softmax operator is a tensor of any rank. The output tensor
batch_size, K is the dimension of input feature). The output tensor has the has the same shape as the input.
same shape as the input tensor.
For each row of the input tensor, the softmax operator squashes the The input tensor will first be logically flattened to a 2-D matrix. The matrix's
K-dimensional vector of arbitrary real values to a K-dimensional vector of real second dimension(row length) is as same as the last dimension of the input
values in the range [0, 1] that add up to 1. tensor, and the first dimension(column length) is the product of all other
dimensions of the input tensor. For each row of the matrix, the softmax operator
squashes the K-dimensional(K is the width of the matrix, which is also the size
of the input tensor's last dimension) vector of arbitrary real values to a
K-dimensional vector of real values in the range [0, 1] that add up to 1.
It computes the exponential of the given dimension and the sum of exponential It computes the exponential of the given dimension and the sum of exponential
values of all the other dimensions in the K-dimensional vector input. values of all the other dimensions in the K-dimensional vector input.
Then the ratio of the exponential of the given dimension and the sum of Then the ratio of the exponential of the given dimension and the sum of
exponential values of all the other dimensions is the output of the softmax exponential values of all the other dimensions is the output of the softmax
operator. operator.
For each row $i$ and each column $j$ in Input(X), we have: For each row $i$ and each column $j$ in the matrix, we have:
$$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$ $$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC"); )DOC");
......
...@@ -31,8 +31,16 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -31,8 +31,16 @@ class SoftmaxKernel : public framework::OpKernel<T> {
// allocate memory on device. // allocate memory on device.
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto dims = X->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_x;
framework::LoDTensor flattened_out;
flattened_x.ShareDataWith(*X);
flattened_out.ShareDataWith(*Out);
math::SoftmaxFunctor<DeviceContext, T>()( math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Out); context.template device_context<DeviceContext>(), &flattened_x,
&flattened_out);
} }
}; };
...@@ -47,8 +55,18 @@ class SoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -47,8 +55,18 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
// allocate memory on device. // allocate memory on device.
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto dims = Out->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_out;
framework::LoDTensor flattened_d_out;
framework::LoDTensor flattened_d_x;
flattened_out.ShareDataWith(*Out);
flattened_d_out.ShareDataWith(*dOut);
flattened_d_x.ShareDataWith(*dX);
math::SoftmaxGradFunctor<DeviceContext, T>()( math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Out, dOut, dX); context.template device_context<DeviceContext>(), &flattened_out,
&flattened_d_out, &flattened_d_x);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册