未验证 提交 d6c69d7c 编写于 作者: H HongyuJia 提交者: GitHub

[Opt Code] Opt GetExpectedKernelType code of sum (#46678)

* refine sum_op mkldnn code

* refine sum_op mkldnn code
上级 ee1aec62
......@@ -39,9 +39,6 @@ class SumOp : public framework::OperatorWithKernel {
auto x_vars = ctx.MultiInputVar("X");
auto x_vars_name = ctx.InputNames("X");
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
PADDLE_ENFORCE_GT(
x_vars.size(),
0,
......@@ -80,8 +77,7 @@ class SumOp : public framework::OperatorWithKernel {
auto data_type = static_cast<framework::proto::VarType::Type>(dtype);
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type) &&
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(data_type == framework::proto::VarType::FP32 ||
data_type == framework::proto::VarType::BF16) &&
ctx.OutputVar("Out")->IsType<framework::LoDTensor>()) {
......@@ -96,25 +92,19 @@ class SumOp : public framework::OperatorWithKernel {
}
}
#endif
return framework::OpKernelType(
data_type, ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
} else if (x_vars[0]->IsType<phi::SelectedRows>()) {
for (auto& var : x_vars) {
auto& value = var->Get<phi::SelectedRows>().value();
if (value.IsInitialized()) {
return framework::OpKernelType(
framework::TransToProtoVarType(value.dtype()),
ctx.device_context(),
layout,
library);
ctx.device_context());
}
}
// if input sparse vars are not initialized, use an default kernel type.
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context(),
layout,
library);
ctx.device_context());
} else if (x_vars[0]->IsType<framework::LoDTensorArray>()) {
for (auto& x_var : x_vars) {
auto& array = x_var->Get<framework::LoDTensorArray>();
......@@ -122,9 +112,7 @@ class SumOp : public framework::OperatorWithKernel {
if (each.numel() != 0 && each.IsInitialized()) {
return framework::OpKernelType(
framework::TransToProtoVarType(each.dtype()),
ctx.device_context(),
layout,
library);
ctx.device_context());
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册