提交 067111d4 编写于 作者: J jiaopu 提交者: jackzhang235

fix transpose test

上级 1e3d867e
......@@ -49,6 +49,6 @@ lite_cc_test(test_fc_converter_mlu SRCS fc_op_test.cc DEPS scope optimizer targe
lite_cc_test(test_scale_converter_mlu SRCS scale_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_interp_converter_mlu SRCS interpolate_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_concat_converter_mlu SRCS concat_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
#lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}")
......@@ -106,10 +106,33 @@ void test_transpose(const std::vector<int64_t>& input_shape,
transpose_ref<float>(op);
out_ref->CopyDataFrom(*out);
Tensor input_x;
input_x.Resize(DDim(input_shape));
transpose(x->mutable_data<float>(),
input_x.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_x);
LaunchOp(op, {x_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
Tensor output_trans;
output_trans.Resize(out->dims());
auto os = out->dims();
transpose(out_data,
output_trans.mutable_data<float>(),
{static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])},
{0, 3, 1, 2});
out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册