From 6c641827092fb10f6eeb56477819c76f2b331969 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 18 Mar 2019 11:57:16 +0000 Subject: [PATCH] refine softmax kernel. test=develop --- paddle/fluid/operators/math/softmax.h | 9 +- paddle/fluid/operators/math/softmax_impl.h | 22 +-- .../operators/mkldnn/softmax_mkldnn_op.cc | 134 +++++------------- paddle/fluid/operators/softmax_cudnn_op.cu.cc | 85 +++-------- paddle/fluid/operators/softmax_op.h | 114 ++++++--------- .../operators/softmax_with_cross_entropy_op.h | 2 +- paddle/fluid/operators/warpctc_cudnn_op.cu.cc | 2 +- 7 files changed, 119 insertions(+), 249 deletions(-) diff --git a/paddle/fluid/operators/math/softmax.h b/paddle/fluid/operators/math/softmax.h index 81beef56d94..f8e250fa2e7 100644 --- a/paddle/fluid/operators/math/softmax.h +++ b/paddle/fluid/operators/math/softmax.h @@ -23,15 +23,16 @@ template class SoftmaxFunctor { public: - void operator()(const DeviceContext& context, const framework::Tensor* X, - framework::Tensor* Y); + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* X, framework::Tensor* Y); }; template class SoftmaxGradFunctor { public: - void operator()(const DeviceContext& context, const framework::Tensor* y, - const framework::Tensor* y_grad, framework::Tensor* x_grad); + void operator()(const DeviceContext& context, const int axis_dim, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad); }; #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/math/softmax_impl.h b/paddle/fluid/operators/math/softmax_impl.h index d77b6712c54..9bcb272b93b 100644 --- a/paddle/fluid/operators/math/softmax_impl.h +++ b/paddle/fluid/operators/math/softmax_impl.h @@ -36,8 +36,8 @@ struct ValueClip { template void SoftmaxFunctor::operator()( - const DeviceContext& context, const framework::Tensor* X, - framework::Tensor* Y) { + const DeviceContext& context, const int axis_dim, + const framework::Tensor* X, framework::Tensor* Y) { auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); @@ -46,10 +46,13 @@ void SoftmaxFunctor::operator()( const int batch_size = logits.dimension(kBatchDim); const int num_classes = logits.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; Eigen::DSizes along_class(kClassDim); Eigen::DSizes batch_by_one(batch_size, 1); Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); auto shifted_logits = (logits - logits.maximum(along_class) @@ -60,11 +63,11 @@ void SoftmaxFunctor::operator()( softmax.device(*context.eigen_device()) = shifted_logits.exp(); softmax.device(*context.eigen_device()) = (softmax * - softmax.sum(along_class) + softmax.reshape(batch_axis_remain) + .sum(along_class) .inverse() .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); + .broadcast(one_axis)); } template @@ -90,7 +93,7 @@ class SoftmaxFunctor> { template void SoftmaxGradFunctor::operator()( - const DeviceContext& context, const framework::Tensor* y, + const DeviceContext& context, const int axis_dim, const framework::Tensor* y, const framework::Tensor* y_grad, framework::Tensor* x_grad) { auto softmax = EigenMatrix::From(*y); auto softmax_grad = EigenMatrix::From(*y_grad); @@ -101,16 +104,19 @@ void SoftmaxGradFunctor::operator()( const int batch_size = softmax.dimension(kBatchDim); const int num_classes = softmax.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; Eigen::DSizes along_class(kClassDim); Eigen::DSizes batch_by_one(batch_size, 1); Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); auto dot = (softmax * softmax_grad) + .reshape(batch_axis_remain) .sum(along_class) .eval() - .reshape(batch_by_one) - .broadcast(one_by_class); + .broadcast(one_axis); logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax; } diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index c73dfd65e76..0ce55221945 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -110,46 +110,28 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { "It must use CPUPlace."); auto& dev_ctx = ctx.template device_context(); auto mkldnn_engine = dev_ctx.GetEngine(); - const Tensor* X = ctx.Input("X"); - Tensor* Out = ctx.Output("Out"); + const Tensor* input = ctx.Input("X"); + Tensor* output = ctx.Output("Out"); PADDLE_ENFORCE_EQ( - X->dims(), Out->dims(), + input->dims(), output->dims(), "The shape of softmax's input and output must be identical."); - const int axis = ctx.Attr("axis"); - int rank = X->dims().size(); - // make sure 'output' holds memory, which will be shared by // 'flattened_output' later. - Out->mutable_data(ctx.GetPlace()); - - std::vector perm, shape; - CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape); - - Tensor X_2d, Out_2d; - Tensor X_trans, Out_trans; - if (axis != -1 && axis != rank - 1) { - X_trans.mutable_data(framework::make_ddim(shape), ctx.GetPlace()); - Out_trans.mutable_data(framework::make_ddim(shape), ctx.GetPlace()); - TransCompute(rank, dev_ctx, *X, &X_trans, - perm); - TransCompute(rank, dev_ctx, *Out, - &Out_trans, perm); - auto dims = X_trans.dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - X_2d.ShareDataWith(X_trans).Resize(flattened_dims); - Out_2d.ShareDataWith(Out_trans).Resize(flattened_dims); - } else { - auto dims = X->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - X_2d.ShareDataWith(*X).Resize(flattened_dims); - Out_2d.ShareDataWith(*Out).Resize(flattened_dims); - } + output->mutable_data(ctx.GetPlace()); + + // flatten input and output to 2-D matrixs + auto dims = input->dims(); // input and output share the same shape + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::Tensor flattened_input; + framework::Tensor flattened_output; + flattened_input.ShareDataWith(*input).Resize(flattened_dims); + flattened_output.ShareDataWith(*output).Resize(flattened_dims); - const T* input_data = X_2d.data(); - T* output_data = Out_2d.mutable_data(ctx.GetPlace()); + const T* input_data = flattened_input.data(); + T* output_data = flattened_output.mutable_data(ctx.GetPlace()); - std::vector src_tz = paddle::framework::vectorize2int(X_2d.dims()); + std::vector src_tz = paddle::framework::vectorize2int(flattened_dims); std::vector dst_tz = src_tz; // Same memory descriptor to be used for input and output memory::dims softmax_tz = {src_tz[0], src_tz[1]}; @@ -179,16 +161,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { // We cannot use softmax_dst_memory_p to get prim desc as // it contains flattened dims (2D) while output tensor can // have 2,3,4+ dims - if (axis != -1 && axis != rank - 1) { - auto output_mem_pd = paddle::platform::create_prim_desc_from_dims( - shape, mkldnn::memory::format::blocked); - Out_trans.set_mkldnn_prim_desc(output_mem_pd); - } else { - auto output_mem_pd = paddle::platform::create_prim_desc_from_dims( - paddle::framework::vectorize2int(Out->dims()), - mkldnn::memory::format::blocked); - Out->set_mkldnn_prim_desc(output_mem_pd); - } + auto output_mem_pd = paddle::platform::create_prim_desc_from_dims( + paddle::framework::vectorize2int(output->dims()), + mkldnn::memory::format::blocked); + output->set_mkldnn_prim_desc(output_mem_pd); std::vector pipeline{ *(static_cast(softmax_p.get()))}; @@ -202,11 +178,6 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { output_data[i] < threshold ? threshold : output_data[i]; } } - - if (axis != -1 && axis != rank - 1) { - TransCompute(rank, dev_ctx, Out_trans, Out, - perm); - } } }; @@ -219,55 +190,33 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel { auto& dev_ctx = ctx.template device_context(); auto mkldnn_engine = dev_ctx.GetEngine(); - const Tensor* Out = ctx.Input("Out"); - auto* dOut = ctx.template Input(framework::GradVarName("Out")); - auto* dX = + const Tensor* output = ctx.Input("Out"); + auto* dout = ctx.template Input(framework::GradVarName("Out")); + auto* dx = ctx.template Output(framework::GradVarName("X")); PADDLE_ENFORCE_EQ( - dOut->dims(), dX->dims(), + dout->dims(), dx->dims(), "The shape of softmax_grad's input and output must be identical."); - const int axis = ctx.Attr("axis"); - int rank = Out->dims().size(); - // make sure 'dx' holds memory, which will be shared by 'flattened_dx' // later. - dX->template mutable_data(ctx.GetPlace()); - - std::vector perm, shape; - CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape); - - Tensor dX_2d, Out_2d, dOut_2d; - Tensor dX_trans, Out_trans, dOut_trans; - if (axis != -1 && axis != rank - 1) { - dX_trans.mutable_data(framework::make_ddim(shape), ctx.GetPlace()); - Out_trans.mutable_data(framework::make_ddim(shape), ctx.GetPlace()); - dOut_trans.mutable_data(framework::make_ddim(shape), ctx.GetPlace()); - TransCompute(rank, dev_ctx, *dX, &dX_trans, - perm); - TransCompute(rank, dev_ctx, *Out, - &Out_trans, perm); - TransCompute(rank, dev_ctx, *dOut, - &dOut_trans, perm); - auto dims = dX_trans.dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - dX_2d.ShareDataWith(dX_trans).Resize(flattened_dims); - Out_2d.ShareDataWith(Out_trans).Resize(flattened_dims); - dOut_2d.ShareDataWith(dOut_trans).Resize(flattened_dims); - } else { - auto dims = dX->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - dX_2d.ShareDataWith(*dX).Resize(flattened_dims); - Out_2d.ShareDataWith(*Out).Resize(flattened_dims); - dOut_2d.ShareDataWith(*dOut).Resize(flattened_dims); - } - - const T* dst_data = Out_2d.data(); - const T* diff_dst_ptr = dOut_2d.template data(); - T* diff_src_ptr = dX_2d.template mutable_data(ctx.GetPlace()); - - std::vector dst_tz = paddle::framework::vectorize2int(Out_2d.dims()); + dx->template mutable_data(ctx.GetPlace()); + + auto dims = dout->dims(); // input and output share the same shape + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::Tensor flattened_output; + framework::Tensor flattened_dout; + framework::Tensor flattened_dx; + flattened_output.ShareDataWith(*output).Resize(flattened_dims); + flattened_dout.ShareDataWith(*dout).Resize(flattened_dims); + flattened_dx.ShareDataWith(*dx).Resize(flattened_dims); + + const T* dst_data = flattened_output.data(); + const T* diff_dst_ptr = flattened_dout.template data(); + T* diff_src_ptr = flattened_dx.template mutable_data(ctx.GetPlace()); + + std::vector dst_tz = paddle::framework::vectorize2int(flattened_dims); std::vector src_tz(dst_tz); // Same memory descriptor to be used for input and output @@ -312,11 +261,6 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel { std::vector pipeline{*softmax_bwd_p}; stream(stream::kind::eager).submit(pipeline).wait(); - - if (axis != -1 && axis != rank - 1) { - TransCompute(rank, dev_ctx, dX_trans, dX, - perm); - } } }; } // namespace operators diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 9e24c76793c..ad3e5543f10 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/softmax_op.h" namespace paddle { namespace operators { @@ -25,44 +24,22 @@ template class SoftmaxCUDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = - context.template device_context(); auto* X = context.Input("X"); auto* Out = context.Output("Out"); - const int axis = context.Attr("axis"); - int rank = X->dims().size(); // allocate memory on device. Out->mutable_data(context.GetPlace()); - std::vector perm, shape; - CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape); - - Tensor X_2d, Out_2d; - Tensor X_trans, Out_trans; - if (axis != -1 && axis != rank - 1) { - X_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); - Out_trans.mutable_data(framework::make_ddim(shape), - context.GetPlace()); - TransCompute(rank, dev_ctx, *X, &X_trans, - perm); - TransCompute(rank, dev_ctx, *Out, - &Out_trans, perm); - X_2d = framework::ReshapeToMatrix(X_trans, rank - 1); - Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); - } else { - X_2d = framework::ReshapeToMatrix(*X, rank - 1); - Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); - } + auto dims = X->dims(); + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::LoDTensor flattened_x; + framework::LoDTensor flattened_out; + flattened_x.ShareDataWith(*X).Resize(flattened_dims); + flattened_out.ShareDataWith(*Out).Resize(flattened_dims); math::SoftmaxCUDNNFunctor()( - context.template device_context(), &X_2d, - &Out_2d); - - if (axis != -1 && axis != rank - 1) { - TransCompute(rank, dev_ctx, Out_trans, - Out, perm); - } + context.template device_context(), + &flattened_x, &flattened_out); } }; @@ -70,51 +47,25 @@ template class SoftmaxGradCUDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = - context.template device_context(); auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); - const int axis = context.Attr("axis"); - int rank = Out->dims().size(); // allocate memory on device. dX->mutable_data(context.GetPlace()); - std::vector perm, shape; - CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape); - - Tensor dX_2d, Out_2d, dOut_2d; - Tensor dX_trans, Out_trans, dOut_trans; - if (axis != -1 && axis != rank - 1) { - dX_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); - Out_trans.mutable_data(framework::make_ddim(shape), - context.GetPlace()); - dOut_trans.mutable_data(framework::make_ddim(shape), - context.GetPlace()); - TransCompute(rank, dev_ctx, *dX, - &dX_trans, perm); - TransCompute(rank, dev_ctx, *Out, - &Out_trans, perm); - TransCompute(rank, dev_ctx, *dOut, - &dOut_trans, perm); - dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1); - Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); - dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1); - } else { - dX_2d = framework::ReshapeToMatrix(*dX, rank - 1); - Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); - dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1); - } + auto dims = Out->dims(); + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::LoDTensor flattened_out; + framework::LoDTensor flattened_d_out; + framework::LoDTensor flattened_d_x; + flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); + flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); math::SoftmaxGradCUDNNFunctor()( - context.template device_context(), &Out_2d, - &dOut_2d, &dX_2d); - - if (axis != -1 && axis != rank - 1) { - TransCompute(rank, dev_ctx, dX_trans, dX, - perm); - } + context.template device_context(), + &flattened_out, &flattened_d_out, &flattened_d_x); } }; diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 10b3f63339f..76e8eeab080 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -13,81 +13,66 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/softmax.h" -#include "paddle/fluid/operators/transpose_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using DDim = framework::DDim; -static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis, - std::vector* perm, - std::vector* shape) { - auto dim_x = x.dims(); - int rank = dim_x.size(); +static inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} - if (axis == -1 || axis == rank - 1) { - return; +static inline int SizeToAxis(const int axis, DDim dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; } + return size; +} - for (int i = 0; i < rank - 1; i++) { - if (i == axis) { - perm->push_back(rank - 1); - shape->push_back(dim_x[rank - 1]); - } else { - perm->push_back(i); - shape->push_back(dim_x[i]); - } +static inline int SizeFromAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; } - perm->push_back(axis); - shape->push_back(dim_x[axis]); + return size; } template class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = context.template device_context(); auto* X = context.Input("X"); auto* Out = context.Output("Out"); - const int axis = context.Attr("axis"); - int rank = X->dims().size(); + const int rank = X->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = X->dims()[axis]; // allocate memory on device. Out->mutable_data(context.GetPlace()); - std::vector perm, shape; - CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape); - + const int n = SizeToAxis(axis, X->dims()); + const int d = SizeFromAxis(axis, X->dims()); Tensor X_2d, Out_2d; - Tensor X_trans, Out_trans; - if (axis != -1 && axis != rank - 1) { - X_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); - Out_trans.mutable_data(framework::make_ddim(shape), - context.GetPlace()); - TransCompute(rank, dev_ctx, *X, &X_trans, perm); - TransCompute(rank, dev_ctx, *Out, &Out_trans, perm); - X_2d = framework::ReshapeToMatrix(X_trans, rank - 1); - Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); - } else { - X_2d = framework::ReshapeToMatrix(*X, rank - 1); - Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); - } + X_2d.ShareDataWith(*X).Resize({n, d}); + Out_2d.ShareDataWith(*Out).Resize({n, d}); + // Tensor X_2d = framework::ReshapeToMatrix(*X, axis - 1); + // Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1); #ifdef PADDLE_ON_INFERENCE math::SoftmaxFunctor()( - context.template device_context(), &X_2d, &Out_2d); + context.template device_context(), axis_dim, &X_2d, &Out_2d); #else math::SoftmaxFunctor()( - context.template device_context(), &X_2d, &Out_2d); + context.template device_context(), axis_dim, &X_2d, &Out_2d); #endif - - if (axis != -1 && axis != rank - 1) { - TransCompute(rank, dev_ctx, Out_trans, Out, perm); - } } }; @@ -95,46 +80,29 @@ template class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = context.template device_context(); auto* Out = context.Input("Out"); auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); - const int axis = context.Attr("axis"); - int rank = Out->dims().size(); + const int rank = dX->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = dX->dims()[axis]; // allocate memory on device. dX->mutable_data(context.GetPlace()); - std::vector perm, shape; - CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape); - + const int n = SizeToAxis(axis, dX->dims()); + const int d = SizeFromAxis(axis, dX->dims()); Tensor dX_2d, Out_2d, dOut_2d; - Tensor dX_trans, Out_trans, dOut_trans; - if (axis != -1 && axis != rank - 1) { - dX_trans.mutable_data(framework::make_ddim(shape), context.GetPlace()); - Out_trans.mutable_data(framework::make_ddim(shape), - context.GetPlace()); - dOut_trans.mutable_data(framework::make_ddim(shape), - context.GetPlace()); - TransCompute(rank, dev_ctx, *dX, &dX_trans, perm); - TransCompute(rank, dev_ctx, *Out, &Out_trans, perm); - TransCompute(rank, dev_ctx, *dOut, &dOut_trans, perm); - dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1); - Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); - dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1); - } else { - dX_2d = framework::ReshapeToMatrix(*dX, rank - 1); - Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); - dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1); - } + dX_2d.ShareDataWith(*dX).Resize({n, d}); + Out_2d.ShareDataWith(*Out).Resize({n, d}); + dOut_2d.ShareDataWith(*dOut).Resize({n, d}); + // Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1); + // Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, axis - 1); + // Tensor dX_2d = framework::ReshapeToMatrix(*dX, axis - 1); math::SoftmaxGradFunctor()( - context.template device_context(), &Out_2d, &dOut_2d, + context.template device_context(), axis_dim, &Out_2d, &dOut_2d, &dX_2d); - - if (axis != -1 && axis != rank - 1) { - TransCompute(rank, dev_ctx, dX_trans, dX, perm); - } } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index c0530e3d8bc..ff99e4207a7 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -43,7 +43,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); math::SoftmaxFunctor()( - dev_ctx, logits, softmax); + dev_ctx, -1, logits, softmax); math::CrossEntropyFunctor()( dev_ctx, loss, softmax, labels, context.Attr("soft_label"), context.Attr("ignore_index")); diff --git a/paddle/fluid/operators/warpctc_cudnn_op.cu.cc b/paddle/fluid/operators/warpctc_cudnn_op.cu.cc index a764d59410c..716faf2995e 100644 --- a/paddle/fluid/operators/warpctc_cudnn_op.cu.cc +++ b/paddle/fluid/operators/warpctc_cudnn_op.cu.cc @@ -69,7 +69,7 @@ class CudnnCTCKernel : public framework::OpKernel { int rank = logits->dims().size(); Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1); Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1); - math::SoftmaxFunctor()(dev_ctx, &in_2d, &out_2d); + math::SoftmaxFunctor()(dev_ctx, -1, &in_2d, &out_2d); // ctc needs sequences data stored in transposed padding format // logits and grad using padding data of layout 'TNC' -- GitLab