From 225c11a91fbb7c75e347854c6147225d61fc2385 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 27 Feb 2019 13:12:48 +0800 Subject: [PATCH] polish cudnn related code and fix bug. (#15164) * staged. * polish code * polish code. test=develop * polish code. test=develop * api change. test=develop * fix default value. test=develop * fix default value. test=develop --- cmake/operators.cmake | 4 + paddle/fluid/framework/executor.cc | 1 + paddle/fluid/operators/activation_cudnn.cu.cc | 40 ++++ .../fluid/operators/activation_cudnn_op.cu.cc | 175 ++++++++++++++ paddle/fluid/operators/activation_op.cc | 47 ++-- paddle/fluid/operators/activation_op.h | 214 +++++++++--------- paddle/fluid/platform/CMakeLists.txt | 1 + paddle/fluid/platform/cudnn_desc.h | 124 ++++++++++ paddle/fluid/platform/cudnn_desc_test.cc | 41 ++++ paddle/fluid/platform/dynload/cudnn.h | 1 + .../tests/unittests/test_activation_op.py | 23 ++ 11 files changed, 543 insertions(+), 128 deletions(-) create mode 100644 paddle/fluid/operators/activation_cudnn.cu.cc create mode 100644 paddle/fluid/operators/activation_cudnn_op.cu.cc create mode 100644 paddle/fluid/platform/cudnn_desc.h create mode 100644 paddle/fluid/platform/cudnn_desc_test.cc diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 4e8c49e62b5..11a5b1b4554 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -153,7 +153,11 @@ function(op_library TARGET) # pybind USE_OP_DEVICE_KERNEL for CUDNN list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len) if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0) + if(${TARGET} STREQUAL "activation") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, CUDNN);\n") + else() file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") + endif() endif() # pybind USE_OP_DEVICE_KERNEL for MIOPEN diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 4323883fa5c..c31d0beec30 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/distributed/distributed.h" diff --git a/paddle/fluid/operators/activation_cudnn.cu.cc b/paddle/fluid/operators/activation_cudnn.cu.cc new file mode 100644 index 00000000000..494c02374a9 --- /dev/null +++ b/paddle/fluid/operators/activation_cudnn.cu.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/platform/cudnn_desc.h" + +namespace paddle { +namespace operators { +using framework::Tensor; +using platform::ActivationDescriptor; +using platform::TensorDescriptor; + +template +class CudnnActivationKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + framework::Tensor *X, *Out; + ExtractActivationTensor(context, X, Out); + ActivationDescriptor act_desc; + TensorDescriptor x_desc, out_desc; + x_desc.set(detail::Ref(X)); + out_desc.set(detail::Ref(Out)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/activation_cudnn_op.cu.cc b/paddle/fluid/operators/activation_cudnn_op.cu.cc new file mode 100644 index 00000000000..a382414d5c4 --- /dev/null +++ b/paddle/fluid/operators/activation_cudnn_op.cu.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/platform/cudnn_desc.h" + +namespace paddle { +namespace operators { +using framework::Tensor; +using platform::ActivationDescriptor; +using platform::TensorDescriptor; +using platform::CUDADeviceContext; + +template +struct CudnnActivationFunctor { + using ELEMENT_TYPE = T; + CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c, + const cudnnActivationMode_t& m) + : ctx_(ctx), coef_(c), mode_(m) {} + void operator()(const Tensor& x, Tensor* out) { + ActivationDescriptor act_desc; + act_desc.set(mode_, coef_); + TensorDescriptor x_desc, out_desc; + x_desc.set(x); + out_desc.set(detail::Ref(out)); + PADDLE_ENFORCE(platform::dynload::cudnnActivationForward( + ctx_.cudnn_handle(), act_desc.desc(), + platform::CudnnDataType::kOne(), x_desc.desc(), x.data(), + platform::CudnnDataType::kZero(), out_desc.desc(), + out->mutable_data(ctx_.GetPlace()))); + } + const CUDADeviceContext& ctx_; + const T coef_; + const cudnnActivationMode_t mode_; +}; + +template +struct CudnnActivationGradFunctor { + using ELEMENT_TYPE = T; + CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c, + const cudnnActivationMode_t& m) + : ctx_(ctx), coef_(c), mode_(m) {} + void operator()(const Tensor& x, const Tensor& out, const Tensor dout, + Tensor* dx) { + ActivationDescriptor act_desc; + act_desc.set(mode_, coef_); + TensorDescriptor x_desc, out_desc, dout_desc, dx_desc; + x_desc.set(x); + out_desc.set(out); + dout_desc.set(dout); + dx_desc.set(detail::Ref(dx)); + PADDLE_ENFORCE(platform::dynload::cudnnActivationBackward( + ctx_.cudnn_handle(), act_desc.desc(), + platform::CudnnDataType::kOne(), out_desc.desc(), out.data(), + dout_desc.desc(), dout.data(), x_desc.desc(), x.data(), + platform::CudnnDataType::kZero(), dx_desc.desc(), + dx->mutable_data(ctx_.GetPlace()))); + } + const CUDADeviceContext& ctx_; + const T coef_; + const cudnnActivationMode_t mode_; +}; + +template +struct CudnnReluFunctor : public CudnnActivationFunctor { + explicit CudnnReluFunctor(const CUDADeviceContext& ctx) + : CudnnActivationFunctor(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} +}; +template +struct CudnnReluGradFunctor : public CudnnActivationGradFunctor { + explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx) + : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} +}; + +template +struct CudnnRelu6Functor : public CudnnActivationFunctor { + explicit CudnnRelu6Functor(const CUDADeviceContext& ctx) + : CudnnActivationFunctor(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {} +}; +template +struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor { + explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx) + : CudnnActivationGradFunctor(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) { + } +}; + +template +struct CudnnSigmoidFunctor : public CudnnActivationFunctor { + explicit CudnnSigmoidFunctor(const CUDADeviceContext& ctx) + : CudnnActivationFunctor(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} +}; +template +struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor { + explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx) + : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} +}; + +template +struct CudnnTanhFunctor : public CudnnActivationFunctor { + explicit CudnnTanhFunctor(const CUDADeviceContext& ctx) + : CudnnActivationFunctor(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} +}; +template +struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor { + explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx) + : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} +}; + +template +class CudnnActivationKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* X = nullptr; + framework::Tensor* Out = nullptr; + ExtractActivationTensor(context, &X, &Out); + Out->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + Functor functor(dev_ctx); + functor(detail::Ref(X), Out); + } +}; + +template +class CudnnActivationGradKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor *X, *Out, *dOut; + X = Out = dOut = nullptr; + framework::Tensor* dX = nullptr; + ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); + dX->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + Functor functor(dev_ctx); + functor(detail::Ref(X), detail::Ref(Out), detail::Ref(dOut), dX); + } +}; + +} // namespace operators +} // namespace paddle + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +#define FOR_EACH_CUDNN_OP_FUNCTOR(__macro) \ + __macro(relu, CudnnReluFunctor, CudnnReluGradFunctor); \ + __macro(relu6, CudnnRelu6Functor, CudnnRelu6GradFunctor); \ + __macro(sigmoid, CudnnTanhFunctor, CudnnTanhGradFunctor); \ + __macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor) + +#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \ + REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace, \ + ops::CudnnActivationKernel>, \ + ops::CudnnActivationKernel>); \ + REGISTER_OP_KERNEL( \ + act_type##_grad, CUDNN, plat::CUDAPlace, \ + ops::CudnnActivationGradKernel>, \ + ops::CudnnActivationGradKernel>); + +FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL); diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 65efe2966ce..2feb8e4c478 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -16,29 +16,36 @@ limitations under the License. */ #include #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/platform/port.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif namespace paddle { namespace operators { using paddle::framework::Tensor; -#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ - class OP_NAME##OpMaker \ - : public ::paddle::framework::OpProtoAndCheckerMaker { \ - public: \ - void Make() override { \ - AddInput("X", "Input of " #OP_NAME " operator"); \ - AddOutput("Out", "Output of " #OP_NAME " operator"); \ - AddAttr("use_mkldnn", \ - "(bool, default false) Only used in mkldnn kernel") \ - .SetDefault(false); \ - AddAttr( \ - "is_test", \ - "(bool, default false) Set to true for inference only, false " \ - "for training. Some layers may run faster when this is true.") \ - .SetDefault(false); \ - AddComment(OP_COMMENT); \ - } \ +#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ + class OP_NAME##OpMaker \ + : public ::paddle::framework::OpProtoAndCheckerMaker { \ + public: \ + void Make() override { \ + AddInput("X", "Input of " #OP_NAME " operator"); \ + AddOutput("Out", "Output of " #OP_NAME " operator"); \ + AddAttr("use_mkldnn", \ + "(bool, default false) Only used in mkldnn kernel") \ + .SetDefault(false); \ + AddAttr("use_cudnn", \ + "(bool, default false) Only used in cudnn kernel, need " \ + "install cudnn") \ + .SetDefault(false); \ + AddAttr( \ + "is_test", \ + "(bool, default false) Set to true for inference only, false " \ + "for training. Some layers may run faster when this is true.") \ + .SetDefault(false); \ + AddComment(OP_COMMENT); \ + } \ } #define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ @@ -67,6 +74,12 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, const std::string& name) { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout = framework::DataLayout::kAnyLayout; +#ifdef PADDLE_WITH_CUDA + auto it1 = oper.Attrs().find("use_cudnn"); + if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) { + library = framework::LibraryType::kCUDNN; + } +#endif #ifdef PADDLE_WITH_MKLDNN auto it = oper.Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index c7df3ea58a9..0f640601367 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -41,53 +41,115 @@ static std::unordered_set InplaceOpSet = { "floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid", }; +static bool IsInplace(const std::string& op) { + bool inplace = InplaceOpSet.count(op); + // for op_grad + const int kGradSuffixLen = 4; + if (op.size() > kGradSuffixLen && + op.compare(op.size() - kGradSuffixLen - 1, kGradSuffixLen, "grad")) { + inplace = + InplaceOpSet.count(op.substr(0, op.size() - (kGradSuffixLen + 1))); + } + return inplace; +} + /* The following operator can be used to process SelectedRows, because the * output of those operator for zero is zero too. */ static std::unordered_set CanBeUsedBySelectedRows = { "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"}; -static bool IsInplace(std::string op) { return InplaceOpSet.count(op); } - -template -class ActivationKernel - : public framework::OpKernel { - public: - using T = typename Functor::ELEMENT_TYPE; - - void Compute(const framework::ExecutionContext& context) const override { +inline void ExtractActivationTensor(const framework::ExecutionContext& context, + const framework::Tensor** X, + framework::Tensor** Out) { + auto x_var = context.InputVar("X"); + auto out_var = context.OutputVar("Out"); + PADDLE_ENFORCE(x_var != nullptr, + "Cannot get input Variable X, variable name = %s", + context.op().Input("X")); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot get output Variable Out, variable name = %s", + context.op().Output("Out")); + if (CanBeUsedBySelectedRows.count(context.op().Type())) { + *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); + *Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( + out_var); + } else { + *X = context.Input("X"); + *Out = context.Output("Out"); + } + + PADDLE_ENFORCE(*Out != nullptr, + "Cannot get output tensor Out, variable name = %s", + context.op().Output("Out")); +} + +inline void ExtractActivationGradTensor( + const framework::ExecutionContext& context, const framework::Tensor** X, + const framework::Tensor** Out, const framework::Tensor** dOut, + framework::Tensor** dX) { + auto out_var = context.InputVar("Out"); + auto out_grad_var = context.InputVar(framework::GradVarName("Out")); + auto x_grad_var = context.OutputVar(framework::GradVarName("X")); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot get input Variable Out, variable name = %s", + context.op().Input("Out")); + PADDLE_ENFORCE(out_grad_var != nullptr, + "Cannot get input Variable %s, variable name = %s", + framework::GradVarName("Out"), + context.op().Input(framework::GradVarName("Out"))); + PADDLE_ENFORCE(x_grad_var != nullptr, + "Cannot get output Variable %s, variable name = %s", + framework::GradVarName("X"), + context.op().Output(framework::GradVarName("X"))); + + if (CanBeUsedBySelectedRows.count(context.op().Type())) { + *Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); + *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( + *out_grad_var); + *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( + x_grad_var); + } else { + *Out = context.Input("Out"); + *dOut = context.Input(framework::GradVarName("Out")); + *dX = context.Output(framework::GradVarName("X")); + } + PADDLE_ENFORCE(*dX != nullptr, + "Cannot get output tensor %s, variable name = %s", + framework::GradVarName("X"), + context.op().Output(framework::GradVarName("X"))); + + bool inplace = IsInplace(context.op().Type()); + if (!inplace) { auto x_var = context.InputVar("X"); - auto out_var = context.OutputVar("Out"); PADDLE_ENFORCE(x_var != nullptr, - "Cannot get input Variable X, variable name = %s", + "Cannot get input tensor X, variable name = %s", context.op().Input("X")); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot get output Variable Out, variable name = %s", - context.op().Output("Out")); - - framework::Tensor X, *Out; - if (CanBeUsedBySelectedRows.count(context.op().Type())) { - X = detail::Ref( - paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var), - "Cannot get input Tensor X, variable name = %s", - context.op().Input("X")); - Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( - out_var); + *X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var); } else { - X = detail::Ref(context.Input("X"), - "Cannot get input Tensor X, variable name = %s", - context.op().Input("X")); - Out = context.Output("Out"); + *X = context.Input("X"); } + } else { + VLOG(10) << " Inplace activation of Op : " << context.op().Type(); + *X = *dX; + } +} - PADDLE_ENFORCE(Out != nullptr, - "Cannot get output tensor Out, variable name = %s", - context.op().Output("Out")); +template +class ActivationKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* X = nullptr; + framework::Tensor* Out = nullptr; + ExtractActivationTensor(context, &X, &Out); Out->mutable_data(context.GetPlace()); - auto x = framework::EigenVector::Flatten(X); - auto out = framework::EigenVector::Flatten(*Out); + + auto x = framework::EigenVector::Flatten(detail::Ref(X)); + auto out = framework::EigenVector::Flatten(detail::Ref(Out)); auto* place = context.template device_context().eigen_device(); Functor functor; @@ -106,55 +168,15 @@ class ActivationGradKernel public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { - auto out_var = context.InputVar("Out"); - auto out_grad_var = context.InputVar(framework::GradVarName("Out")); - auto x_grad_var = context.OutputVar(framework::GradVarName("X")); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot get input Variable Out, variable name = %s", - context.op().Input("Out")); - PADDLE_ENFORCE(out_grad_var != nullptr, - "Cannot get input Variable %s, variable name = %s", - framework::GradVarName("Out"), - context.op().Input(framework::GradVarName("Out"))); - PADDLE_ENFORCE(x_grad_var != nullptr, - "Cannot get output Variable %s, variable name = %s", - framework::GradVarName("X"), - context.op().Output(framework::GradVarName("X"))); - - framework::Tensor Out, dOut, *dX; - if (CanBeUsedBySelectedRows.count(context.op().Type())) { - Out = detail::Ref( - paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var), - "Cannot get input Tensor Out, variable name = %s", - context.op().Input("Out")); - dOut = - detail::Ref(paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( - *out_grad_var), - "Cannot get input Tensor %s, variable name = %s", - framework::GradVarName("Out"), - context.op().Input(framework::GradVarName("Out"))); - dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( - x_grad_var); - } else { - Out = detail::Ref(context.Input("Out"), - "Cannot get input Tensor Out, variable name = %s", - context.op().Input("Out")); - dOut = detail::Ref( - context.Input(framework::GradVarName("Out")), - "Cannot get input Tensor %s, variable name = %s", - framework::GradVarName("Out"), - context.op().Input(framework::GradVarName("Out"))); - dX = context.Output(framework::GradVarName("X")); - } - PADDLE_ENFORCE(dX != nullptr, - "Cannot get output tensor %s, variable name = %s", - framework::GradVarName("X"), - context.op().Output(framework::GradVarName("X"))); + const framework::Tensor *X, *Out, *dOut; + framework::Tensor* dX = nullptr; + X = Out = dOut = nullptr; + ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); dX->mutable_data(context.GetPlace()); - - auto dout = framework::EigenVector::Flatten(dOut); - auto out = framework::EigenVector::Flatten(Out); - auto dx = framework::EigenVector::Flatten(*dX); + auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); + auto out = framework::EigenVector::Flatten(detail::Ref(Out)); + auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); + auto x = framework::EigenVector::Flatten(detail::Ref(X)); auto* place = context.template device_context().eigen_device(); Functor functor; @@ -162,27 +184,7 @@ class ActivationGradKernel for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } - bool inplace = functor.Inplace(); - if (!inplace) { - auto x_var = context.InputVar("X"); - PADDLE_ENFORCE(x_var != nullptr, - "Cannot get input tensor X, variable name = %s", - context.op().Input("X")); - framework::Tensor X; - if (CanBeUsedBySelectedRows.count(context.op().Type())) { - X = detail::Ref( - paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var)); - } else { - X = detail::Ref(context.Input("X")); - } - - auto x = framework::EigenVector::Flatten(X); - functor(*place, x, out, dout, dx); - } else { - VLOG(10) << " Inplace activation "; - auto x = framework::EigenVector::Flatten(*dX); - functor(*place, x, out, dout, dx); - } + functor(*place, x, out, dout, dx); } }; @@ -214,7 +216,6 @@ struct SigmoidFunctor : public BaseActivationFunctor { template struct SigmoidGradFunctor : public BaseActivationFunctor { - bool Inplace() const { return IsInplace("sigmoid"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -269,7 +270,6 @@ struct ExpFunctor : public BaseActivationFunctor { template struct ExpGradFunctor : public BaseActivationFunctor { - bool Inplace() const { return IsInplace("exp"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -288,7 +288,6 @@ struct ReluFunctor : public BaseActivationFunctor { template struct ReluGradFunctor : public BaseActivationFunctor { - bool Inplace() const { return IsInplace("relu"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -331,7 +330,6 @@ struct TanhFunctor : public BaseActivationFunctor { template struct TanhGradFunctor : public BaseActivationFunctor { - bool Inplace() const { return IsInplace("tanh"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -437,7 +435,6 @@ struct SqrtFunctor : public BaseActivationFunctor { template struct SqrtGradFunctor : public BaseActivationFunctor { - bool Inplace() const { return IsInplace("sqrt"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -456,7 +453,6 @@ struct CeilFunctor : public BaseActivationFunctor { template struct ZeroGradFunctor : public BaseActivationFunctor { - bool Inplace() const { return IsInplace("ceil"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -573,7 +569,6 @@ struct ReciprocalFunctor : public BaseActivationFunctor { template struct ReciprocalGradFunctor : public BaseActivationFunctor { - bool Inplace() const { return IsInplace("reciprocal"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -673,7 +668,6 @@ struct Relu6GradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } - bool Inplace() const { return IsInplace("relu6"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -755,7 +749,6 @@ struct SoftReluGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } - bool Inplace() const { return IsInplace("soft_relu"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { @@ -936,7 +929,6 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"slope", &slope}, {"offset", &offset}}; } - bool Inplace() { return IsInplace("hard_sigmoid"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 1838506c893..9220d35707b 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -82,6 +82,7 @@ nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_ cc_test(init_test SRCS init_test.cc DEPS device_context) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) +nv_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) cc_library(timer SRCS timer.cc) diff --git a/paddle/fluid/platform/cudnn_desc.h b/paddle/fluid/platform/cudnn_desc.h new file mode 100644 index 00000000000..1062b403f28 --- /dev/null +++ b/paddle/fluid/platform/cudnn_desc.h @@ -0,0 +1,124 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace platform { +using framework::Tensor; + +template +cudnnDataType_t ToCudnnDataType(const T& t) { + auto type = framework::ToDataType(t); + return ToCudnnDataType(type); +} + +template <> +cudnnDataType_t ToCudnnDataType(const framework::proto::VarType::Type& t) { + cudnnDataType_t type = CUDNN_DATA_FLOAT; + switch (t) { + case framework::proto::VarType::FP16: + type = CUDNN_DATA_HALF; + break; + case framework::proto::VarType::FP32: + type = CUDNN_DATA_FLOAT; + break; + case framework::proto::VarType::FP64: + type = CUDNN_DATA_DOUBLE; + break; + default: + break; + } + return type; +} + +class ActivationDescriptor { + public: + using T = cudnnActivationStruct; + struct Deleter { + void operator()(T* t) { + if (t != nullptr) { + PADDLE_ENFORCE(dynload::cudnnDestroyActivationDescriptor(t)); + t = nullptr; + } + } + }; + ActivationDescriptor() { + T* raw_ptr; + PADDLE_ENFORCE(dynload::cudnnCreateActivationDescriptor(&raw_ptr)); + desc_.reset(raw_ptr); + } + template + void set(cudnnActivationMode_t mode, const T& coef) { + CUDNN_ENFORCE(dynload::cudnnSetActivationDescriptor( + desc_.get(), mode, CUDNN_NOT_PROPAGATE_NAN, static_cast(coef))); + } + + T* desc() { return desc_.get(); } + T* desc() const { return desc_.get(); } + + private: + std::unique_ptr desc_; +}; + +class TensorDescriptor { + public: + using T = cudnnTensorStruct; + struct Deleter { + void operator()(T* t) { + if (t != nullptr) { + PADDLE_ENFORCE(dynload::cudnnDestroyTensorDescriptor(t)); + t = nullptr; + } + } + }; + TensorDescriptor() { + T* raw_ptr; + PADDLE_ENFORCE(dynload::cudnnCreateTensorDescriptor(&raw_ptr)); + desc_.reset(raw_ptr); + } + T* desc() { return desc_.get(); } + T* desc() const { return desc_.get(); } + void set(const Tensor& tensor, const int groups = 1) { + auto dims = framework::vectorize2int(tensor.dims()); + std::vector strides(dims.size()); + strides[dims.size() - 1] = 1; + for (int i = dims.size() - 2; i >= 0; i--) { + strides[i] = dims[i + 1] * strides[i + 1]; + } + std::vector dims_with_group(dims.begin(), dims.end()); + if (groups > 1) { + dims_with_group[1] = dims_with_group[1] / groups; + } + PADDLE_ENFORCE(dynload::cudnnSetTensorNdDescriptor( + desc_.get(), ToCudnnDataType(tensor.type()), dims_with_group.size(), + dims_with_group.data(), strides.data())); + } + + private: + std::unique_ptr desc_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/cudnn_desc_test.cc b/paddle/fluid/platform/cudnn_desc_test.cc new file mode 100644 index 00000000000..a60102a5489 --- /dev/null +++ b/paddle/fluid/platform/cudnn_desc_test.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/cudnn_desc.h" +#include + +namespace paddle { +namespace platform { + +TEST(TensorDescriptor, Empty) { + ActivationDescriptor a; + TensorDescriptor t; + TensorDescriptor t1; + TensorDescriptor *t11 = new TensorDescriptor(); + delete t11; + std::unique_ptr tt(new TensorDescriptor()); +} + +TEST(TensorDescriptor, Normal) { + framework::Tensor tt; + tt.Resize({2, 3, 4}); + tt.mutable_data(platform::CPUPlace()); + + TensorDescriptor desc; + desc.set(tt); + EXPECT_TRUE(desc.desc() != nullptr); +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 2f4f8101e4b..3008c166938 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -99,6 +99,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnDestroy); \ __macro(cudnnSetStream); \ __macro(cudnnActivationForward); \ + __macro(cudnnActivationBackward); \ __macro(cudnnConvolutionForward); \ __macro(cudnnConvolutionBackwardBias); \ __macro(cudnnGetConvolutionForwardWorkspaceSize); \ diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 55c43ef115a..d5a83854099 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -26,6 +26,7 @@ class TestActivation(OpTest): self.op_type = "exp" self.dtype = np.float32 self.init_dtype() + self.init_kernel_type() x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) out = np.exp(x) @@ -44,6 +45,9 @@ class TestActivation(OpTest): def init_dtype(self): self.dtype = np.float32 + def init_kernel_type(self): + pass + class TestSigmoid(TestActivation): def setUp(self): @@ -601,6 +605,25 @@ class TestSwish(TestActivation): self.check_grad(['X'], 'Out', max_relative_error=0.008) +#------------------ Test Cudnn Activation---------------------- +def create_test_act_cudnn_class(parent, atol=1e-3, grad_atol=1e-3): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestActCudnn(parent): + def init_kernel_type(self): + self.attrs = {"use_cudnn": True} + + cls_name = "{0}_{1}".format(parent.__name__, "cudnn") + TestActCudnn.__name__ = cls_name + globals()[cls_name] = TestActCudnn + + +create_test_act_cudnn_class(TestRelu) +create_test_act_cudnn_class(TestRelu6) +create_test_act_cudnn_class(TestSigmoid) +create_test_act_cudnn_class(TestTanh) + + #------------------ Test Fp16 ---------------------- def create_test_act_fp16_class(parent, atol=1e-3, -- GitLab