diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index d68b91544ecad4c53bce0146dc89cfbb0778a338..5b4b9aef88637326d892e222ef3ec7e2ccf5084c 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -16,6 +16,10 @@ limitations under the License. */ #include +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#endif + namespace paddle { namespace operators { @@ -34,28 +38,20 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - // choose cudnn kernel if the runtime supported. - bool use_cudnn = - ctx.HasAttr("use_cudnn") ? ctx.Attr("use_cudnn") : false; - bool runtime_cudnn_support = false; + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + phi::DataLayout layout_ = DataLayout::kAnyLayout; + if (ctx.HasAttr("data_format")) { + layout_ = phi::StringToDataLayout(ctx.Attr("data_format")); + } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = ctx.template device_context(); - runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + if (platform::CanCUDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + layout_, + framework::LibraryType::kCUDNN); } #endif - framework::LibraryType library_ = framework::LibraryType::kPlain; - if (use_cudnn && runtime_cudnn_support) { - library_ = framework::LibraryType::kCUDNN; - } - std::string data_format = ctx.HasAttr("data_format") - ? ctx.Attr("data_format") - : "AnyLayout"; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace(), - phi::StringToDataLayout(data_format), - library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); } }; @@ -134,28 +130,20 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - // choose cudnn kernel if the runtime supported. - bool use_cudnn = - ctx.HasAttr("use_cudnn") ? ctx.Attr("use_cudnn") : false; - bool runtime_cudnn_support = false; + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); + phi::DataLayout layout_ = DataLayout::kAnyLayout; + if (ctx.HasAttr("data_format")) { + layout_ = phi::StringToDataLayout(ctx.Attr("data_format")); + } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = ctx.template device_context(); - runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; + if (platform::CanCUDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + layout_, + framework::LibraryType::kCUDNN); } #endif - framework::LibraryType library_ = framework::LibraryType::kPlain; - if (use_cudnn && runtime_cudnn_support) { - library_ = framework::LibraryType::kCUDNN; - } - std::string data_format = ctx.HasAttr("data_format") - ? ctx.Attr("data_format") - : "AnyLayout"; - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Out"), - ctx.GetPlace(), - phi::StringToDataLayout(data_format), - library_); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); } };