From d6c69d7c8f297baaf0a87cd41d06097e6fbd0789 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 11 Oct 2022 14:34:21 +0800 Subject: [PATCH] [Opt Code] Opt GetExpectedKernelType code of sum (#46678) * refine sum_op mkldnn code * refine sum_op mkldnn code --- paddle/fluid/operators/sum_op.cc | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 8cf6a095e2..9390158904 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -39,9 +39,6 @@ class SumOp : public framework::OperatorWithKernel { auto x_vars = ctx.MultiInputVar("X"); auto x_vars_name = ctx.InputNames("X"); - framework::LibraryType library{framework::LibraryType::kPlain}; - framework::DataLayout layout{framework::DataLayout::kAnyLayout}; - PADDLE_ENFORCE_GT( x_vars.size(), 0, @@ -80,8 +77,7 @@ class SumOp : public framework::OperatorWithKernel { auto data_type = static_cast(dtype); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type) && + if (this->CanMKLDNNBeUsed(ctx, data_type) && (data_type == framework::proto::VarType::FP32 || data_type == framework::proto::VarType::BF16) && ctx.OutputVar("Out")->IsType()) { @@ -96,25 +92,19 @@ class SumOp : public framework::OperatorWithKernel { } } #endif - - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } else if (x_vars[0]->IsType()) { for (auto& var : x_vars) { auto& value = var->Get().value(); if (value.IsInitialized()) { return framework::OpKernelType( framework::TransToProtoVarType(value.dtype()), - ctx.device_context(), - layout, - library); + ctx.device_context()); } } // if input sparse vars are not initialized, use an default kernel type. return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context(), - layout, - library); + ctx.device_context()); } else if (x_vars[0]->IsType()) { for (auto& x_var : x_vars) { auto& array = x_var->Get(); @@ -122,9 +112,7 @@ class SumOp : public framework::OperatorWithKernel { if (each.numel() != 0 && each.IsInitialized()) { return framework::OpKernelType( framework::TransToProtoVarType(each.dtype()), - ctx.device_context(), - layout, - library); + ctx.device_context()); } } } -- GitLab