提交 c5e83404 编写于 作者: J jackzhang235

add support for 3 dim inputs of mlu subgraph op

上级 ed48feaa
...@@ -50,10 +50,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, ...@@ -50,10 +50,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
op_desc.SetAttr<int>("out_dtype", 4); // FP16 op_desc.SetAttr<int>("out_dtype", 4); // FP16
op_desc.SetInput("X", {cur_node->AsArg().name}); op_desc.SetInput("X", {cur_node->AsArg().name});
op_desc.SetOutput("Out", {cast_arg_name}); op_desc.SetOutput("Out", {cast_arg_name});
} else if (op_type == "transpose") { } else if (op_type == "layout") {
// NCHW -> NHWC // NCHW -> NHWC
op_desc.SetAttr<std::vector<int>>("axis", {0, 2, 3, 1}); op_desc.SetInput("Input", {cur_node->AsArg().name});
op_desc.SetInput("X", {cur_node->AsArg().name});
op_desc.SetOutput("Out", {cast_arg_name}); op_desc.SetOutput("Out", {cast_arg_name});
} else if (op_type == "io_copy") { } else if (op_type == "io_copy") {
op_desc.SetInput("Input", {cur_node->AsArg().name}); op_desc.SetInput("Input", {cur_node->AsArg().name});
...@@ -72,8 +71,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, ...@@ -72,8 +71,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) { if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) {
is_found = true; is_found = true;
} }
} else if (op_type == "transpose") { } else if (op_type == "layout") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
DataLayoutCompatible(*out_arg_ty, *cast_type)) {
is_found = true; is_found = true;
}
} else if (op_type == "io_copy") { } else if (op_type == "io_copy") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
...@@ -89,8 +93,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type, ...@@ -89,8 +93,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
// we pick the kernel // we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op); cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt(); auto& stmt = cast_inst->AsStmt();
if (op_type == "layout") {
stmt.picked_kernel().SetContext( stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target())); ContextScheduler::Global().NewContext(TARGET(kX86)));
} else {
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target()));
}
break; break;
} }
} }
...@@ -127,10 +136,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -127,10 +136,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
op_desc.SetAttr<int>("out_dtype", 5); // FP16 op_desc.SetAttr<int>("out_dtype", 5); // FP16
op_desc.SetInput("X", {cast_arg_name}); op_desc.SetInput("X", {cast_arg_name});
op_desc.SetOutput("Out", {cur_node->AsArg().name}); op_desc.SetOutput("Out", {cur_node->AsArg().name});
} else if (op_type == "transpose") { } else if (op_type == "layout") {
// NHWC -> NCHW // NHWC -> NCHW
op_desc.SetAttr<std::vector<int>>("axis", {0, 3, 1, 2}); op_desc.SetInput("Input", {cast_arg_name});
op_desc.SetInput("X", {cast_arg_name});
op_desc.SetOutput("Out", {cur_node->AsArg().name}); op_desc.SetOutput("Out", {cur_node->AsArg().name});
} else if (op_type == "io_copy") { } else if (op_type == "io_copy") {
op_desc.SetInput("Input", {cast_arg_name}); op_desc.SetInput("Input", {cast_arg_name});
...@@ -151,8 +159,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -151,8 +159,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
if (PrecisionCompatibleTo(*in_arg_ty, *cast_type)) { if (PrecisionCompatibleTo(*in_arg_ty, *cast_type)) {
is_found = true; is_found = true;
} }
} else if (op_type == "transpose") { } else if (op_type == "layout") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cast_type) &&
DataLayoutCompatible(*out_arg_ty, *cur_node->AsArg().type)) {
is_found = true; is_found = true;
}
} else if (op_type == "io_copy") { } else if (op_type == "io_copy") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input"); const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out"); const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
...@@ -168,8 +181,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type, ...@@ -168,8 +181,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
// we pick the kernel // we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op); cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt(); auto& stmt = cast_inst->AsStmt();
if (op_type == "layout") {
stmt.picked_kernel().SetContext( stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target())); ContextScheduler::Global().NewContext(TARGET(kX86)));
} else {
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target()));
}
break; break;
} }
} }
...@@ -197,8 +215,8 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph, ...@@ -197,8 +215,8 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
// layout cast node // layout cast node
if (head_type->layout() != inst_type->layout()) { if (head_type->layout() != inst_type->layout()) {
cur_node = InsertCastBefore( cur_node = InsertCastBefore(
"transpose", "layout",
name_prefix + "transpose", name_prefix + "layout",
graph, graph,
cur_node, cur_node,
inst_node, inst_node,
...@@ -346,8 +364,8 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph, ...@@ -346,8 +364,8 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
// layout cast node // layout cast node
if (tail_type->layout() != inst_type->layout()) { if (tail_type->layout() != inst_type->layout()) {
cur_node = InsertCastAfter( cur_node = InsertCastAfter(
"transpose", "layout",
name_prefix + "transpose", name_prefix + "layout",
graph, graph,
cur_node, cur_node,
inst_node, inst_node,
......
...@@ -53,7 +53,7 @@ class SubgraphCastDisplayPass : public DebugPass { ...@@ -53,7 +53,7 @@ class SubgraphCastDisplayPass : public DebugPass {
for (auto p_in_stmt_node : p_in_arg_node->inlinks) { for (auto p_in_stmt_node : p_in_arg_node->inlinks) {
CHECK(p_in_stmt_node->IsStmt()); CHECK(p_in_stmt_node->IsStmt());
std::string stmt_op_type = p_in_stmt_node->AsStmt().op_type(); std::string stmt_op_type = p_in_stmt_node->AsStmt().op_type();
if (stmt_op_type == "cast" || stmt_op_type == "transpose" || if (stmt_op_type == "cast" || stmt_op_type == "layout" ||
stmt_op_type == "io_copy") { stmt_op_type == "io_copy") {
display_debug_info(*p_in_stmt_node, stmt_op_type, true, false); display_debug_info(*p_in_stmt_node, stmt_op_type, true, false);
} else { } else {
...@@ -76,7 +76,7 @@ class SubgraphCastDisplayPass : public DebugPass { ...@@ -76,7 +76,7 @@ class SubgraphCastDisplayPass : public DebugPass {
for (auto p_out_stmt_node : p_out_arg_node->outlinks) { for (auto p_out_stmt_node : p_out_arg_node->outlinks) {
CHECK(p_out_stmt_node->IsStmt()); CHECK(p_out_stmt_node->IsStmt());
std::string stmt_op_type = p_out_stmt_node->AsStmt().op_type(); std::string stmt_op_type = p_out_stmt_node->AsStmt().op_type();
if (stmt_op_type == "cast" || stmt_op_type == "transpose" || if (stmt_op_type == "cast" || stmt_op_type == "layout" ||
stmt_op_type == "io_copy") { stmt_op_type == "io_copy") {
display_debug_info(*p_out_stmt_node, stmt_op_type, false, true); display_debug_info(*p_out_stmt_node, stmt_op_type, false, true);
} else { } else {
......
...@@ -116,12 +116,12 @@ class Optimizer { ...@@ -116,12 +116,12 @@ class Optimizer {
"argument_type_display_pass", "argument_type_display_pass",
"mlu_subgraph_pass", "mlu_subgraph_pass",
"mlu_postprocess_pass",
// subgraph_cast_display_pass
"runtime_context_assign_pass", "runtime_context_assign_pass",
"argument_type_display_pass", "argument_type_display_pass",
"mlu_postprocess_pass",
"memory_optimize_pass"}}; "memory_optimize_pass"}};
if (passes.size() == 1) { if (passes.size() == 1) {
......
...@@ -6,3 +6,4 @@ add_subdirectory(bridges) ...@@ -6,3 +6,4 @@ add_subdirectory(bridges)
add_kernel(subgraph_compute_mlu MLU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} ${mlu_subgraph_bridges}) add_kernel(subgraph_compute_mlu MLU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} ${mlu_subgraph_bridges})
add_kernel(io_copy_compute_mlu MLU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} ${math_mlu}) add_kernel(io_copy_compute_mlu MLU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
add_kernel(calib_compute_mlu MLU basic SRCS calib_compute.cc DEPS ${lite_kernel_deps} ${math_mlu}) add_kernel(calib_compute_mlu MLU basic SRCS calib_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
add_kernel(layout_compute_mlu MLU basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
...@@ -46,26 +46,37 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -46,26 +46,37 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
input_dims.push_back(x->dims().Vectorize()); input_dims.push_back(x->dims().Vectorize());
} }
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>(); auto dims = input_dims[0].size();
int axis = (param_axis < 0) ? (param_axis + output->dims().size()) : param_axis; int axis = (param_axis < 0) ? (param_axis + dims) : param_axis;
int nhwc_axis = -1;
if (dims == 4) {
int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2}; int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2};
int nhwc_axis = nchw_to_nhwc_axis_map[axis]; nhwc_axis = nchw_to_nhwc_axis_map[axis];
} else if (dims == 3) {
int nchw_to_nhwc_axis_map[3] = {0, 2, 1};
nhwc_axis = nchw_to_nhwc_axis_map[axis];
} else {
CHECK(0) << "Unsupport dims in mlu concat";
}
std::vector<int64_t> output_dims; std::vector<int64_t> output_dims;
output_dims.assign(output->dims().size(), 0); output_dims.assign(dims, 0);
/* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis, nhwc_axis) << std::endl; */ /* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis,
* nhwc_axis) << std::endl; */
for (int i = 0; i < output_dims.size(); ++i) { for (int i = 0; i < output_dims.size(); ++i) {
if (i == nhwc_axis) { if (i == nhwc_axis) {
for (auto &dim : input_dims) output_dims[i] += dim[i]; for (auto& dim : input_dims) output_dims[i] += dim[i];
} else { } else {
output_dims[i] = input_dims[0][i]; output_dims[i] = input_dims[0][i];
} }
} }
/* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") << std::endl; */ /* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") <<
* std::endl; */
auto* output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
output->Resize(output_dims); output->Resize(output_dims);
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
......
...@@ -21,18 +21,37 @@ namespace lite { ...@@ -21,18 +21,37 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace mlu { namespace mlu {
std::vector<int> axis_to_4d(std::vector<int> axis) { std::vector<int> axis_to_nhwc4d(const std::vector<int>& axis) {
if (axis.size() >= 4) { CHECK_EQ(axis.size(), 4);
return axis; std::vector<int> new_axis(4, 0);
const std::vector<int> axis_map1 = {0, 2, 3, 1};
const std::vector<int> axis_map2 = {0, 3, 1, 2};
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = axis_map2[axis[axis_map1[i]]];
} }
std::vector<int> new_axis = {0, 1, 2, 3}; return new_axis;
int i = 0; }
for (i = 0; i < axis.size(); i++) {
new_axis[i] = axis[i]; std::vector<int> axis_to_nhw3d(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 3);
std::vector<int> new_axis(3, 0);
const std::vector<int> axis_map = {0, 2, 1};
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = axis_map[axis[axis_map[i]]];
} }
new_axis.push_back(3);
return new_axis; return new_axis;
} }
std::vector<int64_t> infer_shape(const std::vector<int64_t>& x_dims,
const std::vector<int>& axis_nhwc) {
std::vector<int64_t> out_dims(x_dims);
for (size_t i = 0; i < out_dims.size(); ++i) {
out_dims[i] = x_dims[axis_nhwc[i]];
}
return out_dims;
}
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr); CHECK(ctx != nullptr);
CHECK(op != nullptr); CHECK(op != nullptr);
...@@ -44,17 +63,29 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,17 +63,29 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input vars and op attributes // Get input vars and op attributes
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
// auto x = scope->FindMutableTensor(x_var_name)->GetMutable<Tensor>(); auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
// auto x_dims = x->dims(); auto x_dims = x->dims().Vectorize();
auto out_var_name = op_info->Output("Out").front(); auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>(); auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize(); auto output_dims = output->dims().Vectorize();
auto axis = op_info->GetAttr<std::vector<int>>("axis"); auto axis = op_info->GetAttr<std::vector<int>>("axis");
auto axis_4d = axis_to_4d(axis);
std::vector<int> axis_nhwc;
if (axis.size() == 4) {
axis_nhwc = axis_to_nhwc4d(axis);
} else if (axis.size(0 == 3)) {
axis_nhwc = axis_to_nhw3d(axis);
} else {
CHECK(0) << "Unsupport dim in mlu transpose";
}
auto output_dims_nhwc = infer_shape(x_dims, axis_nhwc);
output->Resize(output_dims_nhwc);
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); out_var_name, output_dims_nhwc, CNML_TENSOR, CNML_NHWC, graph->FPType());
CHECK(graph->HasNode(x_var_name)); CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
...@@ -63,7 +94,7 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -63,7 +94,7 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
cnmlNdTransposeOpParam_t transpose_param{nullptr}; cnmlNdTransposeOpParam_t transpose_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam( CNML_CALL(cnmlCreateNdTransposeOpParam(
&transpose_param, axis_4d.data(), axis_4d.size())); &transpose_param, axis_nhwc.data(), axis_nhwc.size()));
// Use cnmlCreatexxxOpForward to create op. // Use cnmlCreatexxxOpForward to create op.
CNML_CALL(cnmlCreateNdTransposeProOp(&transpose_op_, CNML_CALL(cnmlCreateNdTransposeProOp(&transpose_op_,
......
...@@ -97,9 +97,11 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -97,9 +97,11 @@ class SubgraphEngine : public subgraph::Engine {
for (auto& inst : origin_program_) { for (auto& inst : origin_program_) {
auto op = inst.op(); auto op = inst.op();
CHECK(op); CHECK(op);
std::string op_type = op->op_info()->Type();
op->CheckShape(); op->CheckShape();
if (op_type != "concat") {
op->InferShape(); op->InferShape();
std::string op_type = op->op_info()->Type(); }
if (!bridges.Exists(op_type, TARGET(kMLU))) { if (!bridges.Exists(op_type, TARGET(kMLU))) {
LOG(INFO) << "MLU bridges doesn't support op_type: " << op_type; LOG(INFO) << "MLU bridges doesn't support op_type: " << op_type;
return subgraph::FAILED; return subgraph::FAILED;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册