提交 080def5a 编写于 作者: H hjchen2

Refine compiler and winograd conv implementation

上级 97b2c1a9
...@@ -17,8 +17,6 @@ limitations under the License. */ ...@@ -17,8 +17,6 @@ limitations under the License. */
#include "operators/kernel/conv_kernel.h" #include "operators/kernel/conv_kernel.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h" #include "operators/kernel/central-arm-func/conv_arm_func.h"
#include <iostream>
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -885,7 +885,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -885,7 +885,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
// compute U*V first // compute U*V first
framework::Tensor uv_trans; framework::Tensor uv_trans;
framework::DDim shape = framework::DDim shape =
framework::make_ddim(std::vector<int>{4 * out_channel, 8 * tiles, 64}); framework::make_ddim(std::vector<int>{out_channel, tiles, 64, 32});
float *uv_trans_ptr = uv_trans.mutable_data<float>(shape); float *uv_trans_ptr = uv_trans.mutable_data<float>(shape);
memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float)); memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float));
const float *input_ptr = input.data<float>(); const float *input_ptr = input.data<float>();
...@@ -894,17 +894,12 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -894,17 +894,12 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < out_channel; ++i) { for (int i = 0; i < out_channel; ++i) {
float *uv_ptr = uv_trans_ptr + (i * tiles * 64 * 32); float *uv_ptr = uv_trans_ptr + (i * tiles * 64 * 32);
for (int k = 0; k < 64; ++k) {
for (int j = 0; j < tiles; ++j) { for (int j = 0; j < tiles; ++j) {
for (int k = 0; k < 64; ++k) {
const float *w_ptr = weight_ptr + (i * 64 + k) * in_channel * 4; const float *w_ptr = weight_ptr + (i * 64 + k) * in_channel * 4;
const float *in_ptr = input_ptr + (j * 64 + k) * in_channel * 8; const float *in_ptr = input_ptr + (j * 64 + k) * in_channel * 8;
float *out0 = uv_ptr + (8 * j) * 64 + k; // out channel 0
float *out1 = out0 + 8 * tiles * 64; // out channel 1
float *out2 = out1 + 8 * tiles * 64; // out channel 2
float *out3 = out2 + 8 * tiles * 64; // out channel 3
int inter_channel = in_channel >> 1; int inter_channel = in_channel >> 1;
int remain_channel = in_channel & 0x1; int remain_channel = in_channel & 0x1;
int steps = 64 * sizeof(float);
asm volatile( asm volatile(
"veor q8, q8, q8 \n" "veor q8, q8, q8 \n"
"veor q9, q9, q9 \n" "veor q9, q9, q9 \n"
...@@ -921,6 +916,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -921,6 +916,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"loop_2c_%=: \n" "loop_2c_%=: \n"
"vld1.32 {d0-d3}, [%[w_ptr]]! \n" "vld1.32 {d0-d3}, [%[w_ptr]]! \n"
"vld1.32 {d4-d7}, [%[in_ptr]]! \n" "vld1.32 {d4-d7}, [%[in_ptr]]! \n"
"vld1.32 {d8-d11}, [%[in_ptr]]! \n"
"vmla.f32 q8, q2, d0[0] \n" "vmla.f32 q8, q2, d0[0] \n"
"vmla.f32 q9, q3, d0[0] \n" "vmla.f32 q9, q3, d0[0] \n"
"vmla.f32 q10, q2, d0[1] \n" "vmla.f32 q10, q2, d0[1] \n"
...@@ -930,7 +926,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -930,7 +926,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"vmla.f32 q14, q2, d1[1] \n" "vmla.f32 q14, q2, d1[1] \n"
"vmla.f32 q15, q3, d1[1] \n" "vmla.f32 q15, q3, d1[1] \n"
"vld1.32 {d8-d11}, [%[in_ptr]]! \n"
"vmla.f32 q8, q4, d2[0] \n" "vmla.f32 q8, q4, d2[0] \n"
"vmla.f32 q9, q5, d2[0] \n" "vmla.f32 q9, q5, d2[0] \n"
"vmla.f32 q10, q4, d2[1] \n" "vmla.f32 q10, q4, d2[1] \n"
...@@ -966,46 +961,14 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -966,46 +961,14 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"bne loop_c_%= \n" "bne loop_c_%= \n"
"store_res_%=: \n" "store_res_%=: \n"
"vst1.32 {d16[0]}, [%[out0]], %[steps] \n" "vst1.32 {d16-d19}, [%[uv_ptr]]! \n"
"vst1.32 {d16[1]}, [%[out0]], %[steps] \n" "vst1.32 {d20-d23}, [%[uv_ptr]]! \n"
"vst1.32 {d17[0]}, [%[out0]], %[steps] \n" "vst1.32 {d24-d27}, [%[uv_ptr]]! \n"
"vst1.32 {d17[1]}, [%[out0]], %[steps] \n" "vst1.32 {d28-d31}, [%[uv_ptr]]! \n"
"vst1.32 {d18[0]}, [%[out0]], %[steps] \n" : [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr),
"vst1.32 {d18[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d19[0]}, [%[out0]], %[steps] \n"
"vst1.32 {d19[1]}, [%[out0]], %[steps] \n"
"vst1.32 {d20[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d20[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d21[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d21[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d22[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d22[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d23[0]}, [%[out1]], %[steps] \n"
"vst1.32 {d23[1]}, [%[out1]], %[steps] \n"
"vst1.32 {d24[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d24[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d25[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d25[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d26[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d26[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d27[0]}, [%[out2]], %[steps] \n"
"vst1.32 {d27[1]}, [%[out2]], %[steps] \n"
"vst1.32 {d28[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d28[1]}, [%[out3]], %[steps] \n"
"vst1.32 {d29[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d29[1]}, [%[out3]], %[steps] \n"
"vst1.32 {d30[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d30[1]}, [%[out3]], %[steps] \n"
"vst1.32 {d31[0]}, [%[out3]], %[steps] \n"
"vst1.32 {d31[1]}, [%[out3]], %[steps] \n"
: [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [out0] "+r"(out0),
[out1] "+r"(out1), [out2] "+r"(out2), [out3] "+r"(out3),
[remain_channel] "+r"(remain_channel), [remain_channel] "+r"(remain_channel),
[inter_channel] "+r"(inter_channel) [inter_channel] "+r"(inter_channel)
: [steps] "r"(steps) :
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
} }
...@@ -1027,34 +990,63 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -1027,34 +990,63 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
int remain_h = out_h - out_h / 6 * 6; int remain_h = out_h - out_h / 6 * 6;
int remain_w = out_w - out_w / 6 * 6; int remain_w = out_w - out_w / 6 * 6;
float *output_ptr = output->mutable_data<float>(); float *output_ptr = output->mutable_data<float>();
out_channel = output->dims()[1];
int uv_image_size = uv_trans.dims()[1] * 64;
float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f}; float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f};
#pragma omp parallel for #pragma omp parallel for
for (int oc = 0; oc < out_channel; ++oc) { for (int oc = 0; oc < output->dims()[1]; ++oc) {
float at_m[48]; // [6][8] float at_m[48]; // [6][8]
float output_tmp[36]; // [6][6], temporarily restore results float output_tmp[36]; // [6][6], temporarily restore results
const float *uv_ptr = uv_trans_ptr + oc * uv_image_size; // (oc / 4) * tiles * 64 * 32 + (oc & 0x3) * 8
const float *uv_ptr =
uv_trans_ptr + (oc >> 2) * tiles * 64 * 32 + (oc & 0x3) * 8;
for (int tile_h = 0; tile_h < h_tiles; ++tile_h) { for (int tile_h = 0; tile_h < h_tiles; ++tile_h) {
for (int tile_w = 0; tile_w < w_tiles; ++tile_w) { for (int tile_w = 0; tile_w < w_tiles; ++tile_w) {
float *at_m_ptr = at_m; float *at_m_ptr = at_m;
int tile_indics = tile_h * w_tiles + tile_w;
int tile_block = tile_indics >> 3;
int block_indics = tile_indics & 0x7;
const float *uv_ptr0 = uv_ptr + tile_block * 64 * 32 + block_indics;
int steps = 32 * sizeof(float);
asm volatile( asm volatile(
"vld1.32 {d0-d1}, [%[tm_ptr]] \n" "vld1.32 {d0-d1}, [%[tm_ptr]] \n"
"mov r0, #2 \n" "mov r0, #2 \n"
"loop_%=: \n" "loop_%=: \n"
"vld1.32 {d2-d5}, [%[uv_ptr]]! \n" "vld1.32 {d2[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d6-d9}, [%[uv_ptr]]! \n" "vld1.32 {d6[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d10-d13}, [%[uv_ptr]]! \n" "vld1.32 {d10[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d14-d17}, [%[uv_ptr]]! \n" "vld1.32 {d14[0]}, [%[uv_ptr0]], %[steps] \n"
"vtrn.32 q1, q3 \n" "vld1.32 {d4[0]}, [%[uv_ptr0]], %[steps] \n"
"vtrn.32 q2, q4 \n" "vld1.32 {d8[0]}, [%[uv_ptr0]], %[steps] \n"
"vtrn.32 q5, q7 \n" "vld1.32 {d12[0]}, [%[uv_ptr0]], %[steps] \n"
"vtrn.32 q6, q8 \n" "vld1.32 {d16[0]}, [%[uv_ptr0]], %[steps] \n"
"vswp.32 d3, d10 \n" // q1: m0, q5: m2
"vswp.32 d7, d14 \n" // q3: m1, q7: m3 "vld1.32 {d2[1]}, [%[uv_ptr0]], %[steps] \n"
"vswp.32 d5, d12 \n" // q2: m4, q6: m6 "vld1.32 {d6[1]}, [%[uv_ptr0]], %[steps] \n"
"vswp.32 d9, d16 \n" // q4: m5, q8: m7 "vld1.32 {d10[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d14[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d4[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d8[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d12[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d16[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d3[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d7[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d11[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d15[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d5[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d9[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d13[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d17[0]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d3[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d7[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d11[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d15[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d5[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d9[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d13[1]}, [%[uv_ptr0]], %[steps] \n"
"vld1.32 {d17[1]}, [%[uv_ptr0]], %[steps] \n"
"vadd.f32 q9, q3, q5 \n" // m1 + m2 "vadd.f32 q9, q3, q5 \n" // m1 + m2
"vadd.f32 q10, q7, q2 \n" // m3 + m4 "vadd.f32 q10, q7, q2 \n" // m3 + m4
...@@ -1095,8 +1087,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -1095,8 +1087,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"subs r0, #1 \n" "subs r0, #1 \n"
"bne loop_%= \n" "bne loop_%= \n"
: [uv_ptr] "+r"(uv_ptr), [at_m_ptr] "+r"(at_m_ptr) : [uv_ptr0] "+r"(uv_ptr0), [at_m_ptr] "+r"(at_m_ptr)
: [tm_ptr] "r"((float *)transform_matrix) : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
......
...@@ -204,9 +204,15 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) { ...@@ -204,9 +204,15 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels) {
Otype *output_cmp_data = output_cmp.data<Otype>(); Otype *output_cmp_data = output_cmp.data<Otype>();
for (int i = 0; i < output->numel(); ++i) { for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i]; float gap = output_data[i] - output_cmp_data[i];
PADDLE_MOBILE_ENFORCE(std::abs(gap / output_data[i]) < 1e-3, PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3,
"output[%d] = %d, output_cmp[%d] = %d", i, "output[%d] = %d, output_cmp[%d] = %d", i,
output_data[i], i, output_cmp_data[i]); output_data[i], i, output_cmp_data[i]);
// if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
// LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
// << ", output_cmp_data[" << i << "] = " <<
// output_cmp_data[i];
// return 1;
// }
} }
delete op; delete op;
return 0; return 0;
...@@ -234,82 +240,66 @@ int main(int argc, char *argv[]) { ...@@ -234,82 +240,66 @@ int main(int argc, char *argv[]) {
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height, paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 7, pad = 0, stride = 2 // kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 2>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 7, pad = 1, stride = 2 // kernel = 7, pad = 1, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 2>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 7, pad = 3, stride = 2 // kernel = 7, pad = 3, stride = 2
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=2";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 2>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 7, pad = 0, stride = 1 // kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 7, 0, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 7, pad = 1, stride = 1 // kernel = 7, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 7, 1, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 7, pad = 3, stride = 1 // kernel = 7, pad = 3, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 7, pad = 5, stride = 3 // kernel = 7, pad = 5, stride = 3
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=5, stride=3";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 7, 5, 3>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 7, pad = 3, stride = 4 // kernel = 7, pad = 3, stride = 4
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=7, pad=3, stride=4";
paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 7, 3, 4>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 3, pad = 0, stride = 1 // kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 3, 0, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 3, pad = 0, stride = 1 // kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 0, 1>(in_channels, in_height, paddle_mobile::TestConvOp<float, float, 3, 0, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 3, pad = 1, stride = 1 // kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 3, 1, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 3, pad = 1, stride = 1 // kernel = 3, pad = 1, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=3, pad=1, stride=1";
paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height, paddle_mobile::TestConvOp<float, float, 3, 1, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 5, pad = 0, stride = 1 // kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 5, 0, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 5, pad = 0, stride = 1 // kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=0, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 0, 1>(in_channels, in_height, paddle_mobile::TestConvOp<float, float, 5, 0, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 5, pad = 2, stride = 1 // kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "int8, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(in_channels, in_height, paddle_mobile::TestConvOp<int8_t, int32_t, 5, 2, 1>(in_channels, in_height,
in_width, out_channels); in_width, out_channels);
// kernel = 5, pad = 2, stride = 1 // kernel = 5, pad = 2, stride = 1
LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1"; LOG(paddle_mobile::kLOG_INFO) << "float, kernel=5, pad=2, stride=1";
paddle_mobile::TestConvOp<float, float, 5, 2, 1>(in_channels, in_height, paddle_mobile::TestConvOp<float, float, 5, 2, 1>(in_channels, in_height,
......
...@@ -69,6 +69,7 @@ build_for_android() { ...@@ -69,6 +69,7 @@ build_for_android() {
-DANDROID_ABI="${ABI}" \ -DANDROID_ABI="${ABI}" \
-DCMAKE_BUILD_TYPE="${MODE}" \ -DCMAKE_BUILD_TYPE="${MODE}" \
-DCMAKE_TOOLCHAIN_FILE="${TOOLCHAIN_FILE}" \ -DCMAKE_TOOLCHAIN_FILE="${TOOLCHAIN_FILE}" \
-DANDROID_TOOLCHAIN='clang' \
-DANDROID_PLATFORM="${ANDROID_PLATFORM_VERSION}" \ -DANDROID_PLATFORM="${ANDROID_PLATFORM_VERSION}" \
-DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \ -DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
-DANDROID_STL=c++_static \ -DANDROID_STL=c++_static \
...@@ -82,6 +83,7 @@ build_for_android() { ...@@ -82,6 +83,7 @@ build_for_android() {
-DANDROID_ABI="${ABI}" \ -DANDROID_ABI="${ABI}" \
-DCMAKE_BUILD_TYPE="${MODE}" \ -DCMAKE_BUILD_TYPE="${MODE}" \
-DCMAKE_TOOLCHAIN_FILE="${TOOLCHAIN_FILE}" \ -DCMAKE_TOOLCHAIN_FILE="${TOOLCHAIN_FILE}" \
-DANDROID_TOOLCHAIN='clang' \
-DANDROID_PLATFORM="${ANDROID_PLATFORM_VERSION}" \ -DANDROID_PLATFORM="${ANDROID_PLATFORM_VERSION}" \
-DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \ -DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
-DANDROID_STL=c++_static \ -DANDROID_STL=c++_static \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册