未验证 提交 ff07f8a2 编写于 作者: H HongyuJia 提交者: GitHub

[CUDNN hardcode] Opt CUDNN hardcode of sequence_softmax (#47319)

* opt cudnn hardcode of sequence_softmax

* fix grad datatype
上级 98beb5af
...@@ -16,6 +16,10 @@ limitations under the License. */ ...@@ -16,6 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -34,28 +38,20 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -34,28 +38,20 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
bool use_cudnn = phi::DataLayout layout_ = DataLayout::kAnyLayout;
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false; if (ctx.HasAttr("data_format")) {
bool runtime_cudnn_support = false; layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); return framework::OpKernelType(input_data_type,
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
} }
#endif #endif
framework::LibraryType library_ = framework::LibraryType::kPlain; return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.HasAttr("data_format")
? ctx.Attr<std::string>("data_format")
: "AnyLayout";
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::StringToDataLayout(data_format),
library_);
} }
}; };
...@@ -134,28 +130,20 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { ...@@ -134,28 +130,20 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");
bool use_cudnn = phi::DataLayout layout_ = DataLayout::kAnyLayout;
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false; if (ctx.HasAttr("data_format")) {
bool runtime_cudnn_support = false; layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); return framework::OpKernelType(input_data_type,
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false; ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
} }
#endif #endif
framework::LibraryType library_ = framework::LibraryType::kPlain; return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.HasAttr("data_format")
? ctx.Attr<std::string>("data_format")
: "AnyLayout";
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Out"),
ctx.GetPlace(),
phi::StringToDataLayout(data_format),
library_);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册