提交 5941a8a4 编写于 作者: 李寅

Merge branch 'conv2d-padding' into 'master'

Change conv2d functor to member variable.

See merge request !52
......@@ -6,23 +6,39 @@
#define MACE_KERNELS_CONV_2D_H_
#include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
template<DeviceType D, typename T>
class Conv2dFunctor {
public:
Conv2dFunctor(const int* strides, const int* paddings, const int* dilations)
: strides_(strides), paddings_(paddings), dilations_(dilations) {}
void operator()(const T* input, // NCHW
const index_t* input_shape,
const T* filter, // c_out, c_in, kernel_h, kernel_w
const index_t* filter_shape,
const T* bias, // c_out
T* output, // NCHW
const index_t* output_shape) {
Conv2dFunctor(const index_t *input_shape,
const index_t *filter_shape,
const int *strides,
const Padding padding,
const int *dilations) :
strides_(strides),
paddings_(2, 0),
dilations_(dilations) {
CalPaddingSize(input_shape, filter_shape, dilations_, strides_, padding, paddings_.data());
}
Conv2dFunctor(const int *strides,
const std::vector<int> &paddings,
const int *dilations) :
strides_(strides),
paddings_(paddings),
dilations_(dilations) {}
void operator()(const T *input, // NCHW
const index_t *input_shape,
const T *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const T *bias, // c_out
T *output, // NCHW
const index_t *output_shape) {
MACE_CHECK_NOTNULL(output);
index_t batch = output_shape[0];
......@@ -62,7 +78,7 @@ class Conv2dFunctor {
index_t offset = n * channels * height * width +
c * height * width + h * width + w;
T sum = 0;
const T* filter_ptr = filter + c * kernel_size;
const T *filter_ptr = filter + c * kernel_size;
for (int inc = 0; inc < input_channels; ++inc) {
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
......@@ -95,20 +111,20 @@ class Conv2dFunctor {
}
private:
const int* strides_; // [stride_h, stride_w]
const int* paddings_; // [padding_h, padding_w]
const int* dilations_; // [dilation_h, dilation_w]
const int *strides_; // [stride_h, stride_w]
std::vector<int> paddings_; // [padding_h, padding_w]
const int *dilations_; // [dilation_h, dilation_w]
};
template <>
template<>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(
const float* input,
const index_t* input_shape,
const float* filter,
const index_t* filter_shape,
const float* bias,
float* output,
const index_t* output_shape);
const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
} // namespace kernels
} // namespace mace
......
......@@ -44,8 +44,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
for (index_t j = 0; j < count; ++j) {
float32x4_t input_f = vld1q_f32(input_sample_ptr);
float32x4_t output_f = new_offset_f;
output_f = vfmaq_f32(output_f, input_f, new_scale_f);
float32x4_t output_f = vfmaq_f32(new_offset_f, input_f, new_scale_f);
vst1q_f32(output_sample_ptr, output_f);
input_sample_ptr += 4;
output_sample_ptr += 4;
......
......@@ -81,7 +81,7 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float *input,
// Keep this alive during kernel execution
Tensor padded_input;
if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input);
ConstructInputWithPadding(input, input_shape, paddings_.data(), &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
}
......
......@@ -17,7 +17,12 @@ template<DeviceType D, typename T>
class Conv2dOp : public ConvPool2dOpBase<D, T> {
public:
Conv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws) {};
: ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->Input(INPUT)->shape().data(),
this->Input(FILTER)->shape().data(),
this->strides_.data(),
this->padding_,
this->dilations_.data()) {}
bool Run() override {
const Tensor *input = this->Input(INPUT);
......@@ -27,21 +32,19 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
kernels::CalcPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), this->dilations_.data(),
this->strides_.data(), this->padding_, output_shape.data(),
paddings.data());
this->CalOutputSize(input->shape().data(), filter->shape().data(), output_shape.data());
output->Resize(output_shape);
auto conv2d = kernels::Conv2dFunctor<D, T>(
this->strides_.data(), paddings.data(), this->dilations_.data());
conv2d(input->data<T>(), input->shape().data(), filter->data<T>(),
functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter->shape().data(), bias->data<T>(), output->mutable_data<T>(),
output->shape().data());
return true;
}
private:
kernels::Conv2dFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, FILTER, BIAS);
OP_OUTPUT_TAGS(OUTPUT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册