提交 5941a8a4 编写于 作者: 李寅

Merge branch 'conv2d-padding' into 'master'

Change conv2d functor to member variable.

See merge request !52
...@@ -6,23 +6,39 @@ ...@@ -6,23 +6,39 @@
#define MACE_KERNELS_CONV_2D_H_ #define MACE_KERNELS_CONV_2D_H_
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template<DeviceType D, typename T>
class Conv2dFunctor { class Conv2dFunctor {
public: public:
Conv2dFunctor(const int* strides, const int* paddings, const int* dilations) Conv2dFunctor(const index_t *input_shape,
: strides_(strides), paddings_(paddings), dilations_(dilations) {} const index_t *filter_shape,
const int *strides,
void operator()(const T* input, // NCHW const Padding padding,
const index_t* input_shape, const int *dilations) :
const T* filter, // c_out, c_in, kernel_h, kernel_w strides_(strides),
const index_t* filter_shape, paddings_(2, 0),
const T* bias, // c_out dilations_(dilations) {
T* output, // NCHW CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data());
const index_t* output_shape) { }
Conv2dFunctor(const int *strides,
const std::vector<int> &paddings,
const int *dilations) :
strides_(strides),
paddings_(paddings),
dilations_(dilations) {}
void operator()(const T *input, // NCHW
const index_t *input_shape,
const T *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const T *bias, // c_out
T *output, // NCHW
const index_t *output_shape) {
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
index_t batch = output_shape[0]; index_t batch = output_shape[0];
...@@ -60,9 +76,9 @@ class Conv2dFunctor { ...@@ -60,9 +76,9 @@ class Conv2dFunctor {
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) { for (int w = 0; w < width; ++w) {
index_t offset = n * channels * height * width + index_t offset = n * channels * height * width +
c * height * width + h * width + w; c * height * width + h * width + w;
T sum = 0; T sum = 0;
const T* filter_ptr = filter + c * kernel_size; const T *filter_ptr = filter + c * kernel_size;
for (int inc = 0; inc < input_channels; ++inc) { for (int inc = 0; inc < input_channels; ++inc) {
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) { for (int kw = 0; kw < kernel_w; ++kw) {
...@@ -71,7 +87,7 @@ class Conv2dFunctor { ...@@ -71,7 +87,7 @@ class Conv2dFunctor {
if (inh < 0 || inh >= input_height || inw < 0 || if (inh < 0 || inh >= input_height || inw < 0 ||
inw >= input_width) { inw >= input_width) {
MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop && MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw >= padded_w_start && inw < padded_w_stop, inw >= padded_w_start && inw < padded_w_stop,
"Out of range read from input: ", inh, ", ", "Out of range read from input: ", inh, ", ",
inw); inw);
// else padding with 0: // else padding with 0:
...@@ -79,8 +95,8 @@ class Conv2dFunctor { ...@@ -79,8 +95,8 @@ class Conv2dFunctor {
} else { } else {
index_t input_offset = index_t input_offset =
n * input_channels * input_height * input_width + n * input_channels * input_height * input_width +
inc * input_height * input_width + inh * input_width + inc * input_height * input_width + inh * input_width +
inw; inw;
sum += input[input_offset] * *filter_ptr; sum += input[input_offset] * *filter_ptr;
} }
++filter_ptr; ++filter_ptr;
...@@ -95,20 +111,20 @@ class Conv2dFunctor { ...@@ -95,20 +111,20 @@ class Conv2dFunctor {
} }
private: private:
const int* strides_; // [stride_h, stride_w] const int *strides_; // [stride_h, stride_w]
const int* paddings_; // [padding_h, padding_w] std::vector<int> paddings_; // [padding_h, padding_w]
const int* dilations_; // [dilation_h, dilation_w] const int *dilations_; // [dilation_h, dilation_w]
}; };
template <> template<>
void Conv2dFunctor<DeviceType::NEON, float>::operator()( void Conv2dFunctor<DeviceType::NEON, float>::operator()(
const float* input, const float *input,
const index_t* input_shape, const index_t *input_shape,
const float* filter, const float *filter,
const index_t* filter_shape, const index_t *filter_shape,
const float* bias, const float *bias,
float* output, float *output,
const index_t* output_shape); const index_t *output_shape);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -44,8 +44,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()( ...@@ -44,8 +44,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
for (index_t j = 0; j < count; ++j) { for (index_t j = 0; j < count; ++j) {
float32x4_t input_f = vld1q_f32(input_sample_ptr); float32x4_t input_f = vld1q_f32(input_sample_ptr);
float32x4_t output_f = new_offset_f; float32x4_t output_f = vfmaq_f32(new_offset_f, input_f, new_scale_f);
output_f = vfmaq_f32(output_f, input_f, new_scale_f);
vst1q_f32(output_sample_ptr, output_f); vst1q_f32(output_sample_ptr, output_f);
input_sample_ptr += 4; input_sample_ptr += 4;
output_sample_ptr += 4; output_sample_ptr += 4;
......
...@@ -81,7 +81,7 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input, ...@@ -81,7 +81,7 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
// Keep this alive during kernel execution // Keep this alive during kernel execution
Tensor padded_input; Tensor padded_input;
if (paddings_[0] > 0 || paddings_[1] > 0) { if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input); ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input);
input = padded_input.data<float>(); input = padded_input.data<float>();
input_shape = padded_input.shape().data(); input_shape = padded_input.shape().data();
} }
......
...@@ -17,7 +17,12 @@ template<DeviceType D, typename T> ...@@ -17,7 +17,12 @@ template<DeviceType D, typename T>
class Conv2dOp : public ConvPool2dOpBase<D, T> { class Conv2dOp : public ConvPool2dOpBase<D, T> {
public: public:
Conv2dOp(const OperatorDef &op_def, Workspace *ws) Conv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws) {}; : ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->Input(INPUT)->shape().data(),
this->Input(FILTER)->shape().data(),
this->strides_.data(),
this->padding_,
this->dilations_.data()) {}
bool Run() override { bool Run() override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
...@@ -27,21 +32,19 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -27,21 +32,19 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcPaddingAndOutputSize( this->CalOutputSize(input->shape().data(), filter->shape().data(), output_shape.data());
input->shape().data(), filter->shape().data(), this->dilations_.data(),
this->strides_.data(), this->padding_, output_shape.data(),
paddings.data());
output->Resize(output_shape); output->Resize(output_shape);
auto conv2d = kernels::Conv2dFunctor<D, T>( functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
this->strides_.data(), paddings.data(), this->dilations_.data()); filter->shape().data(), bias->data<T>(), output->mutable_data<T>(),
conv2d(input->data<T>(), input->shape().data(), filter->data<T>(), output->shape().data());
filter->shape().data(), bias->data<T>(), output->mutable_data<T>(),
output->shape().data());
return true; return true;
} }
private:
kernels::Conv2dFunctor<D, T> functor_;
protected: protected:
OP_INPUT_TAGS(INPUT, FILTER, BIAS); OP_INPUT_TAGS(INPUT, FILTER, BIAS);
OP_OUTPUT_TAGS(OUTPUT); OP_OUTPUT_TAGS(OUTPUT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册