提交 53b6e51b 编写于 作者: D dingminghui 提交者: jackzhang235

fix(place): add x86 nhwc place in subgraph_pass

Since input tensor is modified to NHWC, which is not compatible to cast
kernel in NCHW layout, type_layout_cast_pass will insert a layout
instruction to transform data. Use NHWC cast kernel to avoid cast layout
上级 64c5aae6
......@@ -69,7 +69,8 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
for (auto& kernel : kernels) {
if (op_type == "cast") {
const Type* in_arg_ty = kernel->GetInputDeclType("X");
if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) {
if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type) &&
DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type)) {
is_found = true;
}
} else if (op_type == "layout") {
......@@ -564,6 +565,16 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
old_type->precision(),
paddle::lite_api::DataLayoutType::kNHWC,
old_type->device());
// modify inst feed to NHWC, while set_mlu_input_layout(kNHWC)
// invoked, to keep consistent with actual data layout
auto place = node.AsStmt().place();
place.layout = DATALAYOUT(kNHWC);
std::vector<Place> valid_places = {place};
auto updated_op_info = *node.AsStmt().op_info();
node.AsStmt().ResetOp(updated_op_info, valid_places, nullptr);
auto kernel = &(node.AsStmt().picked_kernel());
VLOG(4) << "kernel info: " << kernel->name();
node.AsStmt().op()->AttachKernel(kernel);
}
}
}
......
......@@ -95,7 +95,26 @@ void MLUSubgraphPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
++it;
}
}
// add x86 NHWC place
std::vector<paddle::lite_api::PrecisionType> precisions{PRECISION(kFloat),
PRECISION(kFP16)};
if (lite::TargetWrapperMlu::UseFirstConv())
precisions.emplace_back(PRECISION(kInt8));
for (auto& prec : precisions) {
auto is_x86_nhwc = [prec](const Place& it) {
return it.layout == DATALAYOUT(kNHWC) && it.target == TARGET(kX86) &&
it.precision == prec;
};
if (std::find_if(v_places.cbegin(), v_places.cend(), is_x86_nhwc) ==
v_places.end()) {
v_places.emplace_back(Place{TARGET(kX86), prec, DATALAYOUT(kNHWC)});
}
}
graph->SetValidPlaces(v_places);
VLOG(4) << "valid places after modified:";
for (auto& p : v_places) {
VLOG(4) << p.DebugString();
}
#endif
std::unordered_set<std::string> supported_lists;
......
......@@ -182,6 +182,9 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kX86, kFloat, kNCHW);
INIT_FOR(kX86, kFP16, kNCHW);
INIT_FOR(kX86, kInt8, kNCHW);
INIT_FOR(kX86, kFloat, kNHWC);
INIT_FOR(kX86, kFP16, kNHWC);
INIT_FOR(kX86, kInt8, kNHWC);
INIT_FOR(kX86, kAny, kNCHW);
INIT_FOR(kX86, kAny, kAny);
INIT_FOR(kX86, kInt64, kNCHW);
......
......@@ -126,6 +126,15 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kFP16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kHost),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
......
......@@ -20,7 +20,10 @@ REGISTER_LITE_KERNEL(cast,
kNCHW,
paddle::lite::kernels::x86::CastCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kX86),
PRECISION(kFloat),
DATALAYOUT(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -31,6 +34,9 @@ REGISTER_LITE_KERNEL(
kNCHW,
paddle::lite::kernels::x86::CastCompute<::paddle::lite::fluid::float16>,
fp16_to_any)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFP16))})
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kX86),
PRECISION(kFP16),
DATALAYOUT(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册