diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index 61a56fa7e12b6189d6c7d48d16cf30cc21a8c350..3f3e727983d1e677008bc3f28c27a68f8db0472d 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -21,6 +21,7 @@ #include #include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/subgraph/subgraph_detector.h" #include "lite/operators/subgraph_op.h" namespace paddle { @@ -674,7 +675,8 @@ std::pair CheckInputAndInsert(Scope* scope, } if (!PrecisionCompatible(*tensor_type, *subgraph_type) && - tensor_type->precision() != PRECISION(kInt8)) { + tensor_type->precision() != PRECISION(kInt8) && + tensor_type->precision() != PRECISION(kInt32)) { auto cast_op = block_desc->AddOp(); auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); scope->Var(cast_arg_name); @@ -915,6 +917,8 @@ void MLUPostprocessPass::Apply(const std::unique_ptr& graph) { } } } + // std::vector> subgraphs({graph->NodeTopologicalOrder()}); + // SubgraphVisualizer(graph.get(), subgraphs)(); } } // namespace mir diff --git a/lite/core/op_registry.cc b/lite/core/op_registry.cc index 3e185f9d8acc838d75ac4770ed28cfbecc072837..74a74b80747a379bdac1bf69aba9527a977ed456 100644 --- a/lite/core/op_registry.cc +++ b/lite/core/op_registry.cc @@ -153,6 +153,8 @@ KernelRegistry::KernelRegistry() INIT_FOR(kMLU, kInt8, kNCHW); INIT_FOR(kMLU, kInt16, kNHWC); INIT_FOR(kMLU, kInt16, kNCHW); + INIT_FOR(kMLU, kInt32, kNHWC); + INIT_FOR(kMLU, kInt32, kNCHW); INIT_FOR(kHost, kAny, kNCHW); INIT_FOR(kHost, kAny, kNHWC); diff --git a/lite/core/op_registry.h b/lite/core/op_registry.h index f46451ec049fe439614986f3db89d3786b54794b..914de6e711e95e42c7246ff596e75d721684985e 100644 --- a/lite/core/op_registry.h +++ b/lite/core/op_registry.h @@ -321,6 +321,12 @@ class KernelRegistry final { DATALAYOUT(kNHWC)> *, // KernelRegistryForTarget *, // + KernelRegistryForTarget *, // + KernelRegistryForTarget * // >; diff --git a/lite/kernels/mlu/io_copy_compute.cc b/lite/kernels/mlu/io_copy_compute.cc index d11279f767d77b04a78d51c3b283ed638070b64d..ff8a7ddf6e4c465f288ba42b5b2537294a9d7ffd 100644 --- a/lite/kernels/mlu/io_copy_compute.cc +++ b/lite/kernels/mlu/io_copy_compute.cc @@ -134,6 +134,23 @@ REGISTER_LITE_KERNEL( DATALAYOUT(kAny))}) .Finalize(); +REGISTER_LITE_KERNEL( + io_copy, + kMLU, + kInt32, + kNHWC, + paddle::lite::kernels::mlu::IoCopyHostToMluCompute, + host_to_device_kInt32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kInt32), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kMLU), + PRECISION(kInt32), + DATALAYOUT(kAny))}) + .Finalize(); + REGISTER_LITE_KERNEL( io_copy, kMLU,