提交 cd7cabb5 编写于 作者: 李滨

Merge branch 'liyin/mace-depthwise' into 'master'

Refactor depthwise conv

See merge request !1052
...@@ -24,6 +24,49 @@ namespace ops { ...@@ -24,6 +24,49 @@ namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
void Conv2dBase::CalOutputShapeAndInputPadSize(
const std::vector<index_t> &input_shape,
const std::vector<index_t> &filter_shape,
std::vector<index_t> *output_shape,
std::vector<int> *in_pad_size) {
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input_shape.data(),
filter_shape.data(),
dilations_.data(),
strides_.data(),
padding_type_,
output_shape->data(),
in_pad_size->data());
} else {
*in_pad_size = paddings_;
CalcNCHWOutputSize(input_shape.data(),
filter_shape.data(),
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::FLOOR,
output_shape->data());
}
}
void Conv2dBase::CalOutputBoundaryWithoutUsingInputPad(
const std::vector<index_t> &output_shape,
const std::vector<int> in_pad_size,
std::vector<index_t> *out_bound) {
const int pad_top = in_pad_size[0] >> 1;
const int pad_bottom = in_pad_size[0] - pad_top;
const int pad_left = in_pad_size[1] >> 1;
const int pad_right = in_pad_size[1] - pad_left;
const index_t height = output_shape[2];
const index_t width = output_shape[3];
*out_bound = {
pad_top == 0 ? 0 : (pad_top - 1) / strides_[0] + 1,
pad_bottom == 0 ? height : height - ((pad_bottom - 1) / strides_[0] + 1),
pad_left == 0 ? 0 : (pad_left - 1) / strides_[1] + 1,
pad_right == 0 ? width : width - ((pad_right - 1) / strides_[1] + 1),
};
}
void Conv2dBase::CalOutputShapeAndPadSize(const Tensor *input, void Conv2dBase::CalOutputShapeAndPadSize(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const int out_tile_height, const int out_tile_height,
...@@ -46,24 +89,11 @@ void Conv2dBase::CalOutputShapeAndPadSize(const Tensor *input, ...@@ -46,24 +89,11 @@ void Conv2dBase::CalOutputShapeAndPadSize(const Tensor *input,
const index_t filter_w = filter->dim(3); const index_t filter_w = filter->dim(3);
std::vector<int> paddings(2); std::vector<int> paddings(2);
if (paddings_.empty()) { CalOutputShapeAndInputPadSize(input->shape(),
CalcNCHWPaddingAndOutputSize(input->shape().data(), filter->shape(),
filter->shape().data(), output_shape,
dilations_.data(), &paddings);
strides_.data(),
padding_type_,
output_shape->data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(),
filter->shape().data(),
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::FLOOR,
output_shape->data());
}
const index_t out_height = (*output_shape)[2]; const index_t out_height = (*output_shape)[2];
const index_t out_width = (*output_shape)[3]; const index_t out_width = (*output_shape)[3];
const index_t const index_t
...@@ -152,7 +182,8 @@ MaceStatus Conv2dBase::ResizeOutAndPadInOut(const OpContext *context, ...@@ -152,7 +182,8 @@ MaceStatus Conv2dBase::ResizeOutAndPadInOut(const OpContext *context,
} }
if (is_out_padded) { if (is_out_padded) {
std::unique_ptr<Tensor> std::unique_ptr<Tensor>
padded_out = make_unique<Tensor>(scratch_buffer->Scratch(padded_out_size), padded_out =
make_unique<Tensor>(scratch_buffer->Scratch(padded_out_size),
DataType::DT_FLOAT); DataType::DT_FLOAT);
padded_out->Resize({batch, out_channels, padded_out_height, padded_out->Resize({batch, out_channels, padded_out_height,
padded_out_width}); padded_out_width});
......
...@@ -49,6 +49,18 @@ class Conv2dBase { ...@@ -49,6 +49,18 @@ class Conv2dBase {
Tensor *output) = 0; Tensor *output) = 0;
protected: protected:
void CalOutputShapeAndInputPadSize(const std::vector<index_t> &input_shape,
const std::vector<index_t> &filter_shape,
std::vector<index_t> *output_shape,
std::vector<int> *in_pad_size);
void CalOutputBoundaryWithoutUsingInputPad(const std::vector<index_t>
&output_shape,
const std::vector<int>
in_pad_size,
std::vector<index_t>
*out_bound);
void CalOutputShapeAndPadSize(const Tensor *input, void CalOutputShapeAndPadSize(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const int out_tile_height, const int out_tile_height,
......
...@@ -12,15 +12,13 @@ ...@@ -12,15 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h> #include <arm_neon.h>
#endif #include "mace/ops/arm/fp32/depthwise_conv_2d_3x3.h"
#include "mace/utils/macros.h"
#include "mace/ops/arm/depthwise_conv2d_neon.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm {
namespace fp32 {
namespace { namespace {
void DepthwiseConv2dPixel(const float *in_base, void DepthwiseConv2dPixel(const float *in_base,
...@@ -49,42 +47,58 @@ void DepthwiseConv2dPixel(const float *in_base, ...@@ -49,42 +47,58 @@ void DepthwiseConv2dPixel(const float *in_base,
} }
} // namespace } // namespace
// Ho = 2, Wo = 4, Co = 1 MaceStatus DepthwiseConv2dK3x3S1::Compute(const mace::OpContext *context,
void DepthwiseConv2dNeonK3x3S1(const float *input, const mace::Tensor *input,
const float *filter, const mace::Tensor *filter,
const index_t *in_shape, mace::Tensor *output) {
const index_t *out_shape, MACE_UNUSED(context);
const int *pad_hw, std::vector<index_t> out_shape(4);
const index_t valid_h_start, std::vector<int> paddings(2);
const index_t valid_h_stop, auto &in_shape = input->shape();
const index_t valid_w_start, auto &filter_shape = filter->shape();
const index_t valid_w_stop, CalOutputShapeAndInputPadSize(in_shape, filter_shape, &out_shape, &paddings);
float *output) { out_shape[1] *= filter_shape[1];
#if !defined(MACE_ENABLE_NEON) MACE_RETURN_IF_ERROR(output->Resize(out_shape));
MACE_UNUSED(valid_w_start); output->Clear();
MACE_UNUSED(valid_w_stop);
#endif const int pad_top = paddings[0] / 2;
const int pad_left = paddings[1] / 2;
const index_t multiplier = out_shape[1] / in_shape[1]; const index_t multiplier = out_shape[1] / in_shape[1];
const index_t in_image_size = in_shape[2] * in_shape[3]; const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3]; const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size; const index_t in_batch_size = in_shape[1] * in_image_size;
const index_t out_batch_size = out_shape[1] * out_image_size; const index_t out_batch_size = out_shape[1] * out_image_size;
std::vector<index_t> out_bounds;
CalOutputBoundaryWithoutUsingInputPad(out_shape, paddings, &out_bounds);
Tensor::MappingGuard in_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard out_guard(output);
auto filter_data = filter->data<float>();
auto input_data = input->data<float>();
auto output_data = output->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime) #pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < in_shape[0]; ++b) { for (index_t b = 0; b < in_shape[0]; ++b) {
for (index_t m = 0; m < out_shape[1]; ++m) { for (index_t m = 0; m < out_shape[1]; ++m) {
index_t c = m / multiplier; const index_t c = m / multiplier;
index_t multi_index = m % multiplier; const index_t multi_index = m % multiplier;
const float *in_base = input + b * in_batch_size + c * in_image_size; const float *in_base = input_data + b * in_batch_size + c * in_image_size;
const float *filter_ptr = filter + multi_index * in_shape[1] * 9 + c * 9; const float
float *out_base = output + b * out_batch_size + m * out_image_size; *filter_ptr = filter_data + multi_index * in_shape[1] * 9 + c * 9;
float *out_base = output_data + b * out_batch_size + m * out_image_size;
index_t h, w; index_t h, w;
const index_t pad_top = pad_hw[0];
const index_t pad_left = pad_hw[1];
const index_t out_width = out_shape[3]; const index_t out_width = out_shape[3];
const index_t in_height = in_shape[2]; const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3]; const index_t in_width = in_shape[3];
const index_t valid_h_start = out_bounds[0];
const index_t valid_h_stop = out_bounds[1];
const index_t valid_w_start = out_bounds[2];
const index_t valid_w_stop = out_bounds[3];
// top // top
for (h = 0; h < valid_h_start; ++h) { for (h = 0; h < valid_h_start; ++h) {
for (w = 0; w < out_shape[3]; ++w) { for (w = 0; w < out_shape[3]; ++w) {
...@@ -94,7 +108,6 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, ...@@ -94,7 +108,6 @@ void DepthwiseConv2dNeonK3x3S1(const float *input,
} }
} }
#if defined(MACE_ENABLE_NEON)
// load filter (1 outch x 3 height x 3 width): vf_outch_height // load filter (1 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02; float32x4_t vf00, vf01, vf02;
vf00 = vld1q_f32(filter_ptr); vf00 = vld1q_f32(filter_ptr);
...@@ -208,15 +221,7 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, ...@@ -208,15 +221,7 @@ void DepthwiseConv2dNeonK3x3S1(const float *input,
3, out_base); 3, out_base);
} }
} // h } // h
#else
for (index_t ih = valid_h_start; ih < valid_h_stop; ++ih) {
for (index_t iw = 0; iw < out_shape[3]; ++iw) {
DepthwiseConv2dPixel(in_base, filter_ptr, ih, iw, ih - pad_top,
iw - pad_left, out_width, in_height, in_width, 3,
3, out_base);
}
}
#endif
// bottom // bottom
for (; h < out_shape[2]; ++h) { for (; h < out_shape[2]; ++h) {
...@@ -228,42 +233,64 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, ...@@ -228,42 +233,64 @@ void DepthwiseConv2dNeonK3x3S1(const float *input,
} }
} // m } // m
} // b } // b
return MaceStatus::MACE_SUCCESS;
} }
void DepthwiseConv2dNeonK3x3S2(const float *input, MaceStatus DepthwiseConv2dK3x3S2::Compute(const mace::OpContext *context,
const float *filter, const mace::Tensor *input,
const index_t *in_shape, const mace::Tensor *filter,
const index_t *out_shape, mace::Tensor *output) {
const int *pad_hw, MACE_UNUSED(context);
const index_t valid_h_start,
const index_t valid_h_stop, std::vector<index_t> out_shape(4);
const index_t valid_w_start, std::vector<int> paddings(2);
const index_t valid_w_stop, auto &in_shape = input->shape();
float *output) { auto &filter_shape = filter->shape();
#if !defined(MACE_ENABLE_NEON)
MACE_UNUSED(valid_w_start); CalOutputShapeAndInputPadSize(in_shape, filter_shape, &out_shape, &paddings);
MACE_UNUSED(valid_w_stop); out_shape[1] *= in_shape[1];
#endif MACE_RETURN_IF_ERROR(output->Resize(out_shape));
output->Clear();
const int pad_top = paddings[0] / 2;
const int pad_left = paddings[1] / 2;
const index_t multiplier = out_shape[1] / in_shape[1]; const index_t multiplier = out_shape[1] / in_shape[1];
const index_t in_image_size = in_shape[2] * in_shape[3]; const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3]; const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size; const index_t in_batch_size = in_shape[1] * in_image_size;
const index_t out_batch_size = out_shape[1] * out_image_size; const index_t out_batch_size = out_shape[1] * out_image_size;
std::vector<index_t> out_bounds;
CalOutputBoundaryWithoutUsingInputPad(out_shape, paddings, &out_bounds);
Tensor::MappingGuard in_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard out_guard(output);
auto filter_data = filter->data<float>();
auto input_data = input->data<float>();
auto output_data = output->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime) #pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < in_shape[0]; ++b) { for (index_t b = 0; b < in_shape[0]; ++b) {
for (index_t m = 0; m < out_shape[1]; ++m) { for (index_t m = 0; m < out_shape[1]; ++m) {
index_t c = m / multiplier; index_t c = m / multiplier;
index_t multi_index = m % multiplier; index_t multi_index = m % multiplier;
const float *in_base = input + b * in_batch_size + c * in_image_size; const float *in_base = input_data + b * in_batch_size + c * in_image_size;
const float *filter_ptr = filter + multi_index * in_shape[1] * 9 + c * 9; const float
float *out_base = output + b * out_batch_size + m * out_image_size; *filter_ptr = filter_data + multi_index * in_shape[1] * 9 + c * 9;
float *out_base = output_data + b * out_batch_size + m * out_image_size;
index_t h, w; index_t h, w;
const index_t pad_top = pad_hw[0];
const index_t pad_left = pad_hw[1];
const index_t out_width = out_shape[3]; const index_t out_width = out_shape[3];
const index_t in_height = in_shape[2]; const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3]; const index_t in_width = in_shape[3];
const index_t valid_h_start = out_bounds[0];
const index_t valid_h_stop = out_bounds[1];
const index_t valid_w_start = out_bounds[2];
const index_t valid_w_stop = out_bounds[3];
// top // top
for (h = 0; h < valid_h_start; ++h) { for (h = 0; h < valid_h_start; ++h) {
for (w = 0; w < out_width; ++w) { for (w = 0; w < out_width; ++w) {
...@@ -273,7 +300,6 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, ...@@ -273,7 +300,6 @@ void DepthwiseConv2dNeonK3x3S2(const float *input,
} }
} }
#if defined(MACE_ENABLE_NEON)
// load filter (1 outch x 3 height x 3 width): vf_outch_height // load filter (1 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02; float32x4_t vf00, vf01, vf02;
vf00 = vld1q_f32(filter_ptr); vf00 = vld1q_f32(filter_ptr);
...@@ -359,15 +385,7 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, ...@@ -359,15 +385,7 @@ void DepthwiseConv2dNeonK3x3S2(const float *input,
3, 3, out_base); 3, 3, out_base);
} }
} // h } // h
#else
for (index_t ih = valid_h_start; ih < valid_h_stop; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
DepthwiseConv2dPixel(in_base, filter_ptr, ih, iw, ih * 2 - pad_top,
iw * 2 - pad_left, out_width, in_height,
in_width, 3, 3, out_base);
}
}
#endif
// bottom // bottom
for (; h < out_shape[2]; ++h) { for (; h < out_shape[2]; ++h) {
...@@ -379,7 +397,11 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, ...@@ -379,7 +397,11 @@ void DepthwiseConv2dNeonK3x3S2(const float *input,
} }
} // m } // m
} // b } // b
return MaceStatus::MACE_SUCCESS;
} }
} // namespace fp32
} // namespace arm
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -12,37 +12,51 @@ ...@@ -12,37 +12,51 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_ARM_DEPTHWISE_CONV2D_NEON_H_ #ifndef MACE_OPS_ARM_FP32_DEPTHWISE_CONV_2D_3X3_H_
#define MACE_OPS_ARM_DEPTHWISE_CONV2D_NEON_H_ #define MACE_OPS_ARM_FP32_DEPTHWISE_CONV_2D_3X3_H_
#include "mace/core/types.h" #include <vector>
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/conv_2d.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm {
namespace fp32 {
void DepthwiseConv2dNeonK3x3S1(const float *input, class DepthwiseConv2dK3x3S1 : public Conv2dBase {
const float *filter, public:
const index_t *in_shape, DepthwiseConv2dK3x3S1(const std::vector<int> paddings,
const index_t *out_shape, const Padding padding_type)
const int *pad_hw, : Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
const index_t valid_h_start, virtual ~DepthwiseConv2dK3x3S1() {}
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output);
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
const int *pad_hw,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output);
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
};
class DepthwiseConv2dK3x3S2 : public Conv2dBase {
public:
DepthwiseConv2dK3x3S2(const std::vector<int> paddings,
const Padding padding_type)
: Conv2dBase({2, 2}, {1, 1}, paddings, padding_type) {}
virtual ~DepthwiseConv2dK3x3S2() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
};
} // namespace fp32
} // namespace arm
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_ARM_DEPTHWISE_CONV2D_NEON_H_ #endif // MACE_OPS_ARM_FP32_DEPTHWISE_CONV_2D_3X3_H_
...@@ -12,14 +12,17 @@ ...@@ -12,14 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/ops/ref/depthwise_conv_2d.h"
#if defined(MACE_ENABLE_NEON)
#include "mace/ops/arm/fp32/depthwise_conv_2d_3x3.h"
#endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/quantization_util.h" #include "mace/ops/quantization_util.h"
// We reuse TensorFlow Lite's optimized depthwiseconv_uint8 and parallelized it // We reuse TensorFlow Lite's optimized depthwiseconv_uint8 and parallelized it
...@@ -30,7 +33,6 @@ ...@@ -30,7 +33,6 @@
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include "mace/ops/arm/depthwise_conv2d_neon.h"
#include "mace/ops/conv_pool_2d_base.h" #include "mace/ops/conv_pool_2d_base.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
...@@ -60,10 +62,10 @@ class DepthwiseConv2dOpBase : public ConvPool2dOpBase { ...@@ -60,10 +62,10 @@ class DepthwiseConv2dOpBase : public ConvPool2dOpBase {
const float leakyrelu_coefficient_; const float leakyrelu_coefficient_;
}; };
template <DeviceType D, class T> template<DeviceType D, class T>
class DepthwiseConv2dOp; class DepthwiseConv2dOp;
template <> template<>
class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase { class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase {
public: public:
explicit DepthwiseConv2dOp(OpConstructContext *context) explicit DepthwiseConv2dOp(OpConstructContext *context)
...@@ -82,133 +84,60 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase { ...@@ -82,133 +84,60 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase {
MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape(4); #ifdef MACE_ENABLE_NEON
std::vector<int> paddings(2); const index_t filter_h = filter->dim(2);
std::vector<index_t> filter_shape const index_t filter_w = filter->dim(3);
{filter->dim(0) * filter->dim(1), filter->dim(1), filter->dim(2), const index_t stride_h = strides_[0];
filter->dim(3)}; const index_t stride_w = strides_[1];
const index_t dilation_h = dilations_[0];
const index_t dilation_w = dilations_[1];
if (paddings_.empty()) { if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
CalcNCHWPaddingAndOutputSize(input->shape().data(), && dilation_h == 1 && dilation_w == 1) {
filter_shape.data(), if (conv2d_delegator_.get() == nullptr) {
dilations_.data(), conv2d_delegator_ =
strides_.data(), make_unique<arm::fp32::DepthwiseConv2dK3x3S1>(paddings_,
padding_type_, padding_type_);
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(),
filter_shape.data(),
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::FLOOR,
output_shape.data());
} }
MACE_RETURN_IF_ERROR(output->Resize(output_shape)); conv2d_delegator_->Compute(context, input, filter, output);
output->Clear(); } else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) {
index_t batch = output->dim(0); if (conv2d_delegator_.get() == nullptr) {
index_t channels = output->dim(1); conv2d_delegator_ =
index_t height = output->dim(2); make_unique<arm::fp32::DepthwiseConv2dK3x3S2>(paddings_,
index_t width = output->dim(3); padding_type_);
}
index_t input_batch = input->dim(0); conv2d_delegator_->Compute(context, input, filter, output);
index_t input_channels = input->dim(1); } else {
index_t input_height = input->dim(2); if (ref_conv2d_delegator_.get() == nullptr) {
index_t input_width = input->dim(3); ref_conv2d_delegator_ =
make_unique<ref::DepthwiseConv2d<float>>(strides_,
index_t filter_h = filter_shape[2]; dilations_,
index_t filter_w = filter_shape[3]; paddings_,
MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels); padding_type_);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ", }
input_channels); ref_conv2d_delegator_->Compute(context, input, filter, output);
}
index_t stride_h = strides_[0]; #else
index_t stride_w = strides_[1]; if (ref_conv2d_delegator_.get() == nullptr) {
ref_conv2d_delegator_ =
index_t dilation_h = dilations_[0]; make_unique<ref::DepthwiseConv2d<float>>(strides_,
index_t dilation_w = dilations_[1]; dilations_,
paddings_,
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); padding_type_);
}
int pad_top = paddings[0] >> 1; ref_conv2d_delegator_->Compute(context, input, filter, output);
int pad_bottom = paddings[0] - pad_top; #endif // MACE_ENABLE_NEON
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
index_t valid_h_start = pad_top == 0 ? 0 : (pad_top - 1) / stride_h + 1;
index_t valid_h_stop = pad_bottom == 0
? height
: height - ((pad_bottom - 1) / stride_h + 1);
index_t valid_w_start = pad_left == 0 ? 0 : (pad_left - 1) / stride_w + 1;
index_t valid_w_stop = pad_right == 0
? width
: width - ((pad_right - 1) / stride_w + 1);
std::function<void(const float *input, float *output)> conv_func;
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard bias_guard(bias); Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>(); auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>(); auto output_data = output->mutable_data<float>();
const int pad_hw[2] = {pad_top, pad_left}; const index_t batch = output->dim(0);
const index_t input_shape[4] = const index_t channels = output->dim(1);
{batch, input_channels, input_height, input_width}; const index_t height = output->dim(2);
const index_t width = output->dim(3);
// make host compiler happy
MACE_UNUSED(pad_hw);
MACE_UNUSED(input_shape);
if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S1(input,
filter_data,
input_shape,
output_shape.data(),
pad_hw,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S2(input,
filter_data,
input_shape,
output_shape.data(),
pad_hw,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dGeneral(input,
filter_data,
input_shape,
output_shape.data(),
filter_shape.data(),
strides_.data(),
dilations_.data(),
pad_hw,
output);
};
}
conv_func(input_data, output_data);
if (bias_data != nullptr) { if (bias_data != nullptr) {
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
...@@ -229,55 +158,10 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase { ...@@ -229,55 +158,10 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase {
} }
private: private:
void DepthwiseConv2dGeneral(const float *input, #ifdef MACE_ENABLE_NEON
const float *filter, std::unique_ptr<arm::fp32::Conv2dBase> conv2d_delegator_;
const index_t *in_shape, #endif // MACE_ENABLE_NEON
const index_t *out_shape, std::unique_ptr<ref::DepthwiseConv2d<float>> ref_conv2d_delegator_;
const index_t *filter_shape,
const int *stride_hw,
const int *dilation_hw,
const int *pad_hw,
float *output) {
const index_t multiplier = filter_shape[0] / filter_shape[1];
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < in_shape[0]; ++b) {
for (index_t m = 0; m < filter_shape[0]; ++m) {
for (index_t h = 0; h < out_shape[2]; ++h) {
for (index_t w = 0; w < out_shape[3]; ++w) {
const index_t out_channels = filter_shape[0];
const index_t in_channels = filter_shape[1];
const index_t filter_height = filter_shape[2];
const index_t filter_width = filter_shape[3];
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
index_t c = m / multiplier;
index_t o = m % multiplier;
float sum = 0;
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_hw[0] + kh * dilation_hw[0] - pad_hw[0];
index_t iw = w * stride_hw[1] + kw * dilation_hw[1] - pad_hw[1];
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((o * in_channels) + c) * filter_height + kh)
* filter_width + kw;
sum += input[in_offset] * filter[filter_offset];
}
}
}
output[out_offset] = sum;
}
}
}
}
}
protected: protected:
MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS); MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS);
...@@ -542,7 +426,6 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase { ...@@ -542,7 +426,6 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) { void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "DepthwiseConv2d", MACE_REGISTER_OP(op_registry, "DepthwiseConv2d",
DepthwiseConv2dOp, DeviceType::CPU, float); DepthwiseConv2dOp, DeviceType::CPU, float);
......
...@@ -64,7 +64,7 @@ class Conv2d<float> { ...@@ -64,7 +64,7 @@ class Conv2d<float> {
paddings_(paddings), paddings_(paddings),
padding_type_(padding_type) {} padding_type_(padding_type) {}
~Conv2d() {} ~Conv2d() {}
// Always row-major after transpose
MaceStatus Compute( MaceStatus Compute(
const OpContext *context, const OpContext *context,
const Tensor *input, const Tensor *input,
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ref/depthwise_conv_2d.h"
#include <vector>
namespace mace {
namespace ops {
namespace ref {
MaceStatus DepthwiseConv2d<float>::Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) {
MACE_UNUSED(context);
const std::vector<index_t> in_shape = input->shape();
const std::vector<index_t> filter_shape = filter->shape();
std::vector<index_t> out_shape(4);
std::vector<int> paddings(2);
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter->shape().data(),
dilations_.data(),
strides_.data(),
padding_type_,
out_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(),
filter->shape().data(),
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::FLOOR,
out_shape.data());
}
out_shape[1] *= in_shape[1];
const index_t pad_top = paddings[0] >> 1;
const index_t pad_left = paddings[1] >> 1;
output->Resize(out_shape);
const index_t multiplier = filter_shape[0];
const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size;
const index_t out_batch_size = out_shape[1] * out_image_size;
const index_t filter_size = filter_shape[2] * filter_shape[3];
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto output_data = output->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < in_shape[0]; b++) {
for (index_t m = 0; m < out_shape[1]; ++m) {
const index_t c = m / multiplier;
const index_t multi_index = m % multiplier;
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
const index_t in_channels = in_shape[1];
float *out_ptr_base =
output_data + b * out_batch_size + m * out_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
float sum = 0;
const float *in_ptr_base =
input_data + b * in_batch_size + c * in_image_size;
const float *filter_ptr =
filter_data + multi_index * in_channels * filter_size
+ c * filter_size;
for (index_t kh = 0; kh < filter_shape[2]; ++kh) {
for (index_t kw = 0; kw < filter_shape[3]; ++kw) {
const index_t
ih = -pad_top + h * strides_[0] + kh * dilations_[0];
const index_t
iw = -pad_left + w * strides_[1] + kw * dilations_[1];
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
sum += in_ptr_base[ih * in_width + iw] * filter_ptr[kw];
}
} // kw
filter_ptr += filter_shape[3];
} // kh
out_ptr_base[h * out_width + w] = sum;
} // w
} // h
} // m
} // b
return MaceStatus::MACE_SUCCESS;
}
} // namespace ref
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REF_DEPTHWISE_CONV_2D_H_
#define MACE_OPS_REF_DEPTHWISE_CONV_2D_H_
#include <vector>
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
namespace ref {
template<typename OUTPUT_TYPE>
class DepthwiseConv2d {
public:
DepthwiseConv2d(const std::vector<int> strides,
const std::vector<int> dilations,
const std::vector<int> paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
~DepthwiseConv2d() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
};
template<>
class DepthwiseConv2d<float> {
public:
DepthwiseConv2d(const std::vector<int> strides,
const std::vector<int> dilations,
const std::vector<int> paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
~DepthwiseConv2d() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
};
} // namespace ref
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REF_DEPTHWISE_CONV_2D_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册