提交 518325f1 编写于 作者: D dengkaipeng

add softmax_axis CPU kernel. test=develop

上级 6429d2a8
......@@ -37,6 +37,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SoftmaxOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto rank_x = dim_x.size();
auto axis = ctx->Attrs().Get<int>("axis");
PADDLE_ENFORCE(axis >= -1 && axis < rank_x,
"Attr(axis) value should larger equal then -1"
"and less then the rank of Input(X)");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -80,6 +87,10 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
"The input tensor of softmax, "
"whose last dimension is the input_feature_dimensions.");
AddOutput("Out", "The normalized values with the same shape as X.");
AddAttr<int>("axis",
"The dimension of Input(x) to perform softmax,"
"default -1 for last dimension")
.SetDefault(-1);
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
......
......@@ -13,27 +13,69 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
static inline void TransposeAxisToEnd(const Tensor& x, const Tensor& out,
Tensor* x_trans, Tensor* out_trans,
const int axis, std::vector<int> perm,
const framework::ExecutionContext& ctx) {
auto dim_x = x.dims();
int rank = dim_x.size();
if (axis == -1 || axis == rank - 1) {
*x_trans = x;
*out_trans = out;
return;
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<int> shape;
for (int i = 0; i < rank - 1; i++) {
if (i == axis) {
perm.push_back(rank - 1);
shape.push_back(dim_x[rank - 1]);
} else {
perm.push_back(i);
shape.push_back(dim_x[i]);
}
}
perm.push_back(axis);
shape.push_back(dim_x[axis]);
x_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
out_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, x, x_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, out, out_trans, perm);
}
template <typename DeviceContext, typename T>
class SoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int axis = context.Attr<int>("axis");
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
Tensor X_trans, Out_trans;
std::vector<int> perm;
TransposeAxisToEnd<DeviceContext, T>(*X, *Out, &X_trans, &Out_trans, axis,
perm, context);
int rank = X->dims().size();
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
Tensor X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Tensor Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
......@@ -42,6 +84,11 @@ class SoftmaxKernel : public framework::OpKernel<T> {
math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
#endif
if (axis != -1 && axis != rank - 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册