diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 38e93fdf15d99eb447948378a599891074c10fc5..34ea6a91ce7743462d378cf471a5ec3a12ca51d1 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -14,13 +14,86 @@ limitations under the License. */ #define EIGEN_USE_GPU +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax_impl.h" +#include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { namespace operators { namespace math { +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using DataLayout = platform::DataLayout; +template +using CudnnDataType = platform::CudnnDataType; + +template +void SoftmaxCUDNNFunctor::operator()( + const platform::CUDADeviceContext& context, const framework::Tensor* X, + framework::Tensor* Y) { + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor xDesc; + ScopedTensorDescriptor yDesc; + std::vector cudnn_tensor_dims = framework::vectorize2int(X->dims()); + DataLayout layout = DataLayout::kNCHW; + if (cudnn_tensor_dims.size() == 5) { + layout = DataLayout::kNCDHW; + } + // NOTE(*) : cudnn softmax only support >= 4D Tensor, + // fill 1 at unused dims + if (cudnn_tensor_dims.size() <= 2) { + cudnn_tensor_dims.resize(4, 1); + } + cudnnTensorDescriptor_t cudnn_x_desc = + xDesc.descriptor(layout, cudnn_tensor_dims); + cudnnTensorDescriptor_t cudnn_y_desc = + xDesc.descriptor(layout, cudnn_tensor_dims); + PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxForward( + context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType::kOne(), cudnn_x_desc, + X->data(), CudnnDataType::kZero(), cudnn_y_desc, + Y->mutable_data(context.GetPlace()))); +} + +template +void SoftmaxGradCUDNNFunctor::operator()( + const platform::CUDADeviceContext& context, const framework::Tensor* Y, + const framework::Tensor* YGrad, framework::Tensor* XGrad) { + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor yDesc; + ScopedTensorDescriptor dyDesc; + ScopedTensorDescriptor dxDesc; + std::vector cudnn_tensor_dims = framework::vectorize2int(Y->dims()); + DataLayout layout = DataLayout::kNCHW; + if (cudnn_tensor_dims.size() == 5) { + layout = DataLayout::kNCDHW; + } + // NOTE(*) : cudnn softmax only support >= 4D Tensor, + // fill 1 at unused dims + if (cudnn_tensor_dims.size() <= 2) { + cudnn_tensor_dims.resize(4, 1); + } + cudnnTensorDescriptor_t cudnn_y_desc = + yDesc.descriptor(layout, cudnn_tensor_dims); + cudnnTensorDescriptor_t cudnn_xgrad_desc = + dxDesc.descriptor(layout, cudnn_tensor_dims); + cudnnTensorDescriptor_t cudnn_ygrad_desc = + dyDesc.descriptor(layout, cudnn_tensor_dims); + PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxBackward( + context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType::kOne(), cudnn_y_desc, + Y->data(), cudnn_ygrad_desc, YGrad->data(), + CudnnDataType::kZero(), cudnn_xgrad_desc, + XGrad->mutable_data(context.GetPlace()))); +} + +template class SoftmaxCUDNNFunctor; +template class SoftmaxCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; +template class SoftmaxGradCUDNNFunctor; + template class SoftmaxFunctor; template class SoftmaxFunctor; template class SoftmaxGradFunctor; diff --git a/paddle/fluid/operators/math/softmax.h b/paddle/fluid/operators/math/softmax.h index 14b2690c2a4e764058270953214a07aee8053444..da1f0b672d3a5fb5da8f4d72892be21964bdbc0d 100644 --- a/paddle/fluid/operators/math/softmax.h +++ b/paddle/fluid/operators/math/softmax.h @@ -33,6 +33,23 @@ class SoftmaxGradFunctor { const framework::Tensor* y_grad, framework::Tensor* x_grad); }; +#ifdef PADDLE_WITH_CUDA +template +class SoftmaxCUDNNFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor* X, framework::Tensor* Y); +}; + +template +class SoftmaxGradCUDNNFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor* Y, const framework::Tensor* y_grad, + framework::Tensor* x_grad); +}; +#endif + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc b/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..5661f4b42f37fed7f589c515e25fd66cfcede2c7 --- /dev/null +++ b/paddle/fluid/operators/sequence_softmax_cudnn_op.cu.cc @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class SequenceSoftmaxCUDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = x->lod(); + auto dims = x->dims(); + + const size_t level = lod.size() - 1; + PADDLE_ENFORCE_EQ(dims[0], static_cast(lod[level].back()), + "The first dimension of Input(X) should be equal to the " + "sum of all sequences' lengths."); + PADDLE_ENFORCE_EQ(dims[0], x->numel(), + "The width of each timestep in Input(X) of " + "SequenceSoftmaxOp should be 1."); + + out->mutable_data(ctx.GetPlace()); + for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + Tensor x_i = x->Slice(start_pos, end_pos); + Tensor out_i = out->Slice(start_pos, end_pos); + + // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) + framework::DDim dims_i = + // framework::make_ddim({1UL, end_pos - start_pos, 1UL, 1UL}); + framework::make_ddim({1UL, end_pos - start_pos}); + x_i.Resize(dims_i); + out_i.Resize(dims_i); + math::SoftmaxCUDNNFunctor()( + ctx.template device_context(), &x_i, + &out_i); + } + } +}; + +template +class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* x = ctx.Input("X"); + auto* x_grad = ctx.Output(framework::GradVarName("X")); + + auto lod = x->lod(); + const size_t level = lod.size() - 1; + + x_grad->mutable_data(ctx.GetPlace()); + for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + + Tensor out_i = out->Slice(start_pos, end_pos); + Tensor out_grad_i = out_grad->Slice(start_pos, end_pos); + Tensor x_grad_i = x_grad->Slice(start_pos, end_pos); + + // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) + framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); + out_i.Resize(dims_i); + out_grad_i.Resize(dims_i); + x_grad_i.Resize(dims_i); + math::SoftmaxGradCUDNNFunctor()( + ctx.template device_context(), &out_i, + &out_grad_i, &x_grad_i); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(sequence_softmax, CUDNN, ::paddle::platform::CUDAPlace, + ops::SequenceSoftmaxCUDNNKernel, + ops::SequenceSoftmaxCUDNNKernel) +REGISTER_OP_KERNEL(sequence_softmax_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::SequenceSoftmaxGradCUDNNKernel, + ops::SequenceSoftmaxGradCUDNNKernel) diff --git a/paddle/fluid/operators/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_softmax_op.cc index 7e685eb3dc7b12ef38f06b37d99a1212cfbc992c..e8b4df04286d327f568f4c43886f9fcf89cc4a88 100644 --- a/paddle/fluid/operators/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_softmax_op.cc @@ -29,6 +29,29 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // choose cudnn kernel if the runtime supported. + bool use_cudnn = ctx.Attr("use_cudnn"); + bool runtime_cudnn_support = false; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = + ctx.template device_context(); + runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + } +#endif + framework::LibraryType library_ = framework::LibraryType::kPlain; + if (use_cudnn && runtime_cudnn_support) { + library_ = framework::LibraryType::kCUDNN; + } + std::string data_format = ctx.Attr("data_format"); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + framework::StringToDataLayout(data_format), library_); + } }; class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { @@ -41,6 +64,17 @@ class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension " "of length 1."); + AddAttr( + "use_cudnn", + "(bool, default false) Only used in cudnn kernel, need install cudnn") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); AddComment(R"DOC( Sequence Softmax Operator. @@ -91,6 +125,29 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // choose cudnn kernel if the runtime supported. + bool use_cudnn = ctx.Attr("use_cudnn"); + bool runtime_cudnn_support = false; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = + ctx.template device_context(); + runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + } +#endif + framework::LibraryType library_ = framework::LibraryType::kPlain; + if (use_cudnn && runtime_cudnn_support) { + library_ = framework::LibraryType::kCUDNN; + } + std::string data_format = ctx.Attr("data_format"); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + framework::StringToDataLayout(data_format), library_); + } }; } // namespace operators @@ -102,7 +159,9 @@ REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp, ops::SequenceSoftmaxGradOp); REGISTER_OP_CPU_KERNEL( sequence_softmax, - ops::SequenceSoftmaxKernel); + ops::SequenceSoftmaxKernel, + ops::SequenceSoftmaxKernel); REGISTER_OP_CPU_KERNEL( sequence_softmax_grad, - ops::SequenceSoftmaxGradKernel); + ops::SequenceSoftmaxGradKernel, + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/fluid/operators/sequence_softmax_op.cu.cc b/paddle/fluid/operators/sequence_softmax_op.cu.cc index 295c68c5b936d6522666a4cc4e621db6f5f5f3ed..57adea3a1b9dbcbb5787d005e4d3ec595f61d4b2 100644 --- a/paddle/fluid/operators/sequence_softmax_op.cu.cc +++ b/paddle/fluid/operators/sequence_softmax_op.cu.cc @@ -17,7 +17,10 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( sequence_softmax, - ops::SequenceSoftmaxKernel) + ops::SequenceSoftmaxKernel, + ops::SequenceSoftmaxKernel) REGISTER_OP_CUDA_KERNEL( sequence_softmax_grad, - ops::SequenceSoftmaxGradKernel); + ops::SequenceSoftmaxGradKernel, + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..47cb336d87f8627d86ac33d6ac32c04d5d93f753 --- /dev/null +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -0,0 +1,62 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SoftmaxCUDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Out = context.Output("Out"); + + // allocate memory on device. + Out->mutable_data(context.GetPlace()); + + math::SoftmaxCUDNNFunctor()( + context.template device_context(), X, Out); + } +}; + +template +class SoftmaxGradCUDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* Out = context.Input("Out"); + auto* dOut = context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + + // allocate memory on device. + dX->mutable_data(context.GetPlace()); + + math::SoftmaxGradCUDNNFunctor()( + context.template device_context(), Out, + dOut, dX); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace, + ops::SoftmaxCUDNNKernel); +REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace, + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 09275ef290e8c78dc0902033e904cc4e7ccd7adb..1b63f8a499e5d20d2f10c3cd1024d1bcf78764d4 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -33,6 +33,29 @@ class SoftmaxOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // choose cudnn kernel if the runtime supported. + bool use_cudnn = ctx.Attr("use_cudnn"); + bool runtime_cudnn_support = false; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = + ctx.template device_context(); + runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + } +#endif + framework::LibraryType library_ = framework::LibraryType::kPlain; + if (use_cudnn && runtime_cudnn_support) { + library_ = framework::LibraryType::kCUDNN; + } + std::string data_format = ctx.Attr("data_format"); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + framework::StringToDataLayout(data_format), library_); + } }; class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { @@ -43,6 +66,17 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { "The input tensor of softmax. " "2-D with shape [batch_size, input_feature_dimensions]."); AddOutput("Out", "The normalized values with the same shape as X."); + AddAttr( + "use_cudnn", + "(bool, default false) Only used in cudnn kernel, need install cudnn") + .SetDefault(false); + AddAttr( + "data_format", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); AddComment(R"DOC( Softmax Operator. @@ -80,6 +114,29 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // choose cudnn kernel if the runtime supported. + bool use_cudnn = ctx.Attr("use_cudnn"); + bool runtime_cudnn_support = false; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + auto& dev_ctx = + ctx.template device_context(); + runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + } +#endif + framework::LibraryType library_ = framework::LibraryType::kPlain; + if (use_cudnn && runtime_cudnn_support) { + library_ = framework::LibraryType::kCUDNN; + } + std::string data_format = ctx.Attr("data_format"); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + framework::StringToDataLayout(data_format), library_); + } }; } // namespace operators diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 1842ecd745e3f5cb75600ce00d89018f81682632..9a2ac3ff33df3f8b9e24203f9dba2130e1d16510 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -289,7 +289,7 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); #ifdef PADDLE_WITH_CUDA if (use_cudnn) { - auto& dev_ctx = ctx.template device_context(); + auto& dev_ctx = ctx.device_context(); use_cudnn &= dev_ctx.cudnn_handle() != nullptr; } #endif diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index ea189749bc6cc1e37c1efc6fea424143b887cecd..a889ab6bdc6ac9494ef992a97292b7a2536c41c4 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -132,7 +132,7 @@ def detection_output(loc, old_shape = scores.shape scores = ops.reshape(x=scores, shape=(-1, old_shape[-1])) - scores = ops.softmax(x=scores) + scores = nn.softmax(input=scores) scores = ops.reshape(x=scores, shape=old_shape) scores = nn.transpose(scores, perm=[0, 2, 1]) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f107261f3df78dfc7197d9719d0258b6ab09d487..bf161d6618b10da66f25d3f11300a4a2b10b875a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -39,6 +39,8 @@ __all__ = [ 'sequence_conv', 'conv2d', 'sequence_pool', + 'sequence_softmax', + 'softmax', 'pool2d', 'batch_norm', 'beam_search_decode', @@ -1085,6 +1087,30 @@ def sequence_conv(input, return helper.append_activation(pre_act) +def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): + helper = LayerHelper('sequence_softmax', **locals()) + dtype = helper.input_dtype() + softmax_out = helper.create_tmp_variable(dtype) + helper.append_op( + type="sequence_softmax", + inputs={"X": input}, + outputs={"Out": softmax_out}, + attrs={"use_cudnn": use_cudnn}) + return softmax_out + + +def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): + helper = LayerHelper('softmax', **locals()) + dtype = helper.input_dtype() + softmax_out = helper.create_tmp_variable(dtype) + helper.append_op( + type="softmax", + inputs={"X": input}, + outputs={"Out": softmax_out}, + attrs={"use_cudnn": use_cudnn}) + return softmax_out + + def conv2d(input, num_filters, filter_size, diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 14ad18d5085fb945646818cc679f088a43806a70..d7bad221c5fa7b18137bf317125195267437a644 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -45,13 +45,30 @@ __activations__ = [ ] __all__ = [ - 'mean', 'mul', 'reshape', 'scale', 'sigmoid_cross_entropy_with_logits', - 'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul', - 'elementwise_max', 'elementwise_min', 'elementwise_pow', 'clip', - 'clip_by_norm', 'softmax', 'sequence_softmax', 'logical_and', 'logical_or', - 'logical_xor', 'logical_not', 'uniform_random', - 'uniform_random_batch_size_like', 'gaussian_random', - 'gaussian_random_batch_size_like', 'cumsum', 'scatter' + 'mean', + 'mul', + 'reshape', + 'scale', + 'sigmoid_cross_entropy_with_logits', + 'elementwise_add', + 'elementwise_div', + 'elementwise_sub', + 'elementwise_mul', + 'elementwise_max', + 'elementwise_min', + 'elementwise_pow', + 'clip', + 'clip_by_norm', + 'logical_and', + 'logical_or', + 'logical_xor', + 'logical_not', + 'uniform_random', + 'uniform_random_batch_size_like', + 'gaussian_random', + 'gaussian_random_batch_size_like', + 'cumsum', + 'scatter', ] + __activations__ for _OP in set(__all__): diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 6944cca394fbc1ddde09dfeb0bc82e357a3cd225..90d70aa39fdc4d4d3f9062eb6a3eb0cdd014acfc 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -220,7 +220,7 @@ class TestBook(unittest.TestCase): seq_data = layers.data( name='seq_data', shape=[10, 10], dtype='float32', lod_level=1) seq = layers.fc(input=seq_data, size=20) - self.assertIsNotNone(layers.sequence_softmax(x=seq)) + self.assertIsNotNone(layers.sequence_softmax(seq)) print(str(program)) def test_softmax(self): @@ -228,7 +228,7 @@ class TestBook(unittest.TestCase): with program_guard(program): data = layers.data(name='data', shape=[10], dtype='float32') hid = layers.fc(input=data, size=20) - self.assertIsNotNone(layers.softmax(x=hid)) + self.assertIsNotNone(layers.softmax(hid)) print(str(program)) def test_get_places(self): diff --git a/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py b/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py index 9e5c1e7a3d0bdf514de11e797d7139f577002c52..d6dc99bb3106feee33daa52bffb386f07cc16de5 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_softmax_op.py @@ -16,11 +16,15 @@ import unittest import numpy as np from op_test import OpTest from test_softmax_op import stable_softmax +import paddle.fluid.core as core class TestSequenceSoftmaxOp(OpTest): def setUp(self): self.op_type = "sequence_softmax" + self.use_cudnn = False + self.init_op_type() + x = np.random.uniform(0.1, 1, (11, 1)).astype("float32") lod = [[0, 4, 5, 8, 11]] @@ -34,12 +38,31 @@ class TestSequenceSoftmaxOp(OpTest): self.inputs = {"X": (x, lod)} self.outputs = {"Out": out} + self.attrs = {'use_cudnn': self.use_cudnn, } + + def init_op_type(self): + pass def test_check_output(self): - self.check_output() + if self.use_cudnn: + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-5) + else: + self.check_output() def test_check_grad(self): - self.check_grad(["X"], "Out", max_relative_error=0.01) + if self.use_cudnn: + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ["X"], "Out", max_relative_error=0.01) + else: + self.check_grad(["X"], "Out", max_relative_error=0.01) + + +# ----------------cudnn Sequencesoftmax---------------- +class TestSequenceSoftmaxCUDNNOp(TestSequenceSoftmaxOp): + def init_op_type(self): + self.use_cudnn = True if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 8f8312edca7e2d98eb4e881f671c6afdda01c57a..4f20da2b926823db9e7ec92c95178b6d3d1feec9 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -15,6 +15,7 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core def stable_softmax(x): @@ -27,18 +28,37 @@ def stable_softmax(x): class TestSoftmaxOp(OpTest): def setUp(self): self.op_type = "softmax" + self.use_cudnn = False self.inputs = { 'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") } self.outputs = { 'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) } + self.attrs = {'use_cudnn': self.use_cudnn, } + + def init_op_type(self): + pass def test_check_output(self): - self.check_output() + if self.use_cudnn: + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-5) + else: + self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + if self.use_cudnn: + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ["X"], "Out", max_relative_error=0.01) + else: + self.check_grad(["X"], "Out", max_relative_error=0.01) + + +class TestSoftmaxCUDNNOp(TestSoftmaxOp): + def init_op_type(self): + self.use_cudnn = True if __name__ == "__main__":