提交 a74dfefb 编写于 作者: - --get 提交者: jackzhang235

(feat): add int32 io copy support

上级 3cc18277
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <vector> #include <vector>
#include "lite/core/mir/graph_visualize_pass.h" #include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/subgraph/subgraph_detector.h"
#include "lite/operators/subgraph_op.h" #include "lite/operators/subgraph_op.h"
namespace paddle { namespace paddle {
...@@ -674,7 +675,8 @@ std::pair<bool, std::string> CheckInputAndInsert(Scope* scope, ...@@ -674,7 +675,8 @@ std::pair<bool, std::string> CheckInputAndInsert(Scope* scope,
} }
if (!PrecisionCompatible(*tensor_type, *subgraph_type) && if (!PrecisionCompatible(*tensor_type, *subgraph_type) &&
tensor_type->precision() != PRECISION(kInt8)) { tensor_type->precision() != PRECISION(kInt8) &&
tensor_type->precision() != PRECISION(kInt32)) {
auto cast_op = block_desc->AddOp<cpp::OpDesc>(); auto cast_op = block_desc->AddOp<cpp::OpDesc>();
auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name); scope->Var(cast_arg_name);
...@@ -915,6 +917,8 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -915,6 +917,8 @@ void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} }
} }
} }
// std::vector<std::vector<Node*>> subgraphs({graph->NodeTopologicalOrder()});
// SubgraphVisualizer(graph.get(), subgraphs)();
} }
} // namespace mir } // namespace mir
......
...@@ -153,6 +153,8 @@ KernelRegistry::KernelRegistry() ...@@ -153,6 +153,8 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kMLU, kInt8, kNCHW); INIT_FOR(kMLU, kInt8, kNCHW);
INIT_FOR(kMLU, kInt16, kNHWC); INIT_FOR(kMLU, kInt16, kNHWC);
INIT_FOR(kMLU, kInt16, kNCHW); INIT_FOR(kMLU, kInt16, kNCHW);
INIT_FOR(kMLU, kInt32, kNHWC);
INIT_FOR(kMLU, kInt32, kNCHW);
INIT_FOR(kHost, kAny, kNCHW); INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kAny, kNHWC); INIT_FOR(kHost, kAny, kNHWC);
......
...@@ -321,6 +321,12 @@ class KernelRegistry final { ...@@ -321,6 +321,12 @@ class KernelRegistry final {
DATALAYOUT(kNHWC)> *, // DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU), KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt16), PRECISION(kInt16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt32),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt32),
DATALAYOUT(kNCHW)> * // DATALAYOUT(kNCHW)> * //
>; >;
......
...@@ -134,6 +134,23 @@ REGISTER_LITE_KERNEL( ...@@ -134,6 +134,23 @@ REGISTER_LITE_KERNEL(
DATALAYOUT(kAny))}) DATALAYOUT(kAny))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(
io_copy,
kMLU,
kInt32,
kNHWC,
paddle::lite::kernels::mlu::IoCopyHostToMluCompute<PRECISION(kInt32)>,
host_to_device_kInt32)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kInt32),
DATALAYOUT(kAny))})
.Finalize();
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
io_copy, io_copy,
kMLU, kMLU,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册