提交 e14852b3 编写于 作者: 李寅

Merge branch 'conv2d3x3-bug' into 'master'

Fix bug : construct padding input before send data to neon kernel.

See merge request !42
...@@ -31,9 +31,9 @@ static inline void ConstructInputWithPadding(const float* input, ...@@ -31,9 +31,9 @@ static inline void ConstructInputWithPadding(const float* input,
// Skip the padded top rows // Skip the padded top rows
output_ptr += padded_top * output_width; output_ptr += padded_top * output_width;
for (; batch > 0; --batch) { for (int i = 0; i < batch; ++i) {
for (; channels > 0; --channels) { for (int j = 0; j < channels; ++j) {
for(; height > 0; --height) { for (int k = 0; k < height; ++k) {
memcpy(output_ptr + padded_left, input, width * sizeof(float)); memcpy(output_ptr + padded_left, input, width * sizeof(float));
input += width; input += width;
output_ptr += output_width; output_ptr += output_width;
......
...@@ -199,11 +199,11 @@ TEST_F(Conv2dOpTest, ConvNxNS12) { ...@@ -199,11 +199,11 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
srand(time(NULL)); srand(time(NULL));
// generate random input // generate random input
index_t batch = 1 + rand() % 5; index_t batch = 1 + rand() % 10;
index_t input_channels = 3 + rand() % 50; index_t input_channels = 1 + rand() % 50;
index_t height = 10 + rand() % 100; index_t height = 7 + rand() % 100;
index_t width = 10 + rand() % 100; index_t width = 7 + rand() % 100;
index_t output_channels = 3 + rand() % 50; index_t output_channels = 1 + rand() % 50;
// Construct graph // Construct graph
auto& net = test_net(); auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest") OpDefBuilder("Conv2d", "Conv2dTest")
...@@ -236,11 +236,11 @@ TEST_F(Conv2dOpTest, ConvNxNS12) { ...@@ -236,11 +236,11 @@ TEST_F(Conv2dOpTest, ConvNxNS12) {
// Run NEON // Run NEON
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-3);
}; };
for (int kernel_size : {1}) { // TODO(liu1i10) 3x3 for (int kernel_size : {1, 3}) {
for (int stride : {1, 2}) { for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID); func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME); func(kernel_size, kernel_size, stride, stride, SAME);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册