未验证 提交 c5173591 编写于 作者: H HongyuJia 提交者: GitHub

change mkldnn interp to normal GetExpectedKernelType (#46685)

上级 b4d7ef9d
...@@ -342,10 +342,11 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -342,10 +342,11 @@ class InterpolateOp : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method // TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) && // (https://github.com/PaddlePaddle/Paddle/pull/30016/files)
(interp_method == "nearest" || interp_method == "bilinear")) { // NOTE(jiahy0825): currently only support interp_method = nearest or
// interp_method = bilinear
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type, return framework::OpKernelType(data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
......
...@@ -446,10 +446,11 @@ class InterpolateV2Op : public framework::OperatorWithKernel { ...@@ -446,10 +446,11 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method // TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) && // (https://github.com/PaddlePaddle/Paddle/pull/30016/files)
(interp_method == "nearest" || interp_method == "bilinear")) { // NOTE(jiahy0825): currently only support interp_method = nearest or
// interp_method = bilinear
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
return framework::OpKernelType(data_type, return framework::OpKernelType(data_type,
ctx.GetPlace(), ctx.GetPlace(),
framework::DataLayout::kMKLDNN, framework::DataLayout::kMKLDNN,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册