提交 fffb2fe6 编写于 作者: J jackzhang235 提交者: MaxwellDing

move mlu passes forward

上级 eb78c1d8
......@@ -105,13 +105,8 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt();
if (op_type == "layout") {
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(TARGET(kX86)));
} else {
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target()));
}
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target()));
break;
}
}
......@@ -185,7 +180,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cast_type) &&
DataLayoutCompatible(*out_arg_ty, *cur_node->AsArg().type)) {
DataLayoutCompatible(*out_arg_ty, *cur_node->AsArg().type) &&
PrecisionCompatibleTo(*in_arg_ty, *cast_type)) {
is_found = true;
}
} else if (op_type == "io_copy") {
......@@ -203,13 +199,8 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt();
if (op_type == "layout") {
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(TARGET(kX86)));
} else {
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target()));
}
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target()));
break;
}
}
......@@ -517,6 +508,50 @@ void MLUPostprocessPass::GatherAndModifyFirstConvNodes(SSAGraph* graph) {
}
}
void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (node.IsStmt() && node.AsStmt().op_type() == "subgraph") {
for (auto& in_node : node.inlinks) {
const auto* in_node_type = in_node->AsArg().type;
VLOG(4) << "MLU subgraph input type: " << in_node->AsArg().name
<< *in_node_type;
if (in_node->AsArg().is_weight || in_node->AsArg().is_persist) {
CHECK(in_node_type->target() == TARGET(kHost) &&
in_node_type->precision() == PRECISION(kAny) &&
in_node_type->layout() == DATALAYOUT(kNCHW))
<< "MLU subgraph unexpected persistent input type!";
in_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kMLU), PRECISION(kAny), DATALAYOUT(kNHWC));
} else {
CHECK((in_node_type->target() == TARGET(kHost) ||
in_node_type->target() == TARGET(kX86)) &&
in_node_type->precision() == PRECISION(kFloat) &&
in_node_type->layout() == DATALAYOUT(kNCHW))
<< "MLU subgraph unexpected common input type!";
}
}
for (auto& out_node : node.outlinks) {
const auto* out_node_type = out_node->AsArg().type;
VLOG(4) << "MLU subgraph output type: " << out_node->AsArg().name
<< *out_node_type;
if (out_node->AsArg().is_weight || out_node->AsArg().is_persist) {
CHECK(out_node_type->target() == TARGET(kHost) &&
out_node_type->precision() == PRECISION(kAny) &&
out_node_type->layout() == DATALAYOUT(kNCHW))
<< "MLU subgraph unexpected persistent input type!";
out_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kMLU), PRECISION(kAny), DATALAYOUT(kNHWC));
} else {
CHECK(out_node_type->precision() == PRECISION(kFloat))
<< "MLU subgraph unexpected common output type!";
out_node->AsArg().type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
}
}
}
}
}
void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue;
......@@ -562,14 +597,16 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
}
void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// currently for non-persistent input and output args, mlu subgraph op
// only support float16/float32 data type
// currently for non-persistent input and output args, mlu subgraph op
// only support float16/float32 data type
// in two situations as folllows:
// 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch;
// arg_in and arg_out are assumed to be NHWC which user should be aware of.
// Thus here we change these args' layout to NHWC
// in two situations as folllows:
// 1: feed->arg_in->subgraph->... 2: ...->subgraph->arg_out->fetch;
// arg_in and arg_out are assumed to be NHWC which user should be aware of.
// Thus here we change these args' layout to NHWC
#ifdef LITE_WITH_MLU
ModifyInputOutputDataType(graph.get());
if (lite::TargetWrapperMlu::InputLayout() == DATALAYOUT(kNHWC)) {
ModifyLayout(graph.get());
}
......
......@@ -79,6 +79,8 @@ class MLUPostprocessPass : public ProgramPass {
const Type** arg_type,
SSAGraph* graph);
void ModifyInputOutputDataType(SSAGraph* graph);
void ModifyLayout(SSAGraph* graph);
bool NeedInsert(Node* node, const Type* inst_type);
......
......@@ -180,6 +180,8 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kHost, kInt64, kAny);
INIT_FOR(kX86, kFloat, kNCHW);
INIT_FOR(kX86, kFP16, kNCHW);
INIT_FOR(kX86, kInt8, kNCHW);
INIT_FOR(kX86, kAny, kNCHW);
INIT_FOR(kX86, kAny, kAny);
INIT_FOR(kX86, kInt64, kNCHW);
......
......@@ -120,6 +120,9 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kFP16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kX86),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
......
......@@ -90,8 +90,12 @@ class Optimizer {
"xpu_subgraph_pass",
"bm_subgraph_pass",
"rknpu_subgraph_pass",
"mlu_subgraph_pass",
"static_kernel_pick_pass", // pick original kernel from graph
"variable_place_inference_pass", // inference arg/var's
"mlu_postprocess_pass",
// info(target/precision/layout/device)
// using kernel info
"argument_type_display_pass", // debug pass: show arg-type-node's
......@@ -121,13 +125,9 @@ class Optimizer {
"variable_place_inference_pass", //
"argument_type_display_pass",
"mlu_subgraph_pass",
"runtime_context_assign_pass",
"argument_type_display_pass",
"mlu_postprocess_pass",
"memory_optimize_pass"}};
if (passes.size() == 1) {
......
......@@ -102,8 +102,14 @@ REGISTER_LITE_KERNEL(
kNHWC,
paddle::lite::kernels::mlu::IoCopyHostToMluCompute<PRECISION(kFloat)>,
host_to_device_kFloat)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kMLU))})
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
......@@ -113,8 +119,14 @@ REGISTER_LITE_KERNEL(
kNHWC,
paddle::lite::kernels::mlu::IoCopyHostToMluCompute<PRECISION(kFP16)>,
host_to_device_kFP16)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kMLU))})
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
......@@ -124,8 +136,14 @@ REGISTER_LITE_KERNEL(
kNHWC,
paddle::lite::kernels::mlu::IoCopyMluToHostCompute<PRECISION(kFloat)>,
device_to_host_kFloat)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kMLU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
......@@ -135,6 +153,12 @@ REGISTER_LITE_KERNEL(
kNHWC,
paddle::lite::kernels::mlu::IoCopyMluToHostCompute<PRECISION(kFP16)>,
device_to_host_kFP16)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kMLU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
......@@ -24,9 +24,9 @@ namespace mlu {} // namespace mlu
REGISTER_LITE_KERNEL(
layout,
kMLU,
kX86,
kFloat,
kNHWC,
kNCHW,
paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute<PRECISION(kFloat)>,
def_layout_nhwc2nchw_fp32)
.BindInput("Input",
......@@ -41,9 +41,9 @@ REGISTER_LITE_KERNEL(
REGISTER_LITE_KERNEL(
layout,
kMLU,
kX86,
kFP16,
kNHWC,
kNCHW,
paddle::lite::kernels::mlu::LayoutNhwcToNchwCompute<PRECISION(kFP16)>,
def_layout_nhwc2nchw_fp16)
.BindInput("Input",
......@@ -58,9 +58,9 @@ REGISTER_LITE_KERNEL(
REGISTER_LITE_KERNEL(
layout,
kMLU,
kX86,
kFloat,
kNHWC,
kNCHW,
paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute<PRECISION(kFloat)>,
def_layout_nchw2nhwc_fp32)
.BindInput("Input",
......@@ -75,9 +75,9 @@ REGISTER_LITE_KERNEL(
REGISTER_LITE_KERNEL(
layout,
kMLU,
kX86,
kFP16,
kNHWC,
kNCHW,
paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute<PRECISION(kFP16)>,
def_layout_nchw2nhwc_fp16)
.BindInput("Input",
......@@ -92,11 +92,11 @@ REGISTER_LITE_KERNEL(
REGISTER_LITE_KERNEL(
layout,
kMLU,
kX86,
kInt8,
kNHWC,
kNCHW,
paddle::lite::kernels::mlu::LayoutNchwToNhwcCompute<PRECISION(kInt8)>,
def_layout_nchw2nhwc_fp32_int8)
def_layout_nchw2nhwc_int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt8),
......
......@@ -73,7 +73,7 @@ inline void LayoutTransCompute(const int dim,
template <PrecisionType Precision>
class LayoutNchwToNhwcCompute
: public KernelLite<TARGET(kMLU), Precision, DATALAYOUT(kNHWC)> {
: public KernelLite<TARGET(kX86), Precision, DATALAYOUT(kNCHW)> {
public:
using param_t = operators::LayoutParam;
......@@ -122,7 +122,7 @@ class LayoutNchwToNhwcCompute
template <PrecisionType Precision>
class LayoutNhwcToNchwCompute
: public KernelLite<TARGET(kMLU), Precision, DATALAYOUT(kNHWC)> {
: public KernelLite<TARGET(kX86), Precision, DATALAYOUT(kNCHW)> {
public:
using param_t = operators::LayoutParam;
......
......@@ -36,8 +36,14 @@ REGISTER_LITE_KERNEL(
kNHWC,
paddle::lite::kernels::mlu::SubgraphCompute<PRECISION(kFloat)>,
def_kFloat)
.BindInput("Inputs", {LiteType::GetTensorTy(TARGET(kMLU))})
.BindOutput("Outputs", {LiteType::GetTensorTy(TARGET(kMLU))})
.BindInput("Inputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Outputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL(
......@@ -47,6 +53,12 @@ REGISTER_LITE_KERNEL(
kNHWC,
paddle::lite::kernels::mlu::SubgraphCompute<PRECISION(kFP16)>,
def_FP16)
.BindInput("Inputs", {LiteType::GetTensorTy(TARGET(kMLU))})
.BindOutput("Outputs", {LiteType::GetTensorTy(TARGET(kMLU))})
.BindInput("Inputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Outputs",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册