提交 758fd379 编写于 作者: S Shixiaowei02

update unit test

上级 6ccbc24d
......@@ -280,6 +280,8 @@ TEST(conv_arm_int8, int8_fp32) {
group = oc = ic;
}
LOG(INFO) << "flag_bias: " << flag_bias;
const int dks = dilation * (ks - 1) + 1;
int oh = (ih + 2 * padding - dks) / stride + 1;
int ow = (iw + 2 * padding - dks) / stride + 1;
......@@ -291,7 +293,7 @@ TEST(conv_arm_int8, int8_fp32) {
Tensor input_fp32, input_int8;
Tensor filter_fp32, filter_int8;
Tensor bias_fp32, bias_int8;
Tensor bias_fp32, bias_int32;
Tensor output_int32_ref, output_int32;
Tensor output_fp32_ref, output_fp32;
Tensor output_int8_ref, output_int8;
......@@ -301,7 +303,7 @@ TEST(conv_arm_int8, int8_fp32) {
filter_fp32.Resize(filter_shape);
filter_int8.Resize(filter_shape);
bias_fp32.Resize(bias_shape);
bias_int8.Resize(bias_shape);
bias_int32.Resize(bias_shape);
output_int32.Resize(output_shape);
output_int32_ref.Resize(output_shape);
output_fp32_ref.Resize(output_shape);
......@@ -321,8 +323,7 @@ TEST(conv_arm_int8, int8_fp32) {
float* bias_fp32_data =
bias_fp32.mutable_data<float>();
int8_t* bias_int8_data =
bias_int8.mutable_data<int8_t>();
int* bias_int32_data = bias_int32.mutable_data<int>();
for (int i = 0; i < input_fp32.dims().production();
i++) {
......@@ -354,10 +355,21 @@ TEST(conv_arm_int8, int8_fp32) {
filter_fp32_data, filter_int8_data,
w_scale.data(), axis_size, 1, inner_size);
// lite::arm::math::trans_fp32_bias_to_int32_basic(&bias_fp32,
// &bias_int32, in_scale[0], w_scale);
for (int i = 0; i < bias_int32.dims().production();
i++) {
bias_int32_data[i] = 1;
}
operators::ConvParam param;
param.x = &input_int8;
param.filter = &filter_int8;
param.bias = &bias_int8;
if (flag_bias) {
param.bias = &bias_int32;
} else {
param.bias = nullptr;
}
param.fuse_relu = false;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
......@@ -371,6 +383,7 @@ TEST(conv_arm_int8, int8_fp32) {
output_int32_ref.mutable_data<int>();
// ============ int8gemm_int32 ============
/*
param.output = &output_int32;
std::unique_ptr<KernelContext> ctx_int32(
new KernelContext);
......@@ -388,7 +401,7 @@ TEST(conv_arm_int8, int8_fp32) {
EXPECT_NEAR(output_int32_data[i],
output_int32_ref_data[i], 1e-3);
}
*/
// ============ int8gemm_int8 ============
int8_t* output_int8_ref_data =
output_int8_ref.mutable_data<int8_t>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册