提交 12c719a1 编写于 作者: L liuqi

Change filter order to OIHW and Fix some warning.

上级 ab541e32
......@@ -22,7 +22,7 @@ class Conv2dFunctor {
void operator()(const T* input, // NCHW
const index_t* input_shape,
const T* filter, // kernel_h, kernel_w, c_in, c_out
const T* filter, // c_out, c_in, kernel_h, kernel_w
const index_t* filter_shape,
const T* bias, // c_out
T* output, // NCHW
......@@ -39,8 +39,8 @@ class Conv2dFunctor {
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
int kernel_h = filter_shape[0];
int kernel_w = filter_shape[1];
index_t kernel_h = filter_shape[2];
index_t kernel_w = filter_shape[3];
int stride_h = strides_[0];
int stride_w = strides_[1];
......@@ -53,10 +53,12 @@ class Conv2dFunctor {
// The left-upper most offset of the padded input
int padded_h_start = 0 - paddings_[0] / 2;
int padded_w_start = 0 - paddings_[1] / 2;
int padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2;
int padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2;
index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2;
index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2;
#pragma omp parallel for collpse(2)
index_t kernel_size = input_channels * kernel_h * kernel_w;
#pragma omp parallel for collapse(2)
for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < height; ++h) {
......@@ -65,17 +67,10 @@ class Conv2dFunctor {
c * height * width +
h * width + w;
T sum = 0;
const T* filter_ptr = filter + c * kernel_size;
for (int inc = 0; inc < input_channels; ++inc) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
/*
* TODO The tensorflow filter order is HWCiCo.
* We should consider other order for different
* implementaion to optimize memory access.
*/
int filter_offset = kh * kernel_w * input_channels * channels +
kw * input_channels * channels +
inc * channels + c;
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
......@@ -94,8 +89,9 @@ class Conv2dFunctor {
n * input_channels * input_height * input_width +
inc * input_height * input_width +
inh * input_width + inw;
sum += input[input_offset] * filter[filter_offset];
sum += input[input_offset] * *filter_ptr;
}
++filter_ptr;
}
}
output[offset] = sum + bias[c];
......
......@@ -56,10 +56,8 @@ public:
// The left-upper most offset of the padded input
int padded_h_start = 0 - paddings_[0] / 2;
int padded_w_start = 0 - paddings_[1] / 2;
int padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2;
int padded_w_stop = input_width + paddings_[1] - paddings_[0] / 2;
#pragma omp parallel for collpse(2)
#pragma omp parallel for collapse(2)
for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) {
index_t out_offset = n * channels * height * width +
......
......@@ -32,7 +32,7 @@ TEST_F(Conv2dOpTest, Simple_VALID) {
1, 1, 1,
1, 1, 1,
1, 1, 1});
AddInputFromArray<float>("Filter", {3, 3, 2, 1},
AddInputFromArray<float>("Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
......@@ -69,7 +69,7 @@ TEST_F(Conv2dOpTest, Simple_SAME) {
1, 1, 1,
1, 1, 1,
1, 1, 1});
AddInputFromArray<float>("Filter", {3, 3, 2, 1},
AddInputFromArray<float>("Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
......@@ -113,16 +113,11 @@ TEST_F(Conv2dOpTest, Combined) {
1, 1, 1, 1, 1,
1, 1, 1, 1, 1,
1, 1, 1, 1, 1});
AddInputFromArray<float>("Filter", {3, 3, 2, 2},
{1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f});
AddInputFromArray<float>("Filter", {2, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
AddInputFromArray<float>("Bias", {2}, {0.1f, 0.2f});
// Run
......
......@@ -27,7 +27,7 @@ class ConvPool2dOpBase : public Operator<D, T> {
dilations_(OperatorBase::GetRepeatedArgument<int>("dilations")) {}
void CalcPaddingAndOutputSize(const index_t* input_shape, // NCHW
const index_t* filter_shape, // HWIO
const index_t* filter_shape, // OIHW
std::vector<index_t>* output_shape,
std::vector<int>* padding_size) {
MACE_CHECK(dilations_[0] > 0 && dilations_[1] > 0,
......@@ -44,12 +44,12 @@ class ConvPool2dOpBase : public Operator<D, T> {
*padding_size = {0, 0};
index_t output_height, output_width;
index_t kernel_height = filter_shape[0];
index_t kernel_width = filter_shape[1];
index_t output_channels = filter_shape[3];
index_t kernel_height = filter_shape[2];
index_t kernel_width = filter_shape[3];
index_t output_channels = filter_shape[0];
int k_extent_height = (kernel_height - 1) * dilations_[0] + 1;
int k_extent_width = (kernel_width - 1) * dilations_[1] + 1;
index_t k_extent_height = (kernel_height - 1) * dilations_[0] + 1;
index_t k_extent_width = (kernel_width - 1) * dilations_[1] + 1;
switch (padding_) {
case VALID:
......
......@@ -4,8 +4,6 @@
#include "mace/ops/pooling.h"
#include "mace/proto/mace.pb.h"
#include "mace/kernels/pooling.h"
namespace mace {
......
......@@ -29,10 +29,10 @@ public:
std::vector<index_t> output_shape;
std::vector<int> paddings;
std::vector<index_t> filter_shape = std::vector<index_t>(4);
filter_shape[0] = kernels_[0];
filter_shape[1] = kernels_[1];
filter_shape[2] = in_shape[0];
filter_shape[3] = in_shape[1];
filter_shape[0] = in_shape[1];
filter_shape[1] = in_shape[0];
filter_shape[2] = kernels_[0];
filter_shape[3] = kernels_[1];
this->CalcPaddingAndOutputSize(in_shape.data(), filter_shape.data(),
&output_shape, &paddings);
output->Resize(output_shape);
......@@ -50,8 +50,8 @@ public:
};
protected:
PoolingType pooling_type_;
std::vector<int> kernels_;
PoolingType pooling_type_;
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
......
......@@ -5,7 +5,6 @@
#include "gtest/gtest.h"
#include "mace/core/operator.h"
#include "mace/core/net.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/kernels/pooling.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册