From a74dfefbb72f8946278410d236fc7f5ab862aefd Mon Sep 17 00:00:00 2001 From: --get Date: Fri, 22 May 2020 03:34:31 +0000 Subject: [PATCH] (feat): add int32 io copy support --- lite/core/mir/mlu_postprocess_pass.cc | 6 +++++- lite/core/op_registry.cc | 2 ++ lite/core/op_registry.h | 6 ++++++ lite/kernels/mlu/io_copy_compute.cc | 17 +++++++++++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index 61a56fa7e1..3f3e727983 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -21,6 +21,7 @@ #include #include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/subgraph/subgraph_detector.h" #include "lite/operators/subgraph_op.h" namespace paddle { @@ -674,7 +675,8 @@ std::pair CheckInputAndInsert(Scope* scope, } if (!PrecisionCompatible(*tensor_type, *subgraph_type) && - tensor_type->precision() != PRECISION(kInt8)) { + tensor_type->precision() != PRECISION(kInt8) && + tensor_type->precision() != PRECISION(kInt32)) { auto cast_op = block_desc->AddOp(); auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); scope->Var(cast_arg_name); @@ -915,6 +917,8 @@ void MLUPostprocessPass::Apply(const std::unique_ptr& graph) { } } } + // std::vector> subgraphs({graph->NodeTopologicalOrder()}); + // SubgraphVisualizer(graph.get(), subgraphs)(); } } // namespace mir diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 3e185f9d8a..74a74b8074 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -153,6 +153,8 @@ KernelRegistry::KernelRegistry() INIT_FOR(kMLU, kInt8, kNCHW); INIT_FOR(kMLU, kInt16, kNHWC); INIT_FOR(kMLU, kInt16, kNCHW); + INIT_FOR(kMLU, kInt32, kNHWC); + INIT_FOR(kMLU, kInt32, kNCHW); INIT_FOR(kHost, kAny, kNCHW); INIT_FOR(kHost, kAny, kNHWC); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index f46451ec04..914de6e711 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -321,6 +321,12 @@ class KernelRegistry final { DATALAYOUT(kNHWC)> *, // KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // >; diff --git a/lite/kernels/mlu/io_copy_compute.cc b/lite/kernels/mlu/io_copy_compute.cc index d11279f767..ff8a7ddf6e 100644 --- a/lite/kernels/mlu/io_copy_compute.cc +++ b/lite/kernels/mlu/io_copy_compute.cc @@ -134,6 +134,23 @@ REGISTER_LITE_KERNEL( DATALAYOUT(kAny))}) .Finalize(); +REGISTER_LITE_KERNEL( + io_copy, + kMLU, + kInt32, + kNHWC, + paddle::lite::kernels::mlu::IoCopyHostToMluCompute, + host_to_device_kInt32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt32), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kInt32), + DATALAYOUT(kAny))}) + .Finalize(); + REGISTER_LITE_KERNEL( io_copy, kMLU, -- GitLab