From 855c9e3311646ca96ee244f4f70e5fdf9d01304c Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 9 Aug 2018 15:04:48 +0800 Subject: [PATCH] clean softmax_op code --- paddle/fluid/operators/softmax_op.h | 31 ++++++++++++----------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 1205bd0587..febd938022 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -31,16 +31,13 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); - auto dims = X->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_x; - framework::LoDTensor flattened_out; - flattened_x.ShareDataWith(*X).Resize(flattened_dims); - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + int rank = X->dims().size(); + Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X; + Tensor Out_2d = + rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out; math::SoftmaxFunctor()( - context.template device_context(), &flattened_x, - &flattened_out); + context.template device_context(), &X_2d, &Out_2d); } }; @@ -55,18 +52,16 @@ class SoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); - auto dims = Out->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_out; - framework::LoDTensor flattened_d_out; - framework::LoDTensor flattened_d_x; - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); - flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); + int rank = Out->dims().size(); + Tensor Out_2d = + rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out; + Tensor dOut_2d = + rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut; + Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX; math::SoftmaxGradFunctor()( - context.template device_context(), &flattened_out, - &flattened_d_out, &flattened_d_x); + context.template device_context(), &Out_2d, &dOut_2d, + &dX_2d); } }; -- GitLab