Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
c5e83404
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c5e83404
编写于
3月 21, 2020
作者:
J
jackzhang235
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add support for 3 dim inputs of mlu subgraph op
上级
ed48feaa
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
108 addition
and
45 deletion
+108
-45
lite/core/mir/mlu_postprocess_pass.cc
lite/core/mir/mlu_postprocess_pass.cc
+36
-18
lite/core/mir/subgraph_cast_display_pass.cc
lite/core/mir/subgraph_cast_display_pass.cc
+2
-2
lite/core/optimizer.h
lite/core/optimizer.h
+2
-2
lite/kernels/mlu/CMakeLists.txt
lite/kernels/mlu/CMakeLists.txt
+1
-0
lite/kernels/mlu/bridges/concat_op.cc
lite/kernels/mlu/bridges/concat_op.cc
+20
-9
lite/kernels/mlu/bridges/transpose_op.cc
lite/kernels/mlu/bridges/transpose_op.cc
+43
-12
lite/kernels/mlu/subgraph_compute.h
lite/kernels/mlu/subgraph_compute.h
+4
-2
未找到文件。
lite/core/mir/mlu_postprocess_pass.cc
浏览文件 @
c5e83404
...
@@ -50,10 +50,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
...
@@ -50,10 +50,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
op_desc
.
SetAttr
<
int
>
(
"out_dtype"
,
4
);
// FP16
op_desc
.
SetAttr
<
int
>
(
"out_dtype"
,
4
);
// FP16
op_desc
.
SetInput
(
"X"
,
{
cur_node
->
AsArg
().
name
});
op_desc
.
SetInput
(
"X"
,
{
cur_node
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
cast_arg_name
});
op_desc
.
SetOutput
(
"Out"
,
{
cast_arg_name
});
}
else
if
(
op_type
==
"
transpose
"
)
{
}
else
if
(
op_type
==
"
layout
"
)
{
// NCHW -> NHWC
// NCHW -> NHWC
op_desc
.
SetAttr
<
std
::
vector
<
int
>>
(
"axis"
,
{
0
,
2
,
3
,
1
});
op_desc
.
SetInput
(
"Input"
,
{
cur_node
->
AsArg
().
name
});
op_desc
.
SetInput
(
"X"
,
{
cur_node
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
cast_arg_name
});
op_desc
.
SetOutput
(
"Out"
,
{
cast_arg_name
});
}
else
if
(
op_type
==
"io_copy"
)
{
}
else
if
(
op_type
==
"io_copy"
)
{
op_desc
.
SetInput
(
"Input"
,
{
cur_node
->
AsArg
().
name
});
op_desc
.
SetInput
(
"Input"
,
{
cur_node
->
AsArg
().
name
});
...
@@ -72,8 +71,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
...
@@ -72,8 +71,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
if
(
PrecisionCompatibleTo
(
*
in_arg_ty
,
*
cur_node
->
AsArg
().
type
))
{
if
(
PrecisionCompatibleTo
(
*
in_arg_ty
,
*
cur_node
->
AsArg
().
type
))
{
is_found
=
true
;
is_found
=
true
;
}
}
}
else
if
(
op_type
==
"transpose"
)
{
}
else
if
(
op_type
==
"layout"
)
{
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
if
(
DataLayoutCompatible
(
*
in_arg_ty
,
*
cur_node
->
AsArg
().
type
)
&&
DataLayoutCompatible
(
*
out_arg_ty
,
*
cast_type
))
{
is_found
=
true
;
is_found
=
true
;
}
}
else
if
(
op_type
==
"io_copy"
)
{
}
else
if
(
op_type
==
"io_copy"
)
{
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
...
@@ -89,8 +93,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
...
@@ -89,8 +93,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
// we pick the kernel
// we pick the kernel
cast_inst
->
AsStmt
(
op_type
,
std
::
move
(
selected_kernels
),
cast_op
);
cast_inst
->
AsStmt
(
op_type
,
std
::
move
(
selected_kernels
),
cast_op
);
auto
&
stmt
=
cast_inst
->
AsStmt
();
auto
&
stmt
=
cast_inst
->
AsStmt
();
if
(
op_type
==
"layout"
)
{
stmt
.
picked_kernel
().
SetContext
(
stmt
.
picked_kernel
().
SetContext
(
ContextScheduler
::
Global
().
NewContext
(
stmt
.
picked_kernel
().
target
()));
ContextScheduler
::
Global
().
NewContext
(
TARGET
(
kX86
)));
}
else
{
stmt
.
picked_kernel
().
SetContext
(
ContextScheduler
::
Global
().
NewContext
(
stmt
.
picked_kernel
().
target
()));
}
break
;
break
;
}
}
}
}
...
@@ -127,10 +136,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
...
@@ -127,10 +136,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
op_desc
.
SetAttr
<
int
>
(
"out_dtype"
,
5
);
// FP16
op_desc
.
SetAttr
<
int
>
(
"out_dtype"
,
5
);
// FP16
op_desc
.
SetInput
(
"X"
,
{
cast_arg_name
});
op_desc
.
SetInput
(
"X"
,
{
cast_arg_name
});
op_desc
.
SetOutput
(
"Out"
,
{
cur_node
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
cur_node
->
AsArg
().
name
});
}
else
if
(
op_type
==
"
transpose
"
)
{
}
else
if
(
op_type
==
"
layout
"
)
{
// NHWC -> NCHW
// NHWC -> NCHW
op_desc
.
SetAttr
<
std
::
vector
<
int
>>
(
"axis"
,
{
0
,
3
,
1
,
2
});
op_desc
.
SetInput
(
"Input"
,
{
cast_arg_name
});
op_desc
.
SetInput
(
"X"
,
{
cast_arg_name
});
op_desc
.
SetOutput
(
"Out"
,
{
cur_node
->
AsArg
().
name
});
op_desc
.
SetOutput
(
"Out"
,
{
cur_node
->
AsArg
().
name
});
}
else
if
(
op_type
==
"io_copy"
)
{
}
else
if
(
op_type
==
"io_copy"
)
{
op_desc
.
SetInput
(
"Input"
,
{
cast_arg_name
});
op_desc
.
SetInput
(
"Input"
,
{
cast_arg_name
});
...
@@ -151,8 +159,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
...
@@ -151,8 +159,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
if
(
PrecisionCompatibleTo
(
*
in_arg_ty
,
*
cast_type
))
{
if
(
PrecisionCompatibleTo
(
*
in_arg_ty
,
*
cast_type
))
{
is_found
=
true
;
is_found
=
true
;
}
}
}
else
if
(
op_type
==
"transpose"
)
{
}
else
if
(
op_type
==
"layout"
)
{
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
if
(
DataLayoutCompatible
(
*
in_arg_ty
,
*
cast_type
)
&&
DataLayoutCompatible
(
*
out_arg_ty
,
*
cur_node
->
AsArg
().
type
))
{
is_found
=
true
;
is_found
=
true
;
}
}
else
if
(
op_type
==
"io_copy"
)
{
}
else
if
(
op_type
==
"io_copy"
)
{
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
...
@@ -168,8 +181,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
...
@@ -168,8 +181,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
// we pick the kernel
// we pick the kernel
cast_inst
->
AsStmt
(
op_type
,
std
::
move
(
selected_kernels
),
cast_op
);
cast_inst
->
AsStmt
(
op_type
,
std
::
move
(
selected_kernels
),
cast_op
);
auto
&
stmt
=
cast_inst
->
AsStmt
();
auto
&
stmt
=
cast_inst
->
AsStmt
();
if
(
op_type
==
"layout"
)
{
stmt
.
picked_kernel
().
SetContext
(
stmt
.
picked_kernel
().
SetContext
(
ContextScheduler
::
Global
().
NewContext
(
stmt
.
picked_kernel
().
target
()));
ContextScheduler
::
Global
().
NewContext
(
TARGET
(
kX86
)));
}
else
{
stmt
.
picked_kernel
().
SetContext
(
ContextScheduler
::
Global
().
NewContext
(
stmt
.
picked_kernel
().
target
()));
}
break
;
break
;
}
}
}
}
...
@@ -197,8 +215,8 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
...
@@ -197,8 +215,8 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
// layout cast node
// layout cast node
if
(
head_type
->
layout
()
!=
inst_type
->
layout
())
{
if
(
head_type
->
layout
()
!=
inst_type
->
layout
())
{
cur_node
=
InsertCastBefore
(
cur_node
=
InsertCastBefore
(
"
transpose
"
,
"
layout
"
,
name_prefix
+
"
transpose
"
,
name_prefix
+
"
layout
"
,
graph
,
graph
,
cur_node
,
cur_node
,
inst_node
,
inst_node
,
...
@@ -346,8 +364,8 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
...
@@ -346,8 +364,8 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
// layout cast node
// layout cast node
if
(
tail_type
->
layout
()
!=
inst_type
->
layout
())
{
if
(
tail_type
->
layout
()
!=
inst_type
->
layout
())
{
cur_node
=
InsertCastAfter
(
cur_node
=
InsertCastAfter
(
"
transpose
"
,
"
layout
"
,
name_prefix
+
"
transpose
"
,
name_prefix
+
"
layout
"
,
graph
,
graph
,
cur_node
,
cur_node
,
inst_node
,
inst_node
,
...
...
lite/core/mir/subgraph_cast_display_pass.cc
浏览文件 @
c5e83404
...
@@ -53,7 +53,7 @@ class SubgraphCastDisplayPass : public DebugPass {
...
@@ -53,7 +53,7 @@ class SubgraphCastDisplayPass : public DebugPass {
for
(
auto
p_in_stmt_node
:
p_in_arg_node
->
inlinks
)
{
for
(
auto
p_in_stmt_node
:
p_in_arg_node
->
inlinks
)
{
CHECK
(
p_in_stmt_node
->
IsStmt
());
CHECK
(
p_in_stmt_node
->
IsStmt
());
std
::
string
stmt_op_type
=
p_in_stmt_node
->
AsStmt
().
op_type
();
std
::
string
stmt_op_type
=
p_in_stmt_node
->
AsStmt
().
op_type
();
if
(
stmt_op_type
==
"cast"
||
stmt_op_type
==
"
transpose
"
||
if
(
stmt_op_type
==
"cast"
||
stmt_op_type
==
"
layout
"
||
stmt_op_type
==
"io_copy"
)
{
stmt_op_type
==
"io_copy"
)
{
display_debug_info
(
*
p_in_stmt_node
,
stmt_op_type
,
true
,
false
);
display_debug_info
(
*
p_in_stmt_node
,
stmt_op_type
,
true
,
false
);
}
else
{
}
else
{
...
@@ -76,7 +76,7 @@ class SubgraphCastDisplayPass : public DebugPass {
...
@@ -76,7 +76,7 @@ class SubgraphCastDisplayPass : public DebugPass {
for
(
auto
p_out_stmt_node
:
p_out_arg_node
->
outlinks
)
{
for
(
auto
p_out_stmt_node
:
p_out_arg_node
->
outlinks
)
{
CHECK
(
p_out_stmt_node
->
IsStmt
());
CHECK
(
p_out_stmt_node
->
IsStmt
());
std
::
string
stmt_op_type
=
p_out_stmt_node
->
AsStmt
().
op_type
();
std
::
string
stmt_op_type
=
p_out_stmt_node
->
AsStmt
().
op_type
();
if
(
stmt_op_type
==
"cast"
||
stmt_op_type
==
"
transpose
"
||
if
(
stmt_op_type
==
"cast"
||
stmt_op_type
==
"
layout
"
||
stmt_op_type
==
"io_copy"
)
{
stmt_op_type
==
"io_copy"
)
{
display_debug_info
(
*
p_out_stmt_node
,
stmt_op_type
,
false
,
true
);
display_debug_info
(
*
p_out_stmt_node
,
stmt_op_type
,
false
,
true
);
}
else
{
}
else
{
...
...
lite/core/optimizer.h
浏览文件 @
c5e83404
...
@@ -116,12 +116,12 @@ class Optimizer {
...
@@ -116,12 +116,12 @@ class Optimizer {
"argument_type_display_pass"
,
"argument_type_display_pass"
,
"mlu_subgraph_pass"
,
"mlu_subgraph_pass"
,
"mlu_postprocess_pass"
,
// subgraph_cast_display_pass
"runtime_context_assign_pass"
,
"runtime_context_assign_pass"
,
"argument_type_display_pass"
,
"argument_type_display_pass"
,
"mlu_postprocess_pass"
,
"memory_optimize_pass"
}};
"memory_optimize_pass"
}};
if
(
passes
.
size
()
==
1
)
{
if
(
passes
.
size
()
==
1
)
{
...
...
lite/kernels/mlu/CMakeLists.txt
浏览文件 @
c5e83404
...
@@ -6,3 +6,4 @@ add_subdirectory(bridges)
...
@@ -6,3 +6,4 @@ add_subdirectory(bridges)
add_kernel
(
subgraph_compute_mlu MLU basic SRCS subgraph_compute.cc DEPS
${
lite_kernel_deps
}
${
mlu_subgraph_bridges
}
)
add_kernel
(
subgraph_compute_mlu MLU basic SRCS subgraph_compute.cc DEPS
${
lite_kernel_deps
}
${
mlu_subgraph_bridges
}
)
add_kernel
(
io_copy_compute_mlu MLU basic SRCS io_copy_compute.cc DEPS
${
lite_kernel_deps
}
${
math_mlu
}
)
add_kernel
(
io_copy_compute_mlu MLU basic SRCS io_copy_compute.cc DEPS
${
lite_kernel_deps
}
${
math_mlu
}
)
add_kernel
(
calib_compute_mlu MLU basic SRCS calib_compute.cc DEPS
${
lite_kernel_deps
}
${
math_mlu
}
)
add_kernel
(
calib_compute_mlu MLU basic SRCS calib_compute.cc DEPS
${
lite_kernel_deps
}
${
math_mlu
}
)
add_kernel
(
layout_compute_mlu MLU basic SRCS layout_compute.cc DEPS
${
lite_kernel_deps
}
${
math_mlu
}
)
lite/kernels/mlu/bridges/concat_op.cc
浏览文件 @
c5e83404
...
@@ -46,26 +46,37 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
...
@@ -46,26 +46,37 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
input_dims
.
push_back
(
x
->
dims
().
Vectorize
());
input_dims
.
push_back
(
x
->
dims
().
Vectorize
());
}
}
auto
output
=
scope
->
FindVar
(
out_var_name
)
->
GetMutable
<
Tensor
>
();
auto
dims
=
input_dims
[
0
].
size
();
int
axis
=
(
param_axis
<
0
)
?
(
param_axis
+
output
->
dims
().
size
())
:
param_axis
;
int
axis
=
(
param_axis
<
0
)
?
(
param_axis
+
dims
)
:
param_axis
;
int
nhwc_axis
=
-
1
;
if
(
dims
==
4
)
{
int
nchw_to_nhwc_axis_map
[
4
]
=
{
0
,
3
,
1
,
2
};
int
nchw_to_nhwc_axis_map
[
4
]
=
{
0
,
3
,
1
,
2
};
int
nhwc_axis
=
nchw_to_nhwc_axis_map
[
axis
];
nhwc_axis
=
nchw_to_nhwc_axis_map
[
axis
];
}
else
if
(
dims
==
3
)
{
int
nchw_to_nhwc_axis_map
[
3
]
=
{
0
,
2
,
1
};
nhwc_axis
=
nchw_to_nhwc_axis_map
[
axis
];
}
else
{
CHECK
(
0
)
<<
"Unsupport dims in mlu concat"
;
}
std
::
vector
<
int64_t
>
output_dims
;
std
::
vector
<
int64_t
>
output_dims
;
output_dims
.
assign
(
output
->
dims
().
size
()
,
0
);
output_dims
.
assign
(
dims
,
0
);
/* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis, nhwc_axis) << std::endl; */
/* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis,
* nhwc_axis) << std::endl; */
for
(
int
i
=
0
;
i
<
output_dims
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
output_dims
.
size
();
++
i
)
{
if
(
i
==
nhwc_axis
)
{
if
(
i
==
nhwc_axis
)
{
for
(
auto
&
dim
:
input_dims
)
output_dims
[
i
]
+=
dim
[
i
];
for
(
auto
&
dim
:
input_dims
)
output_dims
[
i
]
+=
dim
[
i
];
}
else
{
}
else
{
output_dims
[
i
]
=
input_dims
[
0
][
i
];
output_dims
[
i
]
=
input_dims
[
0
][
i
];
}
}
}
}
/* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") << std::endl; */
/* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") <<
* std::endl; */
auto
*
output
=
scope
->
FindVar
(
out_var_name
)
->
GetMutable
<
Tensor
>
();
output
->
Resize
(
output_dims
);
output
->
Resize
(
output_dims
);
auto
output_tensor
=
graph
->
AddNode
(
auto
output_tensor
=
graph
->
AddNode
(
out_var_name
,
output_dims
,
CNML_TENSOR
,
CNML_NHWC
,
graph
->
FPType
());
out_var_name
,
output_dims
,
CNML_TENSOR
,
CNML_NHWC
,
graph
->
FPType
());
...
...
lite/kernels/mlu/bridges/transpose_op.cc
浏览文件 @
c5e83404
...
@@ -21,18 +21,37 @@ namespace lite {
...
@@ -21,18 +21,37 @@ namespace lite {
namespace
subgraph
{
namespace
subgraph
{
namespace
mlu
{
namespace
mlu
{
std
::
vector
<
int
>
axis_to_4d
(
std
::
vector
<
int
>
axis
)
{
std
::
vector
<
int
>
axis_to_nhwc4d
(
const
std
::
vector
<
int
>&
axis
)
{
if
(
axis
.
size
()
>=
4
)
{
CHECK_EQ
(
axis
.
size
(),
4
);
return
axis
;
std
::
vector
<
int
>
new_axis
(
4
,
0
);
const
std
::
vector
<
int
>
axis_map1
=
{
0
,
2
,
3
,
1
};
const
std
::
vector
<
int
>
axis_map2
=
{
0
,
3
,
1
,
2
};
for
(
size_t
i
=
0
;
i
<
new_axis
.
size
();
++
i
)
{
new_axis
[
i
]
=
axis_map2
[
axis
[
axis_map1
[
i
]]];
}
}
std
::
vector
<
int
>
new_axis
=
{
0
,
1
,
2
,
3
};
return
new_axis
;
int
i
=
0
;
}
for
(
i
=
0
;
i
<
axis
.
size
();
i
++
)
{
new_axis
[
i
]
=
axis
[
i
];
std
::
vector
<
int
>
axis_to_nhw3d
(
const
std
::
vector
<
int
>&
axis
)
{
CHECK_EQ
(
axis
.
size
(),
3
);
std
::
vector
<
int
>
new_axis
(
3
,
0
);
const
std
::
vector
<
int
>
axis_map
=
{
0
,
2
,
1
};
for
(
size_t
i
=
0
;
i
<
new_axis
.
size
();
++
i
)
{
new_axis
[
i
]
=
axis_map
[
axis
[
axis_map
[
i
]]];
}
}
new_axis
.
push_back
(
3
);
return
new_axis
;
return
new_axis
;
}
}
std
::
vector
<
int64_t
>
infer_shape
(
const
std
::
vector
<
int64_t
>&
x_dims
,
const
std
::
vector
<
int
>&
axis_nhwc
)
{
std
::
vector
<
int64_t
>
out_dims
(
x_dims
);
for
(
size_t
i
=
0
;
i
<
out_dims
.
size
();
++
i
)
{
out_dims
[
i
]
=
x_dims
[
axis_nhwc
[
i
]];
}
return
out_dims
;
}
int
TransposeConverter
(
void
*
ctx
,
OpLite
*
op
,
KernelBase
*
kernel
)
{
int
TransposeConverter
(
void
*
ctx
,
OpLite
*
op
,
KernelBase
*
kernel
)
{
CHECK
(
ctx
!=
nullptr
);
CHECK
(
ctx
!=
nullptr
);
CHECK
(
op
!=
nullptr
);
CHECK
(
op
!=
nullptr
);
...
@@ -44,17 +63,29 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
...
@@ -44,17 +63,29 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input vars and op attributes
// Get input vars and op attributes
auto
x_var_name
=
op_info
->
Input
(
"X"
).
front
();
auto
x_var_name
=
op_info
->
Input
(
"X"
).
front
();
// auto x = scope->FindMutableTenso
r(x_var_name)->GetMutable<Tensor>();
auto
x
=
scope
->
FindVa
r
(
x_var_name
)
->
GetMutable
<
Tensor
>
();
// auto x_dims = x->dims
();
auto
x_dims
=
x
->
dims
().
Vectorize
();
auto
out_var_name
=
op_info
->
Output
(
"Out"
).
front
();
auto
out_var_name
=
op_info
->
Output
(
"Out"
).
front
();
auto
output
=
scope
->
FindVar
(
out_var_name
)
->
GetMutable
<
Tensor
>
();
auto
output
=
scope
->
FindVar
(
out_var_name
)
->
GetMutable
<
Tensor
>
();
auto
output_dims
=
output
->
dims
().
Vectorize
();
auto
output_dims
=
output
->
dims
().
Vectorize
();
auto
axis
=
op_info
->
GetAttr
<
std
::
vector
<
int
>>
(
"axis"
);
auto
axis
=
op_info
->
GetAttr
<
std
::
vector
<
int
>>
(
"axis"
);
auto
axis_4d
=
axis_to_4d
(
axis
);
std
::
vector
<
int
>
axis_nhwc
;
if
(
axis
.
size
()
==
4
)
{
axis_nhwc
=
axis_to_nhwc4d
(
axis
);
}
else
if
(
axis
.
size
(
0
==
3
))
{
axis_nhwc
=
axis_to_nhw3d
(
axis
);
}
else
{
CHECK
(
0
)
<<
"Unsupport dim in mlu transpose"
;
}
auto
output_dims_nhwc
=
infer_shape
(
x_dims
,
axis_nhwc
);
output
->
Resize
(
output_dims_nhwc
);
auto
output_tensor
=
graph
->
AddNode
(
auto
output_tensor
=
graph
->
AddNode
(
out_var_name
,
output_dims
,
CNML_TENSOR
,
CNML_NHWC
,
graph
->
FPType
());
out_var_name
,
output_dims
_nhwc
,
CNML_TENSOR
,
CNML_NHWC
,
graph
->
FPType
());
CHECK
(
graph
->
HasNode
(
x_var_name
));
CHECK
(
graph
->
HasNode
(
x_var_name
));
auto
input_tensor
=
graph
->
GetNode
(
x_var_name
);
auto
input_tensor
=
graph
->
GetNode
(
x_var_name
);
...
@@ -63,7 +94,7 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
...
@@ -63,7 +94,7 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
cnmlNdTransposeOpParam_t
transpose_param
{
nullptr
};
cnmlNdTransposeOpParam_t
transpose_param
{
nullptr
};
CNML_CALL
(
cnmlCreateNdTransposeOpParam
(
CNML_CALL
(
cnmlCreateNdTransposeOpParam
(
&
transpose_param
,
axis_
4d
.
data
(),
axis_4d
.
size
()));
&
transpose_param
,
axis_
nhwc
.
data
(),
axis_nhwc
.
size
()));
// Use cnmlCreatexxxOpForward to create op.
// Use cnmlCreatexxxOpForward to create op.
CNML_CALL
(
cnmlCreateNdTransposeProOp
(
&
transpose_op_
,
CNML_CALL
(
cnmlCreateNdTransposeProOp
(
&
transpose_op_
,
...
...
lite/kernels/mlu/subgraph_compute.h
浏览文件 @
c5e83404
...
@@ -97,9 +97,11 @@ class SubgraphEngine : public subgraph::Engine {
...
@@ -97,9 +97,11 @@ class SubgraphEngine : public subgraph::Engine {
for
(
auto
&
inst
:
origin_program_
)
{
for
(
auto
&
inst
:
origin_program_
)
{
auto
op
=
inst
.
op
();
auto
op
=
inst
.
op
();
CHECK
(
op
);
CHECK
(
op
);
std
::
string
op_type
=
op
->
op_info
()
->
Type
();
op
->
CheckShape
();
op
->
CheckShape
();
if
(
op_type
!=
"concat"
)
{
op
->
InferShape
();
op
->
InferShape
();
std
::
string
op_type
=
op
->
op_info
()
->
Type
();
}
if
(
!
bridges
.
Exists
(
op_type
,
TARGET
(
kMLU
)))
{
if
(
!
bridges
.
Exists
(
op_type
,
TARGET
(
kMLU
)))
{
LOG
(
INFO
)
<<
"MLU bridges doesn't support op_type: "
<<
op_type
;
LOG
(
INFO
)
<<
"MLU bridges doesn't support op_type: "
<<
op_type
;
return
subgraph
::
FAILED
;
return
subgraph
::
FAILED
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录