未验证 提交 d6c69d7c 编写于 作者: H HongyuJia 提交者: GitHub

[Opt Code] Opt GetExpectedKernelType code of sum (#46678)

* refine sum_op mkldnn code

* refine sum_op mkldnn code
上级 ee1aec62
...@@ -39,9 +39,6 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -39,9 +39,6 @@ class SumOp : public framework::OperatorWithKernel {
auto x_vars = ctx.MultiInputVar("X"); auto x_vars = ctx.MultiInputVar("X");
auto x_vars_name = ctx.InputNames("X"); auto x_vars_name = ctx.InputNames("X");
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
x_vars.size(), x_vars.size(),
0, 0,
...@@ -80,8 +77,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -80,8 +77,7 @@ class SumOp : public framework::OperatorWithKernel {
auto data_type = static_cast<framework::proto::VarType::Type>(dtype); auto data_type = static_cast<framework::proto::VarType::Type>(dtype);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (this->CanMKLDNNBeUsed(ctx, data_type) &&
this->CanMKLDNNBeUsed(ctx, data_type) &&
(data_type == framework::proto::VarType::FP32 || (data_type == framework::proto::VarType::FP32 ||
data_type == framework::proto::VarType::BF16) && data_type == framework::proto::VarType::BF16) &&
ctx.OutputVar("Out")->IsType<framework::LoDTensor>()) { ctx.OutputVar("Out")->IsType<framework::LoDTensor>()) {
...@@ -96,25 +92,19 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -96,25 +92,19 @@ class SumOp : public framework::OperatorWithKernel {
} }
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace());
return framework::OpKernelType(
data_type, ctx.GetPlace(), layout, library);
} else if (x_vars[0]->IsType<phi::SelectedRows>()) { } else if (x_vars[0]->IsType<phi::SelectedRows>()) {
for (auto& var : x_vars) { for (auto& var : x_vars) {
auto& value = var->Get<phi::SelectedRows>().value(); auto& value = var->Get<phi::SelectedRows>().value();
if (value.IsInitialized()) { if (value.IsInitialized()) {
return framework::OpKernelType( return framework::OpKernelType(
framework::TransToProtoVarType(value.dtype()), framework::TransToProtoVarType(value.dtype()),
ctx.device_context(), ctx.device_context());
layout,
library);
} }
} }
// if input sparse vars are not initialized, use an default kernel type. // if input sparse vars are not initialized, use an default kernel type.
return framework::OpKernelType(framework::proto::VarType::FP32, return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context(), ctx.device_context());
layout,
library);
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) { } else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
for (auto& x_var : x_vars) { for (auto& x_var : x_vars) {
auto& array = x_var->Get<framework::LoDTensorArray>(); auto& array = x_var->Get<framework::LoDTensorArray>();
...@@ -122,9 +112,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -122,9 +112,7 @@ class SumOp : public framework::OperatorWithKernel {
if (each.numel() != 0 && each.IsInitialized()) { if (each.numel() != 0 && each.IsInitialized()) {
return framework::OpKernelType( return framework::OpKernelType(
framework::TransToProtoVarType(each.dtype()), framework::TransToProtoVarType(each.dtype()),
ctx.device_context(), ctx.device_context());
layout,
library);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册