diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index ad3e5543f10ae05865565110ba2231c897c205b8..94e54266f0f922efef5ea4a1b23338b6ce02d131 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -12,60 +12,90 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/platform/cudnn_desc.h" +#include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { namespace operators { +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using DataLayout = platform::DataLayout; using Tensor = framework::Tensor; +static inline int SizeOutAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + template class SoftmaxCUDNNKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Out = context.Output("Out"); - - // allocate memory on device. - Out->mutable_data(context.GetPlace()); - - auto dims = X->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_x; - framework::LoDTensor flattened_out; - flattened_x.ShareDataWith(*X).Resize(flattened_dims); - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - - math::SoftmaxCUDNNFunctor()( - context.template device_context(), - &flattened_x, &flattened_out); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + auto* out_data = out->data(); + + auto dims = x->dims(); + const int rank = dims.size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + const int dim = dims[axis]; + const int N = SizeToAxis(axis, dims); + const int D = SizeOutAxis(axis, dims); + + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + DataLayout layout = DataLayout::kNCHW; + cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); + + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( + handle, CUDNN_SOFTMAX_ACCURATE, mode, + platform::CudnnDataType::kOne(), desc_, x->data(), + platform::CudnnDataType::kZero(), desc_, out_data)); } }; template class SoftmaxGradCUDNNKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* Out = context.Input("Out"); - auto* dOut = context.Input(framework::GradVarName("Out")); - auto* dX = context.Output(framework::GradVarName("X")); - - // allocate memory on device. - dX->mutable_data(context.GetPlace()); - - auto dims = Out->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_out; - framework::LoDTensor flattened_d_out; - framework::LoDTensor flattened_d_x; - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); - flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); - - math::SoftmaxGradCUDNNFunctor()( - context.template device_context(), - &flattened_out, &flattened_d_out, &flattened_d_x); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + auto* dx_data = dx->data(); + + auto dims = out->dims(); + const int rank = dims.size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + const int dim = dims[axis]; + const int N = SizeToAxis(axis, dims); + const int D = SizeOutAxis(axis, dims); + + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + DataLayout layout = DataLayout::kNCHW; + cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); + + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( + handle, CUDNN_SOFTMAX_ACCURATE, mode, + platform::CudnnDataType::kOne(), desc_, out->data(), desc_, + dout->data(), platform::CudnnDataType::kZero(), desc_, dx_data)); } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 2a6ca7975f0c591701400e71feab0be36300480b..cf46b4fc3bdad486f65afa5ac994d506c20344cb 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -53,13 +53,6 @@ class SoftmaxOp : public framework::OperatorWithKernel { "Attr(axis) value should be in range [-R, R-1], " "R is the rank of Input(X).")); - auto use_cudnn = ctx->Attrs().Get("use_cudnn"); - if (axis != rank_x - 1 && axis != -1) { - PADDLE_ENFORCE_EQ(use_cudnn, false, - platform::errors::InvalidArgument( - "CUDNN kernel only support axis as -1.")); - } - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 04d5cc941a4636da0352fe9221cdad8bdfcd2bd9..a37fad9cf0ca0772fc4939e6d59e27f6fc7efac1 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -153,16 +153,103 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp): return [2, 3, 4, 5] +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp3(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 0 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp4(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 1 + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp): def get_x_shape(self): return [2, 3, 4, 5] + def get_axis(self): + return 2 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp6(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + def get_axis(self): return 3 +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp7(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp8(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 0 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp9(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 1 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp10(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 2 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp11(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 3 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp12(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 4 + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxFP16Op(TestSoftmaxOp):