提交 328839bd 编写于 作者: L liuqi

Finish conv2d 3x3 stride 2 neon kernel.

上级 d20d5ad8
......@@ -8,221 +8,288 @@
namespace mace {
namespace kernels {
#define KERNEL_HEAD_CODE \
int output_batch = output_shape[0]; \
int output_channels = output_shape[1]; \
int output_height = output_shape[2]; \
int output_width = output_shape[3]; \
int input_batch = input_shape[0]; \
int input_channels = input_shape[1]; \
int input_height = input_shape[2]; \
int input_width = input_shape[3]; \
int kernel_h = 3; \
int kernel_w = 3; \
for (int b = 0; b < output_batch; ++b) { \
float* output_ptr_base = output + b * output_channels * output_height * output_width; \
for (int oc = 0; oc < output_channels; ++oc) { \
const float* filter_ptr = filter + oc * input_channels * kernel_h * kernel_w; \
const float* input_ptr = input + b * input_channels * input_height * input_width; \
float* output_ptr = output_ptr_base + oc * output_height * output_width; \
std::fill(output_ptr, output_ptr + output_height * output_width, bias[oc]); \
for (int ic = 0; ic < input_channels; ++ic) { \
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
#define KERNEL_TAIL_CODE \
filter_ptr += 9; \
input_ptr += input_height * input_width; \
} \
} \
}
static const int kRegisterSize = 4;
void Conv2dNeonK3x3S1(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape) {
int batch = output_shape[0];
int channels = output_shape[1];
int height = output_shape[2];
int width = output_shape[3];
int input_batch = input_shape[0];
int input_channels = input_shape[1];
int input_height = input_shape[2];
int input_width = input_shape[3];
int kernel_h = 3;
int kernel_w = 3;
int height_count = (height >> 1) << 1;
for (int b = 0; b < batch; ++b) {
float* output_ptr_base = output + b * channels * height * width;
for (int oc = 0; oc < channels; ++oc) {
const float* filter_ptr =
filter + oc * input_channels * kernel_h * kernel_w;
const float* input_ptr =
input + b * input_channels * input_height * input_width;
float* output_ptr = output_ptr_base + oc * height * width;
std::fill(output_ptr, output_ptr + height * width, bias[oc]);
for (int ic = 0; ic < input_channels; ++ic) {
float32x4_t filter0 = vld1q_f32(filter_ptr);
float32x4_t filter3 = vld1q_f32(filter_ptr + 3);
float32x4_t filter6 = vld1q_f32(filter_ptr + 6);
const float* row[kRegisterSize] = {input_ptr, input_ptr + input_width,
input_ptr + 2 * input_width,
input_ptr + 3 * input_width};
float* output_ptr1 = output_ptr;
float* output_ptr2 = output_ptr + width;
void Conv2dNeonK3x3S1(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
int height_count = (output_shape[2] >> 1) << 1;
KERNEL_HEAD_CODE
const float *row_ptr_v[kRegisterSize] = {
input_ptr, input_ptr + input_width,
input_ptr + 2 * input_width, input_ptr + 3 * input_width
};
float *output_ptr_v[] = {output_ptr, output_ptr + output_width};
for (int h = 0; h < height_count; h += 2) {
int count = width >> 2;
int remain_count = width & 3;
int count = output_width >> 2;
int remain_count = output_width & 3;
for (; count > 0; --count) {
float32x4_t sum0 = vdupq_n_f32(.0f);
float32x4_t sum1 = vdupq_n_f32(.0f);
float32x4_t row0_ext_0 = vld1q_f32(row[0]); // 0123
float32x4_t row0_latter =
vld1q_f32(row[0] + kRegisterSize); // 4567
float32x4_t row0_ext_1 =
vextq_f32(row0_ext_0, row0_latter, 1); // 1234
float32x4_t row0_ext_2 =
vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2);
float32x4_t row1_ext_0 = vld1q_f32(row[1]); // 0123
float32x4_t row1_latter =
vld1q_f32(row[1] + kRegisterSize); // 4567
float32x4_t row1_ext_1 =
vextq_f32(row1_ext_0, row1_latter, 1); // 1234
float32x4_t row1_ext_2 =
vextq_f32(row1_ext_0, row1_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2);
row0_ext_0 = vld1q_f32(row[2]); // 0123
row0_latter = vld1q_f32(row[2] + kRegisterSize); // 4567
row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); // 1234
row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter6, 2);
float32x4_t n_sum0 = vdupq_n_f32(.0f);
float32x4_t n_row_former = vld1q_f32(row_ptr_v[0]);
float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + kRegisterSize);
float32x4_t n_row_ext0 = vextq_f32(n_row_former, n_row_latter, 1);
float32x4_t n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_former, n_filter_v[0], 0);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext0, n_filter_v[0], 1);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext1, n_filter_v[0], 2);
float32x4_t n_row1_former = vld1q_f32(row_ptr_v[1]);
float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize);
float32x4_t n_row1_ext0 = vextq_f32(n_row1_former, n_row1_latter, 1);
float32x4_t n_row1_ext1 = vextq_f32(n_row1_former, n_row1_latter, 2);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_former, n_filter_v[1], 0);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext0, n_filter_v[1], 1);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext1, n_filter_v[1], 2);
n_row_former = vld1q_f32(row_ptr_v[2]);
n_row_latter = vld1q_f32(row_ptr_v[2] + kRegisterSize);
n_row_ext0 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_former, n_filter_v[2], 0);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext0, n_filter_v[2], 1);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext1, n_filter_v[2], 2);
// second row
sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter0, 0);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter0, 1);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_2, filter0, 2);
sum1 = vfmaq_laneq_f32(sum1, row0_ext_0, filter3, 0);
sum1 = vfmaq_laneq_f32(sum1, row0_ext_1, filter3, 1);
sum1 = vfmaq_laneq_f32(sum1, row0_ext_2, filter3, 2);
row1_ext_0 = vld1q_f32(row[3]); // 0123
row1_latter = vld1q_f32(row[3] + kRegisterSize); // 4567
row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); // 1234
row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); // 2345
sum1 = vfmaq_laneq_f32(sum1, row1_ext_0, filter6, 0);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_1, filter6, 1);
sum1 = vfmaq_laneq_f32(sum1, row1_ext_2, filter6, 2);
float32x4_t output_row0 = vld1q_f32(output_ptr1);
float32x4_t output_row1 = vld1q_f32(output_ptr2);
output_row0 = vaddq_f32(output_row0, sum0);
output_row1 = vaddq_f32(output_row1, sum1);
vst1q_f32(output_ptr1, output_row0);
vst1q_f32(output_ptr2, output_row1);
output_ptr1 += kRegisterSize;
output_ptr2 += kRegisterSize;
float32x4_t n_sum1 = vdupq_n_f32(.0f);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_former, n_filter_v[0], 0);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext0, n_filter_v[0], 1);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext1, n_filter_v[0], 2);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_former, n_filter_v[1], 0);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_ext0, n_filter_v[1], 1);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_ext1, n_filter_v[1], 2);
n_row1_former = vld1q_f32(row_ptr_v[3]);
n_row1_latter = vld1q_f32(row_ptr_v[3] + kRegisterSize);
n_row1_ext0 = vextq_f32(n_row1_former, n_row1_latter, 1);
n_row1_ext1 = vextq_f32(n_row1_former, n_row1_latter, 2);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_former, n_filter_v[2], 0);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext0, n_filter_v[2], 1);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext1, n_filter_v[2], 2);
float32x4_t n_output_row = vld1q_f32(output_ptr_v[0]);
float32x4_t n_output_row1 = vld1q_f32(output_ptr_v[1]);
n_output_row = vaddq_f32(n_output_row, n_sum0);
n_output_row1 = vaddq_f32(n_output_row1, n_sum1);
vst1q_f32(output_ptr_v[0], n_output_row);
vst1q_f32(output_ptr_v[1], n_output_row1);
output_ptr_v[0] += kRegisterSize;
output_ptr_v[1] += kRegisterSize;
for (int i = 0; i < kRegisterSize; ++i) {
row[i] += kRegisterSize;
row_ptr_v[i] += kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
float32x4_t row0 = vld1q_f32(row[0]); // 0123
float32x4_t row1 = vld1q_f32(row[1]); // 0123
float32x4_t row2 = vld1q_f32(row[2]); // 0123
float32x4_t row3 = vld1q_f32(row[3]); // 0123
float32x4_t sum = vmulq_f32(row0, filter0);
sum = vmlaq_f32(sum, row1, filter3);
sum = vmlaq_f32(sum, row2, filter6);
sum = vsetq_lane_f32(*output_ptr1, sum, 3);
*output_ptr1 = vaddvq_f32(sum);
sum = vmulq_f32(row1, filter0);
sum = vmlaq_f32(sum, row2, filter3);
sum = vmlaq_f32(sum, row3, filter6);
sum = vsetq_lane_f32(*output_ptr2, sum, 3);
*output_ptr2 = vaddvq_f32(sum);
++output_ptr1;
++output_ptr2;
float32x4_t n_row_v[] = {
vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2])
};
float32x4_t n_sum0 = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum0 = vmlaq_f32(n_sum0, n_row_v[1], n_filter_v[1]);
n_sum0 = vmlaq_f32(n_sum0, n_row_v[2], n_filter_v[2]);
n_sum0 = vsetq_lane_f32(*output_ptr_v[0], n_sum0, 3);
*output_ptr_v[0] = vaddvq_f32(n_sum0);
float32x4_t n_row3 = vld1q_f32(row_ptr_v[3]);
float32x4_t n_sum1 = vmulq_f32(n_row_v[1], n_filter_v[0]);
n_sum1 = vmlaq_f32(n_sum1, n_row_v[2], n_filter_v[1]);
n_sum1 = vmlaq_f32(n_sum1, n_row3, n_filter_v[2]);
n_sum1 = vsetq_lane_f32(*output_ptr_v[1], n_sum1, 3);
*output_ptr_v[1] = vaddvq_f32(n_sum1);
++output_ptr_v[0];
++output_ptr_v[1];
for (int i = 0; i < kRegisterSize; ++i) {
row[i] += 1;
row_ptr_v[i] += 1;
}
}
output_ptr1 += width;
output_ptr2 += width;
output_ptr_v[0] += output_width;
output_ptr_v[1] += output_width;
for (int i = 0; i < kRegisterSize; ++i) {
row[i] += 2 + input_width;
row_ptr_v[i] += 2 + input_width;
}
}
if (height != height_count) {
int count = width >> 2;
int remain_count = width & 3;
if (output_height != height_count) {
int count = output_width >> 2;
int remain_count = output_width & 3;
for (; count > 0; --count) {
float32x4_t sum0 = vdupq_n_f32(.0f);
float32x4_t row0_ext_0 = vld1q_f32(row[0]); // 0123
float32x4_t row0_latter =
vld1q_f32(row[0] + kRegisterSize); // 4567
float32x4_t row0_ext_1 =
vextq_f32(row0_ext_0, row0_latter, 1); // 1234
float32x4_t row0_ext_2 =
vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter0, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter0, 1);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2);
float32x4_t row1_ext_0 = vld1q_f32(row[1]); // 0123
float32x4_t row1_latter =
vld1q_f32(row[1] + kRegisterSize); // 4567
float32x4_t row1_ext_1 =
vextq_f32(row1_ext_0, row1_latter, 1); // 1234
float32x4_t row1_ext_2 =
vextq_f32(row1_ext_0, row1_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row1_ext_0, filter3, 0);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_1, filter3, 1);
sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2);
row0_ext_0 = vld1q_f32(row[2]); // 0123
row0_latter = vld1q_f32(row[2] + kRegisterSize); // 4567
row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); // 1234
row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); // 2345
sum0 = vfmaq_laneq_f32(sum0, row0_ext_0, filter6, 0);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_1, filter6, 1);
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter6, 2);
float32x4_t output_row0 = vld1q_f32(output_ptr1);
output_row0 = vaddq_f32(output_row0, sum0);
vst1q_f32(output_ptr1, output_row0);
output_ptr1 += kRegisterSize;
float32x4_t n_sum = vdupq_n_f32(.0f);
float32x4_t n_row_former = vld1q_f32(row_ptr_v[0]);
float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + kRegisterSize);
float32x4_t n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
float32x4_t n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[0], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[0], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[0], 2);
n_row_former = vld1q_f32(row_ptr_v[1]);
n_row_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize);
n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[1], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[1], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[1], 2);
n_row_former = vld1q_f32(row_ptr_v[2]);
n_row_latter = vld1q_f32(row_ptr_v[2] + kRegisterSize);
n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[2], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[2], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[2], 2);
float32x4_t n_output_row = vld1q_f32(output_ptr_v[0]);
n_output_row = vaddq_f32(n_output_row, n_sum);
vst1q_f32(output_ptr_v[0], n_output_row);
output_ptr_v[0] += kRegisterSize;
for (int i = 0; i < 3; ++i) {
row[i] += kRegisterSize;
row_ptr_v[i] += kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
float32x4_t row0 = vld1q_f32(row[0]); // 0123
float32x4_t row1 = vld1q_f32(row[1]); // 0123
float32x4_t row2 = vld1q_f32(row[2]); // 0123
float32x4_t n_row_v[] = {
vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2]),
};
float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]);
n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]);
n_sum = vsetq_lane_f32(*output_ptr_v[0], n_sum, 3);
*output_ptr_v[0] = vaddvq_f32(n_sum);
++output_ptr_v[0];
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += 1;
}
}
}
KERNEL_TAIL_CODE
}
void Conv2dNeonK3x3S2(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
int tail_step = 2 * (input_shape[3] - output_shape[3]);
float32x4_t sum = vmulq_f32(row0, filter0);
sum = vmlaq_f32(sum, row1, filter3);
sum = vmlaq_f32(sum, row2, filter6);
sum = vsetq_lane_f32(*output_ptr1, sum, 3);
*output_ptr1 = vaddvq_f32(sum);
KERNEL_HEAD_CODE
++output_ptr1;
const float *row_ptr_v[3] = {
input_ptr, input_ptr + input_width, input_ptr + 2 * input_width
};
float *output_ptr_inner = output_ptr;
for (int h = 0; h < output_height; ++h) {
int count = output_width >> 2;
int remain_count = output_width & 3;
for (; count > 0; --count) {
float32x4_t n_sum = vdupq_n_f32(.0f);
float32x4x2_t n_row_former = vld2q_f32(row_ptr_v[0]);
float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + 8);
float32x4_t n_row_ext = vextq_f32(n_row_former.val[0], n_row_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former.val[0], n_filter_v[0], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former.val[1], n_filter_v[0], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext, n_filter_v[0], 2);
float32x4x2_t n_row1_former = vld2q_f32(row_ptr_v[1]);
float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + 8);
float32x4_t n_row1_ext = vextq_f32(n_row1_former.val[0], n_row1_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_former.val[0], n_filter_v[1], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_former.val[1], n_filter_v[1], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_ext, n_filter_v[1], 2);
float32x4x2_t n_row2_former = vld2q_f32(row_ptr_v[2]);
float32x4_t n_row2_latter = vld1q_f32(row_ptr_v[2] + 8);
float32x4_t n_row2_ext = vextq_f32(n_row2_former.val[0], n_row2_latter, 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_former.val[0], n_filter_v[2], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_former.val[1], n_filter_v[2], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_ext, n_filter_v[2], 2);
float32x4_t n_output_row = vld1q_f32(output_ptr_inner);
n_output_row = vaddq_f32(n_output_row, n_sum);
vst1q_f32(output_ptr_inner, n_output_row);
output_ptr_inner += kRegisterSize;
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += 2 * kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
float32x4_t n_row_v[] = {
vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2])
};
float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]);
n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]);
n_sum = vsetq_lane_f32(*output_ptr_inner, n_sum, 3);
*output_ptr_inner = vaddvq_f32(n_sum);
++output_ptr_inner;
for (int i = 0; i < 3; ++i) {
row[i] += 1;
row_ptr_v[i] += 2;
}
}
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += tail_step;
}
}
filter_ptr += 9;
input_ptr += input_height * input_width;
}
}
}
KERNEL_TAIL_CODE
}
} // namespace kernels
} // namespace mace
#undef KERNEL_HEAD_CODE
#undef KERNEL_TAIL_CODE
} // namespace kernels
} // namespace mace
......@@ -76,6 +76,10 @@ BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float);
......
......@@ -174,8 +174,8 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
// generate random input
index_t batch = 1 + rand() % 10;
index_t input_channels = 1 + rand() % 50;
index_t height = 7 + rand() % 100;
index_t width = 7 + rand() % 100;
index_t height = 11 + rand() % 100;
index_t width = 11 + rand() % 100;
index_t output_channels = 1 + rand() % 50;
// Construct graph
auto& net = test_net();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册