未验证 提交 b4d7ef9d 编写于 作者: H HongyuJia 提交者: GitHub

[Opt Code] Opt GetExpectedKernelType code of conv_op (#46681)

* refine conv_op mkldnn code

* fix customized_type_value
上级 d6c69d7c
......@@ -189,35 +189,9 @@ std::vector<int64_t> ConvOp::ComputeOutputShape(
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
std::string data_format =
"AnyLayout"; // todo enable data layout when it's ready
framework::DataLayout layout = framework::StringToDataLayout(data_format);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
customized_type_value =
(input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType())
? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") ==
framework::DataTypeTrait<int8_t>::DataType()
? kConvMKLDNNINT8WS8
: kConvMKLDNNINT8
: kConvMKLDNNFP32;
}
#endif
// todo enable data layout when it's ready
// (https://github.com/PaddlePaddle/Paddle/pull/20042)
if (input_data_type != framework::proto::VarType::INT8 &&
input_data_type != framework::proto::VarType::UINT8 &&
......@@ -234,17 +208,11 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
paddle::framework::DataTypeToString(input_data_type),
paddle::framework::DataTypeToString(filter_data_type)));
}
// #ifndef PADDLE_WITH_ASCEND_CL
// if (input_data_type == framework::proto::VarType::FP16) {
// PADDLE_ENFORCE_EQ(
// library, framework::LibraryType::kCUDNN,
// platform::errors::InvalidArgument(
// "float16 can only be used when CUDNN or NPU is used"));
// }
// #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
#if PADDLE_WITH_CUDA
if (input_data_type == framework::proto::VarType::BF16 &&
library == framework::LibraryType::kCUDNN) {
if (input_data_type == framework::proto::VarType::BF16) {
PADDLE_ENFORCE_GE(
platform::DnnVersion(),
8100,
......@@ -252,10 +220,41 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"bfloat16 can only be used when CUDNN_VERSION >= 8100"));
}
#endif // PADDLE_WITH_CUDA
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
auto type = framework::OpKernelType(
input_data_type, ctx.GetPlace(), layout, library, customized_type_value);
return type;
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
int customized_type_value =
(input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType())
? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") ==
framework::DataTypeTrait<int8_t>::DataType()
? kConvMKLDNNINT8WS8
: kConvMKLDNNINT8
: kConvMKLDNNFP32;
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
}
#endif
// #ifndef PADDLE_WITH_ASCEND_CL
// if (input_data_type == framework::proto::VarType::FP16) {
// PADDLE_ENFORCE_EQ(
// library, framework::LibraryType::kCUDNN,
// platform::errors::InvalidArgument(
// "float16 can only be used when CUDNN or NPU is used"));
// }
// #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType ConvOp::GetKernelTypeForVar(
......@@ -502,32 +501,28 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library_{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
const std::string data_format = ctx.Attr<std::string>("data_format");
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
kConvMKLDNNFP32);
}
#endif
auto type = framework::OpKernelType(
data_type, ctx.GetPlace(), layout_, library_, customized_type_value);
return type;
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType ConvOpGrad::GetKernelTypeForVar(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册