提交 1bb607bb 编写于 作者: - --get 提交者: MaxwellDing

(bugfix): change 4-d trans before and after (flatten or reshape) op to nd trans

上级 99b7f238
...@@ -38,16 +38,34 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -38,16 +38,34 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ================== Trans1: NHWC => NCHW =========================== // ================== Trans1: NHWC => NCHW ===========================
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
std::vector<int> nhwc_to_nchw_axis = {0, 3, 1, 2}; // std::vector<int> nhwc_to_nchw_axis = {0, 3, 1, 2};
std::vector<int> trans_1_axis;
switch (x->dims().size()) {
case 4:
trans_1_axis = {0, 3, 1, 2};
break;
case 3:
trans_1_axis = {0, 2, 1};
break;
case 2:
trans_1_axis = {0, 1};
break;
case 1:
trans_1_axis = {0};
break;
default:
break;
}
auto trans1_out = graph->AddNode(x_var_name + ".trans.i", auto trans1_out = graph->AddNode(x_var_name + ".trans.i",
x->dims().Vectorize(), x->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph->FPType()); graph->FPType(),
CNML_NCHW);
cnmlBaseOp_t trans1_op{nullptr}; cnmlBaseOp_t trans1_op{nullptr};
cnmlNdTransposeOpParam_t trans1_param{nullptr}; cnmlNdTransposeOpParam_t trans1_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam( CNML_CALL(cnmlCreateNdTransposeOpParam(
&trans1_param, nhwc_to_nchw_axis.data(), nhwc_to_nchw_axis.size())); &trans1_param, trans_1_axis.data(), trans_1_axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&trans1_op, CNML_CALL(cnmlCreateNdTransposeProOp(&trans1_op,
input_tensor->mlu_tensor(), input_tensor->mlu_tensor(),
trans1_out->mlu_tensor(), trans1_out->mlu_tensor(),
...@@ -59,31 +77,48 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -59,31 +77,48 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto trans2_input = graph->AddNode(out_var_name + ".trans.o", auto trans2_input = graph->AddNode(out_var_name + ".trans.o",
output_dims, output_dims,
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph->FPType()); graph->FPType(),
CNML_NCHW);
int cnml_trans2_input_shape[4]; int cnml_trans2_input_shape[4];
CNML_CALL( CNML_CALL(
cnmlGetTensorShape(trans2_input->mlu_tensor(), cnml_trans2_input_shape)); cnmlGetTensorShape(trans2_input->mlu_tensor(), cnml_trans2_input_shape));
cnmlReshapeOpParam_t reshape_param{nullptr}; cnmlReshapeOpParam_t reshape_param{nullptr};
CNML_CALL( CNML_CALL(cnmlCreateNdReshapeOpParam(
cnmlCreateNdReshapeOpParam(&reshape_param, cnml_trans2_input_shape, 4)); &reshape_param, cnml_trans2_input_shape, output->dims().size()));
// Use cnmlCreatexxxOpForward to create op. // Use cnmlCreatexxxOpForward to create op.
CNML_CALL(cnmlCreateReshapeOp(&flatten_op, CNML_CALL(cnmlCreateReshapeOp(&flatten_op,
reshape_param, reshape_param,
trans1_out->mlu_tensor(), trans1_out->mlu_tensor(),
trans2_input->mlu_tensor())); trans2_input->mlu_tensor()));
// ======================= Flatten End =================================== // ======================= Flatten End ===================================
// ================== Trans2: NCHW => NHWC =============================== // ================== Trans2: NCHW => NHWC ===============================
std::vector<int> nchw_to_nhwc_axis = {0, 2, 3, 1}; // std::vector<int> nchw_to_nhwc_axis = {0, 2, 3, 1};
std::vector<int> trans_2_axis;
switch (output->dims().size()) {
case 4:
trans_2_axis = {0, 2, 3, 1};
break;
case 3:
trans_2_axis = {0, 2, 1};
break;
case 2:
trans_2_axis = {0, 1};
break;
case 1:
trans_2_axis = {0};
break;
default:
break;
}
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t trans2_op{nullptr}; cnmlBaseOp_t trans2_op{nullptr};
cnmlNdTransposeOpParam_t trans2_param{nullptr}; cnmlNdTransposeOpParam_t trans2_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam( CNML_CALL(cnmlCreateNdTransposeOpParam(
&trans2_param, nchw_to_nhwc_axis.data(), nchw_to_nhwc_axis.size())); &trans2_param, trans_2_axis.data(), trans_2_axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&trans2_op, CNML_CALL(cnmlCreateNdTransposeProOp(&trans2_op,
trans2_input->mlu_tensor(), trans2_input->mlu_tensor(),
output_tensor->mlu_tensor(), output_tensor->mlu_tensor(),
...@@ -96,15 +131,10 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -96,15 +131,10 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
VLOG(6) << "out_var_name: " << out_var_name; VLOG(6) << "out_var_name: " << out_var_name;
VLOG(6) << "input dim: " << x->dims(); VLOG(6) << "input dim: " << x->dims();
VLOG(6) << "output dim: " << output->dims(); VLOG(6) << "output dim: " << output->dims();
int tmp_shape[4]; // cnmlPrintTensor(input_tensor->mlu_tensor(), CNML_TENSOR);
cnmlGetTensorShape(trans1_out->mlu_tensor(), tmp_shape); // cnmlPrintTensor(trans1_out->mlu_tensor(), CNML_TENSOR);
VLOG(6) << "trans1_out shape" // cnmlPrintTensor(trans2_input->mlu_tensor(), CNML_TENSOR);
<< ": " << tmp_shape[0] << " " << tmp_shape[1] << " " << tmp_shape[2] // cnmlPrintTensor(output_tensor->mlu_tensor(), CNML_TENSOR);
<< " " << tmp_shape[3];
cnmlGetTensorShape(trans2_input->mlu_tensor(), tmp_shape);
VLOG(6) << "trans2_input shape"
<< ": " << tmp_shape[0] << " " << tmp_shape[1] << " " << tmp_shape[2]
<< " " << tmp_shape[3];
// ============== DEBUG END =============== // ============== DEBUG END ===============
graph->FuseOp(trans1_op); graph->FuseOp(trans1_op);
graph->FuseOp(flatten_op); graph->FuseOp(flatten_op);
......
...@@ -68,11 +68,7 @@ void test_flatten(std::vector<int64_t> input_shape, int axis) { ...@@ -68,11 +68,7 @@ void test_flatten(std::vector<int64_t> input_shape, int axis) {
} }
} }
TEST(MLUBridges, flatten) { TEST(MLUBridges, flatten) { test_flatten({1, 2, 4, 4}, 2); }
std::vector<int64_t> input_shape = {1, 2, 4, 4};
int axis = 2;
test_flatten(input_shape, axis);
}
} // namespace mlu } // namespace mlu
} // namespace subgraph } // namespace subgraph
} // namespace lite } // namespace lite
......
...@@ -38,16 +38,34 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -38,16 +38,34 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ================== Trans1: NHWC => NCHW =========================== // ================== Trans1: NHWC => NCHW ===========================
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
std::vector<int> nhwc_to_nchw_axis = {0, 3, 1, 2}; // std::vector<int> nhwc_to_nchw_axis = {0, 3, 1, 2};
std::vector<int> trans_1_axis;
switch (x->dims().size()) {
case 4:
trans_1_axis = {0, 3, 1, 2};
break;
case 3:
trans_1_axis = {0, 2, 1};
break;
case 2:
trans_1_axis = {0, 1};
break;
case 1:
trans_1_axis = {0};
break;
default:
break;
}
auto trans1_out = graph->AddNode(x_var_name + ".trans.i", auto trans1_out = graph->AddNode(x_var_name + ".trans.i",
x->dims().Vectorize(), x->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph->FPType()); graph->FPType(),
CNML_NCHW);
cnmlBaseOp_t trans1_op{nullptr}; cnmlBaseOp_t trans1_op{nullptr};
cnmlNdTransposeOpParam_t trans1_param{nullptr}; cnmlNdTransposeOpParam_t trans1_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam( CNML_CALL(cnmlCreateNdTransposeOpParam(
&trans1_param, nhwc_to_nchw_axis.data(), nhwc_to_nchw_axis.size())); &trans1_param, trans_1_axis.data(), trans_1_axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&trans1_op, CNML_CALL(cnmlCreateNdTransposeProOp(&trans1_op,
input_tensor->mlu_tensor(), input_tensor->mlu_tensor(),
trans1_out->mlu_tensor(), trans1_out->mlu_tensor(),
...@@ -59,8 +77,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -59,8 +77,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto trans2_input = graph->AddNode(out_var_name + ".trans.o", auto trans2_input = graph->AddNode(out_var_name + ".trans.o",
output_dims, output_dims,
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph->FPType()); graph->FPType(),
CNML_NCHW);
cnmlReshapeOpParam_t reshape_param{nullptr}; cnmlReshapeOpParam_t reshape_param{nullptr};
int cnml_trans2_input_shape[4]; int cnml_trans2_input_shape[4];
CNML_CALL( CNML_CALL(
...@@ -76,13 +95,30 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -76,13 +95,30 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ======================= Reshape op End =================================== // ======================= Reshape op End ===================================
// ================== Trans2: NCHW => NHWC =============================== // ================== Trans2: NCHW => NHWC ===============================
std::vector<int> nchw_to_nhwc_axis = {0, 2, 3, 1}; // std::vector<int> nchw_to_nhwc_axis = {0, 2, 3, 1};
std::vector<int> trans_2_axis;
switch (output->dims().size()) {
case 4:
trans_2_axis = {0, 2, 3, 1};
break;
case 3:
trans_2_axis = {0, 2, 1};
break;
case 2:
trans_2_axis = {0, 1};
break;
case 1:
trans_2_axis = {0};
break;
default:
break;
}
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t trans2_op{nullptr}; cnmlBaseOp_t trans2_op{nullptr};
cnmlNdTransposeOpParam_t trans2_param{nullptr}; cnmlNdTransposeOpParam_t trans2_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam( CNML_CALL(cnmlCreateNdTransposeOpParam(
&trans2_param, nchw_to_nhwc_axis.data(), nchw_to_nhwc_axis.size())); &trans2_param, trans_2_axis.data(), trans_2_axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&trans2_op, CNML_CALL(cnmlCreateNdTransposeProOp(&trans2_op,
trans2_input->mlu_tensor(), trans2_input->mlu_tensor(),
output_tensor->mlu_tensor(), output_tensor->mlu_tensor(),
...@@ -100,21 +136,12 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -100,21 +136,12 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
for (size_t i = 0; i < 4; i++) { for (size_t i = 0; i < 4; i++) {
VLOG(6) << cnml_input_shape[i]; VLOG(6) << cnml_input_shape[i];
} }
int tmp_shape[4]; // cnmlPrintTensor(input_tensor->mlu_tensor(), CNML_TENSOR);
cnmlGetTensorShape(trans1_out->mlu_tensor(), tmp_shape); // cnmlPrintTensor(trans1_out->mlu_tensor(), CNML_TENSOR);
VLOG(6) << "trans1_out shape" // cnmlPrintTensor(trans2_input->mlu_tensor(), CNML_TENSOR);
<< ": " << tmp_shape[0] << " " << tmp_shape[1] << " " << tmp_shape[2] // cnmlPrintTensor(output_tensor->mlu_tensor(), CNML_TENSOR);
<< " " << tmp_shape[3];
cnmlGetTensorShape(trans2_input->mlu_tensor(), tmp_shape);
VLOG(6) << "trans2_input shape"
<< ": " << tmp_shape[0] << " " << tmp_shape[1] << " " << tmp_shape[2]
<< " " << tmp_shape[3];
// =============== DEBUG END ================= // =============== DEBUG END =================
// CNML_CALL(cnmlCreateReshapeOp_V2(
// &reshape_op,
// input_tensor->mlu_tensor(),
// output_tensor->mlu_tensor()));
graph->FuseOp(trans1_op); graph->FuseOp(trans1_op);
graph->FuseOp(reshape_op); graph->FuseOp(reshape_op);
graph->FuseOp(trans2_op); graph->FuseOp(trans2_op);
......
...@@ -88,11 +88,7 @@ void test_reshape(std::vector<int64_t> input_shape, ...@@ -88,11 +88,7 @@ void test_reshape(std::vector<int64_t> input_shape,
} }
} }
TEST(MLUBridges, reshape) { TEST(MLUBridges, reshape) { test_reshape({1, 2, 4, 4}, {1, 4, 2, 4}); }
std::vector<int64_t> input_shape = {1, 2, 4, 4};
std::vector<int64_t> out_shape = {1, 4, 2, 4};
test_reshape(input_shape, out_shape);
}
} // namespace mlu } // namespace mlu
} // namespace subgraph } // namespace subgraph
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册