diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index ad3e5543f10ae05865565110ba2231c897c205b8..84151d70b9997ba717d1942b93a9be0c31f534a3 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -24,22 +25,40 @@ template class SoftmaxCUDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = context.template device_context(); auto* X = context.Input("X"); auto* Out = context.Output("Out"); + // auto dims = X->dims(); + const int axis = context.Attr("axis"); + int rank = X->dims().size(); // allocate memory on device. Out->mutable_data(context.GetPlace()); - auto dims = X->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_x; - framework::LoDTensor flattened_out; - flattened_x.ShareDataWith(*X).Resize(flattened_dims); - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + std::vector perm, shape; + CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape); + + Tensor X_2d, Out_2d; + Tensor X_trans, Out_trans; + if (axis != -1 && axis != rank - 1) { + X_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + Out_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + TransCompute(rank, dev_ctx, *X, &X_trans, perm); + TransCompute(rank, dev_ctx, *Out, &Out_trans, perm); + X_2d = framework::ReshapeToMatrix(X_trans, rank - 1); + Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); + } else { + X_2d = framework::ReshapeToMatrix(*X, rank - 1); + Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); + } math::SoftmaxCUDNNFunctor()( context.template device_context(), - &flattened_x, &flattened_out); + &X_2d, &Out_2d); + + if (axis != -1 && axis != rank - 1) { + TransCompute(rank, dev_ctx, Out_trans, Out, perm); + } } }; @@ -47,25 +66,44 @@ template class SoftmaxGradCUDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = context.template device_context(); auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); + const int axis = context.Attr("axis"); + int rank = Out->dims().size(); // allocate memory on device. dX->mutable_data(context.GetPlace()); - auto dims = Out->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_out; - framework::LoDTensor flattened_d_out; - framework::LoDTensor flattened_d_x; - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); - flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); + std::vector perm, shape; + CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape); + + Tensor dX_2d, Out_2d, dOut_2d; + Tensor dX_trans, Out_trans, dOut_trans; + if (axis != -1 && axis != rank - 1) { + dX_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + Out_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + dOut_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + TransCompute(rank, dev_ctx, *dX, &dX_trans, perm); + TransCompute(rank, dev_ctx, *Out, &Out_trans, perm); + TransCompute(rank, dev_ctx, *dOut, &dOut_trans, perm); + dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1); + Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); + dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1); + } else { + dX_2d = framework::ReshapeToMatrix(*dX, rank - 1); + Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); + dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1); + } math::SoftmaxGradCUDNNFunctor()( context.template device_context(), - &flattened_out, &flattened_d_out, &flattened_d_x); + &Out_2d, &dOut_2d, &dX_2d); + + if (axis != -1 && axis != rank - 1) { + TransCompute(rank, dev_ctx, dX_trans, dX, perm); + } } }; diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index ad41e52116c2ecc8d7488f71d5e4b118638204b9..1810b23e0d456245a0a6e5bbec4c9e36850a433f 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -23,59 +23,58 @@ namespace operators { using Tensor = framework::Tensor; -template -static inline void TransposeAxisToEnd(const Tensor& x, const Tensor& out, - Tensor* x_trans, Tensor* out_trans, - const int axis, std::vector perm, - const framework::ExecutionContext& ctx) { +static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis, + std::vector* perm, std::vector* shape) { auto dim_x = x.dims(); int rank = dim_x.size(); if (axis == -1 || axis == rank - 1) { - *x_trans = x; - *out_trans = out; return; } - auto& dev_ctx = ctx.template device_context(); - std::vector shape; for (int i = 0; i < rank - 1; i++) { if (i == axis) { - perm.push_back(rank - 1); - shape.push_back(dim_x[rank - 1]); + perm->push_back(rank - 1); + shape->push_back(dim_x[rank - 1]); } else { - perm.push_back(i); - shape.push_back(dim_x[i]); + perm->push_back(i); + shape->push_back(dim_x[i]); } } - perm.push_back(axis); - shape.push_back(dim_x[axis]); - - x_trans->mutable_data(framework::make_ddim(shape), ctx.GetPlace()); - out_trans->mutable_data(framework::make_ddim(shape), ctx.GetPlace()); - TransCompute(rank, dev_ctx, x, x_trans, perm); - TransCompute(rank, dev_ctx, out, out_trans, perm); + perm->push_back(axis); + shape->push_back(dim_x[axis]); } template class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = context.template device_context(); auto* X = context.Input("X"); auto* Out = context.Output("Out"); const int axis = context.Attr("axis"); + int rank = X->dims().size(); // allocate memory on device. Out->mutable_data(context.GetPlace()); + std::vector perm, shape; + CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape); + + Tensor X_2d, Out_2d; Tensor X_trans, Out_trans; - std::vector perm; - TransposeAxisToEnd(*X, *Out, &X_trans, &Out_trans, axis, - perm, context); + if (axis != -1 && axis != rank - 1) { + X_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + Out_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + TransCompute(rank, dev_ctx, *X, &X_trans, perm); + TransCompute(rank, dev_ctx, *Out, &Out_trans, perm); + X_2d = framework::ReshapeToMatrix(X_trans, rank - 1); + Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); + } else { + X_2d = framework::ReshapeToMatrix(*X, rank - 1); + Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); + } - int rank = X->dims().size(); - Tensor X_2d = framework::ReshapeToMatrix(X_trans, rank - 1); - Tensor Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); #ifdef PADDLE_ON_INFERENCE math::SoftmaxFunctor()( @@ -86,7 +85,6 @@ class SoftmaxKernel : public framework::OpKernel { #endif if (axis != -1 && axis != rank - 1) { - auto& dev_ctx = context.template device_context(); TransCompute(rank, dev_ctx, Out_trans, Out, perm); } } @@ -96,21 +94,44 @@ template class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = context.template device_context(); auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); + const int axis = context.Attr("axis"); + int rank = Out->dims().size(); // allocate memory on device. dX->mutable_data(context.GetPlace()); - int rank = Out->dims().size(); - Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); - Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1); - Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1); + std::vector perm, shape; + CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape); + + Tensor dX_2d, Out_2d, dOut_2d; + Tensor dX_trans, Out_trans, dOut_trans; + if (axis != -1 && axis != rank - 1) { + dX_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + Out_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + dOut_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); + TransCompute(rank, dev_ctx, *dX, &dX_trans, perm); + TransCompute(rank, dev_ctx, *Out, &Out_trans, perm); + TransCompute(rank, dev_ctx, *dOut, &dOut_trans, perm); + dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1); + Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); + dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1); + } else { + dX_2d = framework::ReshapeToMatrix(*dX, rank - 1); + Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); + dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1); + } math::SoftmaxGradFunctor()( context.template device_context(), &Out_2d, &dOut_2d, &dX_2d); + + if (axis != -1 && axis != rank - 1) { + TransCompute(rank, dev_ctx, dX_trans, dX, perm); + } } }; diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 5c56de6779d238064f03a65b54f3c73a77119f60..084fa869e3a7ea407ce7d8f8a6f5c9b81644dbfb 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -31,6 +31,9 @@ class TestSoftmaxOp(OpTest): def get_x_shape(self): return [10, 10] + def get_axis(self): + return -1 + def setUp(self): self.op_type = "softmax" self.use_cudnn = False @@ -38,15 +41,15 @@ class TestSoftmaxOp(OpTest): self.dtype = np.float32 self.init_kernel_type() self.shape = self.get_x_shape() + self.axis = self.get_axis() x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) - out = np.apply_along_axis(stable_softmax, 1, - x.reshape([-1, self.shape[-1]])) - out = out.reshape(self.shape) + out = np.apply_along_axis(stable_softmax, self.axis, x) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} self.attrs = { + 'axis': self.axis, 'use_cudnn': self.use_cudnn, 'use_mkldnn': self.use_mkldnn } @@ -76,6 +79,38 @@ class TestSoftmaxOp2(TestSoftmaxOp): return [2, 3, 4, 5] +class TestSoftmaxOp3(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 0 + + +class TestSoftmaxOp4(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 1 + + +class TestSoftmaxOp5(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 2 + + +class TestSoftmaxOp5(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 3 + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxCUDNNOp(TestSoftmaxOp): @@ -90,6 +125,26 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp): return [2, 3, 4, 5] +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp3(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 1 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 2 + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxFP16Op(TestSoftmaxOp):