未验证 提交 ff07f8a2 编写于 作者: H HongyuJia 提交者: GitHub

[CUDNN hardcode] Opt CUDNN hardcode of sequence_softmax (#47319)

* opt cudnn hardcode of sequence_softmax

* fix grad datatype
上级 98beb5af
......@@ -16,6 +16,10 @@ limitations under the License. */
#include <string>
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle {
namespace operators {
......@@ -34,28 +38,20 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn =
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
bool runtime_cudnn_support = false;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
phi::DataLayout layout_ = DataLayout::kAnyLayout;
if (ctx.HasAttr("data_format")) {
layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.HasAttr("data_format")
? ctx.Attr<std::string>("data_format")
: "AnyLayout";
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::StringToDataLayout(data_format),
library_);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
}
};
......@@ -134,28 +130,20 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn =
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
bool runtime_cudnn_support = false;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");
phi::DataLayout layout_ = DataLayout::kAnyLayout;
if (ctx.HasAttr("data_format")) {
layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.HasAttr("data_format")
? ctx.Attr<std::string>("data_format")
: "AnyLayout";
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Out"),
ctx.GetPlace(),
phi::StringToDataLayout(data_format),
library_);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册