diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 8cf6a095e23045113235f73e8729c43d4cad34f7..939015890433b829640471995617e71bb60ecbc2 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -39,9 +39,6 @@ class SumOp : public framework::OperatorWithKernel { auto x_vars = ctx.MultiInputVar("X"); auto x_vars_name = ctx.InputNames("X"); - framework::LibraryType library{framework::LibraryType::kPlain}; - framework::DataLayout layout{framework::DataLayout::kAnyLayout}; - PADDLE_ENFORCE_GT( x_vars.size(), 0, @@ -80,8 +77,7 @@ class SumOp : public framework::OperatorWithKernel { auto data_type = static_cast(dtype); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type) && + if (this->CanMKLDNNBeUsed(ctx, data_type) && (data_type == framework::proto::VarType::FP32 || data_type == framework::proto::VarType::BF16) && ctx.OutputVar("Out")->IsType()) { @@ -96,25 +92,19 @@ class SumOp : public framework::OperatorWithKernel { } } #endif - - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } else if (x_vars[0]->IsType()) { for (auto& var : x_vars) { auto& value = var->Get().value(); if (value.IsInitialized()) { return framework::OpKernelType( framework::TransToProtoVarType(value.dtype()), - ctx.device_context(), - layout, - library); + ctx.device_context()); } } // if input sparse vars are not initialized, use an default kernel type. return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context(), - layout, - library); + ctx.device_context()); } else if (x_vars[0]->IsType()) { for (auto& x_var : x_vars) { auto& array = x_var->Get(); @@ -122,9 +112,7 @@ class SumOp : public framework::OperatorWithKernel { if (each.numel() != 0 && each.IsInitialized()) { return framework::OpKernelType( framework::TransToProtoVarType(each.dtype()), - ctx.device_context(), - layout, - library); + ctx.device_context()); } } }