提交 c5e83404 编写于 作者: J jackzhang235

add support for 3 dim inputs of mlu subgraph op

上级 ed48feaa
......@@ -50,10 +50,9 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
op_desc.SetAttr<int>("out_dtype", 4); // FP16
op_desc.SetInput("X", {cur_node->AsArg().name});
op_desc.SetOutput("Out", {cast_arg_name});
} else if (op_type == "transpose") {
} else if (op_type == "layout") {
// NCHW -> NHWC
op_desc.SetAttr<std::vector<int>>("axis", {0, 2, 3, 1});
op_desc.SetInput("X", {cur_node->AsArg().name});
op_desc.SetInput("Input", {cur_node->AsArg().name});
op_desc.SetOutput("Out", {cast_arg_name});
} else if (op_type == "io_copy") {
op_desc.SetInput("Input", {cur_node->AsArg().name});
......@@ -72,8 +71,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
if (PrecisionCompatibleTo(*in_arg_ty, *cur_node->AsArg().type)) {
is_found = true;
}
} else if (op_type == "transpose") {
is_found = true;
} else if (op_type == "layout") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cur_node->AsArg().type) &&
DataLayoutCompatible(*out_arg_ty, *cast_type)) {
is_found = true;
}
} else if (op_type == "io_copy") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
......@@ -89,8 +93,13 @@ Node* MLUPostprocessPass::InsertCastBefore(const std::string& op_type,
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt();
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target()));
if (op_type == "layout") {
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(TARGET(kX86)));
} else {
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target()));
}
break;
}
}
......@@ -127,10 +136,9 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
op_desc.SetAttr<int>("out_dtype", 5); // FP16
op_desc.SetInput("X", {cast_arg_name});
op_desc.SetOutput("Out", {cur_node->AsArg().name});
} else if (op_type == "transpose") {
} else if (op_type == "layout") {
// NHWC -> NCHW
op_desc.SetAttr<std::vector<int>>("axis", {0, 3, 1, 2});
op_desc.SetInput("X", {cast_arg_name});
op_desc.SetInput("Input", {cast_arg_name});
op_desc.SetOutput("Out", {cur_node->AsArg().name});
} else if (op_type == "io_copy") {
op_desc.SetInput("Input", {cast_arg_name});
......@@ -151,8 +159,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
if (PrecisionCompatibleTo(*in_arg_ty, *cast_type)) {
is_found = true;
}
} else if (op_type == "transpose") {
is_found = true;
} else if (op_type == "layout") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
if (DataLayoutCompatible(*in_arg_ty, *cast_type) &&
DataLayoutCompatible(*out_arg_ty, *cur_node->AsArg().type)) {
is_found = true;
}
} else if (op_type == "io_copy") {
const Type* in_arg_ty = kernel->GetInputDeclType("Input");
const Type* out_arg_ty = kernel->GetOutputDeclType("Out");
......@@ -168,8 +181,13 @@ Node* MLUPostprocessPass::InsertCastAfter(const std::string& op_type,
// we pick the kernel
cast_inst->AsStmt(op_type, std::move(selected_kernels), cast_op);
auto& stmt = cast_inst->AsStmt();
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(stmt.picked_kernel().target()));
if (op_type == "layout") {
stmt.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(TARGET(kX86)));
} else {
stmt.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
stmt.picked_kernel().target()));
}
break;
}
}
......@@ -197,8 +215,8 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
// layout cast node
if (head_type->layout() != inst_type->layout()) {
cur_node = InsertCastBefore(
"transpose",
name_prefix + "transpose",
"layout",
name_prefix + "layout",
graph,
cur_node,
inst_node,
......@@ -346,8 +364,8 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
// layout cast node
if (tail_type->layout() != inst_type->layout()) {
cur_node = InsertCastAfter(
"transpose",
name_prefix + "transpose",
"layout",
name_prefix + "layout",
graph,
cur_node,
inst_node,
......
......@@ -53,7 +53,7 @@ class SubgraphCastDisplayPass : public DebugPass {
for (auto p_in_stmt_node : p_in_arg_node->inlinks) {
CHECK(p_in_stmt_node->IsStmt());
std::string stmt_op_type = p_in_stmt_node->AsStmt().op_type();
if (stmt_op_type == "cast" || stmt_op_type == "transpose" ||
if (stmt_op_type == "cast" || stmt_op_type == "layout" ||
stmt_op_type == "io_copy") {
display_debug_info(*p_in_stmt_node, stmt_op_type, true, false);
} else {
......@@ -76,7 +76,7 @@ class SubgraphCastDisplayPass : public DebugPass {
for (auto p_out_stmt_node : p_out_arg_node->outlinks) {
CHECK(p_out_stmt_node->IsStmt());
std::string stmt_op_type = p_out_stmt_node->AsStmt().op_type();
if (stmt_op_type == "cast" || stmt_op_type == "transpose" ||
if (stmt_op_type == "cast" || stmt_op_type == "layout" ||
stmt_op_type == "io_copy") {
display_debug_info(*p_out_stmt_node, stmt_op_type, false, true);
} else {
......
......@@ -116,12 +116,12 @@ class Optimizer {
"argument_type_display_pass",
"mlu_subgraph_pass",
"mlu_postprocess_pass",
// subgraph_cast_display_pass
"runtime_context_assign_pass",
"argument_type_display_pass",
"mlu_postprocess_pass",
"memory_optimize_pass"}};
if (passes.size() == 1) {
......
......@@ -6,3 +6,4 @@ add_subdirectory(bridges)
add_kernel(subgraph_compute_mlu MLU basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} ${mlu_subgraph_bridges})
add_kernel(io_copy_compute_mlu MLU basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
add_kernel(calib_compute_mlu MLU basic SRCS calib_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
add_kernel(layout_compute_mlu MLU basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} ${math_mlu})
......@@ -46,26 +46,37 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
input_dims.push_back(x->dims().Vectorize());
}
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
int axis = (param_axis < 0) ? (param_axis + output->dims().size()) : param_axis;
int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2};
int nhwc_axis = nchw_to_nhwc_axis_map[axis];
auto dims = input_dims[0].size();
int axis = (param_axis < 0) ? (param_axis + dims) : param_axis;
int nhwc_axis = -1;
if (dims == 4) {
int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2};
nhwc_axis = nchw_to_nhwc_axis_map[axis];
} else if (dims == 3) {
int nchw_to_nhwc_axis_map[3] = {0, 2, 1};
nhwc_axis = nchw_to_nhwc_axis_map[axis];
} else {
CHECK(0) << "Unsupport dims in mlu concat";
}
std::vector<int64_t> output_dims;
output_dims.assign(output->dims().size(), 0);
output_dims.assign(dims, 0);
/* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis,
* nhwc_axis) << std::endl; */
/* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis, nhwc_axis) << std::endl; */
for (int i = 0; i < output_dims.size(); ++i) {
if (i == nhwc_axis) {
for (auto &dim : input_dims) output_dims[i] += dim[i];
for (auto& dim : input_dims) output_dims[i] += dim[i];
} else {
output_dims[i] = input_dims[0][i];
}
}
/* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") << std::endl; */
/* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") <<
* std::endl; */
auto* output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
output->Resize(output_dims);
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
......
......@@ -21,18 +21,37 @@ namespace lite {
namespace subgraph {
namespace mlu {
std::vector<int> axis_to_4d(std::vector<int> axis) {
if (axis.size() >= 4) {
return axis;
std::vector<int> axis_to_nhwc4d(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 4);
std::vector<int> new_axis(4, 0);
const std::vector<int> axis_map1 = {0, 2, 3, 1};
const std::vector<int> axis_map2 = {0, 3, 1, 2};
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = axis_map2[axis[axis_map1[i]]];
}
std::vector<int> new_axis = {0, 1, 2, 3};
int i = 0;
for (i = 0; i < axis.size(); i++) {
new_axis[i] = axis[i];
return new_axis;
}
std::vector<int> axis_to_nhw3d(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 3);
std::vector<int> new_axis(3, 0);
const std::vector<int> axis_map = {0, 2, 1};
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = axis_map[axis[axis_map[i]]];
}
new_axis.push_back(3);
return new_axis;
}
std::vector<int64_t> infer_shape(const std::vector<int64_t>& x_dims,
const std::vector<int>& axis_nhwc) {
std::vector<int64_t> out_dims(x_dims);
for (size_t i = 0; i < out_dims.size(); ++i) {
out_dims[i] = x_dims[axis_nhwc[i]];
}
return out_dims;
}
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
......@@ -44,17 +63,29 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input vars and op attributes
auto x_var_name = op_info->Input("X").front();
// auto x = scope->FindMutableTensor(x_var_name)->GetMutable<Tensor>();
// auto x_dims = x->dims();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims().Vectorize();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
auto axis_4d = axis_to_4d(axis);
std::vector<int> axis_nhwc;
if (axis.size() == 4) {
axis_nhwc = axis_to_nhwc4d(axis);
} else if (axis.size(0 == 3)) {
axis_nhwc = axis_to_nhw3d(axis);
} else {
CHECK(0) << "Unsupport dim in mlu transpose";
}
auto output_dims_nhwc = infer_shape(x_dims, axis_nhwc);
output->Resize(output_dims_nhwc);
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
out_var_name, output_dims_nhwc, CNML_TENSOR, CNML_NHWC, graph->FPType());
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
......@@ -63,7 +94,7 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
cnmlNdTransposeOpParam_t transpose_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam(
&transpose_param, axis_4d.data(), axis_4d.size()));
&transpose_param, axis_nhwc.data(), axis_nhwc.size()));
// Use cnmlCreatexxxOpForward to create op.
CNML_CALL(cnmlCreateNdTransposeProOp(&transpose_op_,
......
......@@ -97,9 +97,11 @@ class SubgraphEngine : public subgraph::Engine {
for (auto& inst : origin_program_) {
auto op = inst.op();
CHECK(op);
op->CheckShape();
op->InferShape();
std::string op_type = op->op_info()->Type();
op->CheckShape();
if (op_type != "concat") {
op->InferShape();
}
if (!bridges.Exists(op_type, TARGET(kMLU))) {
LOG(INFO) << "MLU bridges doesn't support op_type: " << op_type;
return subgraph::FAILED;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册