From c5173591c0dec75e829001bec50f5d65fb032cb3 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 11 Oct 2022 14:36:41 +0800 Subject: [PATCH] change mkldnn interp to normal GetExpectedKernelType (#46685) --- paddle/fluid/operators/interpolate_op.cc | 7 ++++--- paddle/fluid/operators/interpolate_v2_op.cc | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 056b81fd9a2..ac50da83e6b 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -342,10 +342,11 @@ class InterpolateOp : public framework::OperatorWithKernel { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - const auto& interp_method = ctx.Attr("interp_method"); // TODO(danqing): support other interp_method - if (this->CanMKLDNNBeUsed(ctx, data_type) && - (interp_method == "nearest" || interp_method == "bilinear")) { + // (https://github.com/PaddlePaddle/Paddle/pull/30016/files) + // NOTE(jiahy0825): currently only support interp_method = nearest or + // interp_method = bilinear + if (this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index e7a362f543b..e9d0d718b9f 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -446,10 +446,11 @@ class InterpolateV2Op : public framework::OperatorWithKernel { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - const auto& interp_method = ctx.Attr("interp_method"); // TODO(danqing): support other interp_method - if (this->CanMKLDNNBeUsed(ctx, data_type) && - (interp_method == "nearest" || interp_method == "bilinear")) { + // (https://github.com/PaddlePaddle/Paddle/pull/30016/files) + // NOTE(jiahy0825): currently only support interp_method = nearest or + // interp_method = bilinear + if (this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, -- GitLab