From b4d7ef9dea7928418a3428b5d452a7b05c956ff3 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 11 Oct 2022 14:35:57 +0800 Subject: [PATCH] [Opt Code] Opt GetExpectedKernelType code of conv_op (#46681) * refine conv_op mkldnn code * fix customized_type_value --- paddle/fluid/operators/conv_op.cc | 119 ++++++++++++++---------------- 1 file changed, 57 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index f6132910712..ce335cff52d 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -189,35 +189,9 @@ std::vector ConvOp::ComputeOutputShape( framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - framework::LibraryType library{framework::LibraryType::kPlain}; - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - std::string data_format = - "AnyLayout"; // todo enable data layout when it's ready - framework::DataLayout layout = framework::StringToDataLayout(data_format); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::CanCUDNNBeUsed(ctx)) { - library = framework::LibraryType::kCUDNN; - } -#endif -#ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; - customized_type_value = - (input_data_type == framework::DataTypeTrait::DataType() || - input_data_type == framework::DataTypeTrait::DataType()) - ? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") == - framework::DataTypeTrait::DataType() - ? kConvMKLDNNINT8WS8 - : kConvMKLDNNINT8 - : kConvMKLDNNFP32; - } -#endif + // todo enable data layout when it's ready + // (https://github.com/PaddlePaddle/Paddle/pull/20042) if (input_data_type != framework::proto::VarType::INT8 && input_data_type != framework::proto::VarType::UINT8 && @@ -234,28 +208,53 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( paddle::framework::DataTypeToString(input_data_type), paddle::framework::DataTypeToString(filter_data_type))); } -// #ifndef PADDLE_WITH_ASCEND_CL -// if (input_data_type == framework::proto::VarType::FP16) { -// PADDLE_ENFORCE_EQ( -// library, framework::LibraryType::kCUDNN, -// platform::errors::InvalidArgument( -// "float16 can only be used when CUDNN or NPU is used")); -// } -// #endif + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::CanCUDNNBeUsed(ctx)) { #if PADDLE_WITH_CUDA - if (input_data_type == framework::proto::VarType::BF16 && - library == framework::LibraryType::kCUDNN) { - PADDLE_ENFORCE_GE( - platform::DnnVersion(), - 8100, - platform::errors::InvalidArgument( - "bfloat16 can only be used when CUDNN_VERSION >= 8100")); - } + if (input_data_type == framework::proto::VarType::BF16) { + PADDLE_ENFORCE_GE( + platform::DnnVersion(), + 8100, + platform::errors::InvalidArgument( + "bfloat16 can only be used when CUDNN_VERSION >= 8100")); + } #endif // PADDLE_WITH_CUDA + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kAnyLayout, + framework::LibraryType::kCUDNN); + } +#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP - auto type = framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library, customized_type_value); - return type; +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + int customized_type_value = + (input_data_type == framework::DataTypeTrait::DataType() || + input_data_type == framework::DataTypeTrait::DataType()) + ? OperatorWithKernel::IndicateVarDataType(ctx, "Filter") == + framework::DataTypeTrait::DataType() + ? kConvMKLDNNINT8WS8 + : kConvMKLDNNINT8 + : kConvMKLDNNFP32; + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); + } +#endif + + // #ifndef PADDLE_WITH_ASCEND_CL + // if (input_data_type == framework::proto::VarType::FP16) { + // PADDLE_ENFORCE_EQ( + // library, framework::LibraryType::kCUDNN, + // platform::errors::InvalidArgument( + // "float16 can only be used when CUDNN or NPU is used")); + // } + // #endif + + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType ConvOp::GetKernelTypeForVar( @@ -502,32 +501,28 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - std::string data_format = "AnyLayout"; - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::CanCUDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kCUDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kAnyLayout, + framework::LibraryType::kCUDNN); } #endif #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - const std::string data_format = ctx.Attr("data_format"); - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; - customized_type_value = kConvMKLDNNFP32; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + kConvMKLDNNFP32); } #endif - auto type = framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_, customized_type_value); - return type; + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType ConvOpGrad::GetKernelTypeForVar( -- GitLab