提交 cd7cabb5 编写于 作者: 李滨

Merge branch 'liyin/mace-depthwise' into 'master'

Refactor depthwise conv

See merge request !1052
......@@ -24,6 +24,49 @@ namespace ops {
namespace arm {
namespace fp32 {
void Conv2dBase::CalOutputShapeAndInputPadSize(
const std::vector<index_t> &input_shape,
const std::vector<index_t> &filter_shape,
std::vector<index_t> *output_shape,
std::vector<int> *in_pad_size) {
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input_shape.data(),
filter_shape.data(),
dilations_.data(),
strides_.data(),
padding_type_,
output_shape->data(),
in_pad_size->data());
} else {
*in_pad_size = paddings_;
CalcNCHWOutputSize(input_shape.data(),
filter_shape.data(),
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::FLOOR,
output_shape->data());
}
}
void Conv2dBase::CalOutputBoundaryWithoutUsingInputPad(
const std::vector<index_t> &output_shape,
const std::vector<int> in_pad_size,
std::vector<index_t> *out_bound) {
const int pad_top = in_pad_size[0] >> 1;
const int pad_bottom = in_pad_size[0] - pad_top;
const int pad_left = in_pad_size[1] >> 1;
const int pad_right = in_pad_size[1] - pad_left;
const index_t height = output_shape[2];
const index_t width = output_shape[3];
*out_bound = {
pad_top == 0 ? 0 : (pad_top - 1) / strides_[0] + 1,
pad_bottom == 0 ? height : height - ((pad_bottom - 1) / strides_[0] + 1),
pad_left == 0 ? 0 : (pad_left - 1) / strides_[1] + 1,
pad_right == 0 ? width : width - ((pad_right - 1) / strides_[1] + 1),
};
}
void Conv2dBase::CalOutputShapeAndPadSize(const Tensor *input,
const Tensor *filter,
const int out_tile_height,
......@@ -46,24 +89,11 @@ void Conv2dBase::CalOutputShapeAndPadSize(const Tensor *input,
const index_t filter_w = filter->dim(3);
std::vector<int> paddings(2);
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter->shape().data(),
dilations_.data(),
strides_.data(),
padding_type_,
output_shape->data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(),
filter->shape().data(),
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::FLOOR,
output_shape->data());
}
CalOutputShapeAndInputPadSize(input->shape(),
filter->shape(),
output_shape,
&paddings);
const index_t out_height = (*output_shape)[2];
const index_t out_width = (*output_shape)[3];
const index_t
......@@ -96,9 +126,9 @@ MaceStatus Conv2dBase::ResizeOutAndPadInOut(const OpContext *context,
const int out_tile_height,
const int out_tile_width,
std::unique_ptr<const Tensor>
*padded_input,
*padded_input,
std::unique_ptr<Tensor>
*padded_output) {
*padded_output) {
std::vector<index_t> output_shape;
std::vector<int> in_pad_size;
std::vector<int> out_pad_size;
......@@ -152,8 +182,9 @@ MaceStatus Conv2dBase::ResizeOutAndPadInOut(const OpContext *context,
}
if (is_out_padded) {
std::unique_ptr<Tensor>
padded_out = make_unique<Tensor>(scratch_buffer->Scratch(padded_out_size),
DataType::DT_FLOAT);
padded_out =
make_unique<Tensor>(scratch_buffer->Scratch(padded_out_size),
DataType::DT_FLOAT);
padded_out->Resize({batch, out_channels, padded_out_height,
padded_out_width});
*padded_output = std::move(padded_out);
......
......@@ -49,6 +49,18 @@ class Conv2dBase {
Tensor *output) = 0;
protected:
void CalOutputShapeAndInputPadSize(const std::vector<index_t> &input_shape,
const std::vector<index_t> &filter_shape,
std::vector<index_t> *output_shape,
std::vector<int> *in_pad_size);
void CalOutputBoundaryWithoutUsingInputPad(const std::vector<index_t>
&output_shape,
const std::vector<int>
in_pad_size,
std::vector<index_t>
*out_bound);
void CalOutputShapeAndPadSize(const Tensor *input,
const Tensor *filter,
const int out_tile_height,
......
......@@ -12,15 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/utils/macros.h"
#include "mace/ops/arm/depthwise_conv2d_neon.h"
#include "mace/ops/arm/fp32/depthwise_conv_2d_3x3.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
namespace {
void DepthwiseConv2dPixel(const float *in_base,
......@@ -49,42 +47,58 @@ void DepthwiseConv2dPixel(const float *in_base,
}
} // namespace
// Ho = 2, Wo = 4, Co = 1
void DepthwiseConv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
const int *pad_hw,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output) {
#if !defined(MACE_ENABLE_NEON)
MACE_UNUSED(valid_w_start);
MACE_UNUSED(valid_w_stop);
#endif
MaceStatus DepthwiseConv2dK3x3S1::Compute(const mace::OpContext *context,
const mace::Tensor *input,
const mace::Tensor *filter,
mace::Tensor *output) {
MACE_UNUSED(context);
std::vector<index_t> out_shape(4);
std::vector<int> paddings(2);
auto &in_shape = input->shape();
auto &filter_shape = filter->shape();
CalOutputShapeAndInputPadSize(in_shape, filter_shape, &out_shape, &paddings);
out_shape[1] *= filter_shape[1];
MACE_RETURN_IF_ERROR(output->Resize(out_shape));
output->Clear();
const int pad_top = paddings[0] / 2;
const int pad_left = paddings[1] / 2;
const index_t multiplier = out_shape[1] / in_shape[1];
const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size;
const index_t out_batch_size = out_shape[1] * out_image_size;
std::vector<index_t> out_bounds;
CalOutputBoundaryWithoutUsingInputPad(out_shape, paddings, &out_bounds);
Tensor::MappingGuard in_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard out_guard(output);
auto filter_data = filter->data<float>();
auto input_data = input->data<float>();
auto output_data = output->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < in_shape[0]; ++b) {
for (index_t m = 0; m < out_shape[1]; ++m) {
index_t c = m / multiplier;
index_t multi_index = m % multiplier;
const float *in_base = input + b * in_batch_size + c * in_image_size;
const float *filter_ptr = filter + multi_index * in_shape[1] * 9 + c * 9;
float *out_base = output + b * out_batch_size + m * out_image_size;
const index_t c = m / multiplier;
const index_t multi_index = m % multiplier;
const float *in_base = input_data + b * in_batch_size + c * in_image_size;
const float
*filter_ptr = filter_data + multi_index * in_shape[1] * 9 + c * 9;
float *out_base = output_data + b * out_batch_size + m * out_image_size;
index_t h, w;
const index_t pad_top = pad_hw[0];
const index_t pad_left = pad_hw[1];
const index_t out_width = out_shape[3];
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t valid_h_start = out_bounds[0];
const index_t valid_h_stop = out_bounds[1];
const index_t valid_w_start = out_bounds[2];
const index_t valid_w_stop = out_bounds[3];
// top
for (h = 0; h < valid_h_start; ++h) {
for (w = 0; w < out_shape[3]; ++w) {
......@@ -94,7 +108,6 @@ void DepthwiseConv2dNeonK3x3S1(const float *input,
}
}
#if defined(MACE_ENABLE_NEON)
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02;
vf00 = vld1q_f32(filter_ptr);
......@@ -208,15 +221,7 @@ void DepthwiseConv2dNeonK3x3S1(const float *input,
3, out_base);
}
} // h
#else
for (index_t ih = valid_h_start; ih < valid_h_stop; ++ih) {
for (index_t iw = 0; iw < out_shape[3]; ++iw) {
DepthwiseConv2dPixel(in_base, filter_ptr, ih, iw, ih - pad_top,
iw - pad_left, out_width, in_height, in_width, 3,
3, out_base);
}
}
#endif
// bottom
for (; h < out_shape[2]; ++h) {
......@@ -228,42 +233,64 @@ void DepthwiseConv2dNeonK3x3S1(const float *input,
}
} // m
} // b
return MaceStatus::MACE_SUCCESS;
}
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
const int *pad_hw,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output) {
#if !defined(MACE_ENABLE_NEON)
MACE_UNUSED(valid_w_start);
MACE_UNUSED(valid_w_stop);
#endif
MaceStatus DepthwiseConv2dK3x3S2::Compute(const mace::OpContext *context,
const mace::Tensor *input,
const mace::Tensor *filter,
mace::Tensor *output) {
MACE_UNUSED(context);
std::vector<index_t> out_shape(4);
std::vector<int> paddings(2);
auto &in_shape = input->shape();
auto &filter_shape = filter->shape();
CalOutputShapeAndInputPadSize(in_shape, filter_shape, &out_shape, &paddings);
out_shape[1] *= in_shape[1];
MACE_RETURN_IF_ERROR(output->Resize(out_shape));
output->Clear();
const int pad_top = paddings[0] / 2;
const int pad_left = paddings[1] / 2;
const index_t multiplier = out_shape[1] / in_shape[1];
const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size;
const index_t out_batch_size = out_shape[1] * out_image_size;
std::vector<index_t> out_bounds;
CalOutputBoundaryWithoutUsingInputPad(out_shape, paddings, &out_bounds);
Tensor::MappingGuard in_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard out_guard(output);
auto filter_data = filter->data<float>();
auto input_data = input->data<float>();
auto output_data = output->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < in_shape[0]; ++b) {
for (index_t m = 0; m < out_shape[1]; ++m) {
index_t c = m / multiplier;
index_t multi_index = m % multiplier;
const float *in_base = input + b * in_batch_size + c * in_image_size;
const float *filter_ptr = filter + multi_index * in_shape[1] * 9 + c * 9;
float *out_base = output + b * out_batch_size + m * out_image_size;
const float *in_base = input_data + b * in_batch_size + c * in_image_size;
const float
*filter_ptr = filter_data + multi_index * in_shape[1] * 9 + c * 9;
float *out_base = output_data + b * out_batch_size + m * out_image_size;
index_t h, w;
const index_t pad_top = pad_hw[0];
const index_t pad_left = pad_hw[1];
const index_t out_width = out_shape[3];
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t valid_h_start = out_bounds[0];
const index_t valid_h_stop = out_bounds[1];
const index_t valid_w_start = out_bounds[2];
const index_t valid_w_stop = out_bounds[3];
// top
for (h = 0; h < valid_h_start; ++h) {
for (w = 0; w < out_width; ++w) {
......@@ -273,7 +300,6 @@ void DepthwiseConv2dNeonK3x3S2(const float *input,
}
}
#if defined(MACE_ENABLE_NEON)
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02;
vf00 = vld1q_f32(filter_ptr);
......@@ -359,15 +385,7 @@ void DepthwiseConv2dNeonK3x3S2(const float *input,
3, 3, out_base);
}
} // h
#else
for (index_t ih = valid_h_start; ih < valid_h_stop; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
DepthwiseConv2dPixel(in_base, filter_ptr, ih, iw, ih * 2 - pad_top,
iw * 2 - pad_left, out_width, in_height,
in_width, 3, 3, out_base);
}
}
#endif
// bottom
for (; h < out_shape[2]; ++h) {
......@@ -379,7 +397,11 @@ void DepthwiseConv2dNeonK3x3S2(const float *input,
}
} // m
} // b
return MaceStatus::MACE_SUCCESS;
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
......@@ -12,37 +12,51 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_DEPTHWISE_CONV2D_NEON_H_
#define MACE_OPS_ARM_DEPTHWISE_CONV2D_NEON_H_
#ifndef MACE_OPS_ARM_FP32_DEPTHWISE_CONV_2D_3X3_H_
#define MACE_OPS_ARM_FP32_DEPTHWISE_CONV_2D_3X3_H_
#include "mace/core/types.h"
#include <vector>
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/conv_2d.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
void DepthwiseConv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
const int *pad_hw,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output);
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
const int *pad_hw,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output);
class DepthwiseConv2dK3x3S1 : public Conv2dBase {
public:
DepthwiseConv2dK3x3S1(const std::vector<int> paddings,
const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~DepthwiseConv2dK3x3S1() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
};
class DepthwiseConv2dK3x3S2 : public Conv2dBase {
public:
DepthwiseConv2dK3x3S2(const std::vector<int> paddings,
const Padding padding_type)
: Conv2dBase({2, 2}, {1, 1}, paddings, padding_type) {}
virtual ~DepthwiseConv2dK3x3S2() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_DEPTHWISE_CONV2D_NEON_H_
#endif // MACE_OPS_ARM_FP32_DEPTHWISE_CONV_2D_3X3_H_
......@@ -12,14 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "mace/ops/ref/depthwise_conv_2d.h"
#if defined(MACE_ENABLE_NEON)
#include "mace/ops/arm/fp32/depthwise_conv_2d_3x3.h"
#endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/quantization_util.h"
// We reuse TensorFlow Lite's optimized depthwiseconv_uint8 and parallelized it
......@@ -30,7 +33,6 @@
#include "mace/core/future.h"
#include "mace/core/operator.h"
#include "mace/ops/activation.h"
#include "mace/ops/arm/depthwise_conv2d_neon.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/public/mace.h"
#include "mace/utils/memory.h"
......@@ -50,20 +52,20 @@ class DepthwiseConv2dOpBase : public ConvPool2dOpBase {
: ConvPool2dOpBase(context),
activation_(ops::StringToActivationType(
Operation::GetOptionalArg<std::string>("activation",
"NOOP"))),
"NOOP"))),
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)) {}
"leakyrelu_coefficient", 0.0f)) {}
protected:
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
};
template <DeviceType D, class T>
template<DeviceType D, class T>
class DepthwiseConv2dOp;
template <>
template<>
class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase {
public:
explicit DepthwiseConv2dOp(OpConstructContext *context)
......@@ -82,133 +84,60 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase {
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape
{filter->dim(0) * filter->dim(1), filter->dim(1), filter->dim(2),
filter->dim(3)};
#ifdef MACE_ENABLE_NEON
const index_t filter_h = filter->dim(2);
const index_t filter_w = filter->dim(3);
const index_t stride_h = strides_[0];
const index_t stride_w = strides_[1];
const index_t dilation_h = dilations_[0];
const index_t dilation_w = dilations_[1];
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter_shape.data(),
dilations_.data(),
strides_.data(),
padding_type_,
output_shape.data(),
paddings.data());
if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
if (conv2d_delegator_.get() == nullptr) {
conv2d_delegator_ =
make_unique<arm::fp32::DepthwiseConv2dK3x3S1>(paddings_,
padding_type_);
}
conv2d_delegator_->Compute(context, input, filter, output);
} else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) {
if (conv2d_delegator_.get() == nullptr) {
conv2d_delegator_ =
make_unique<arm::fp32::DepthwiseConv2dK3x3S2>(paddings_,
padding_type_);
}
conv2d_delegator_->Compute(context, input, filter, output);
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(),
filter_shape.data(),
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::FLOOR,
output_shape.data());
if (ref_conv2d_delegator_.get() == nullptr) {
ref_conv2d_delegator_ =
make_unique<ref::DepthwiseConv2d<float>>(strides_,
dilations_,
paddings_,
padding_type_);
}
ref_conv2d_delegator_->Compute(context, input, filter, output);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
output->Clear();
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter_shape[2];
index_t filter_w = filter_shape[3];
MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
index_t valid_h_start = pad_top == 0 ? 0 : (pad_top - 1) / stride_h + 1;
index_t valid_h_stop = pad_bottom == 0
? height
: height - ((pad_bottom - 1) / stride_h + 1);
index_t valid_w_start = pad_left == 0 ? 0 : (pad_left - 1) / stride_w + 1;
index_t valid_w_stop = pad_right == 0
? width
: width - ((pad_right - 1) / stride_w + 1);
std::function<void(const float *input, float *output)> conv_func;
#else
if (ref_conv2d_delegator_.get() == nullptr) {
ref_conv2d_delegator_ =
make_unique<ref::DepthwiseConv2d<float>>(strides_,
dilations_,
paddings_,
padding_type_);
}
ref_conv2d_delegator_->Compute(context, input, filter, output);
#endif // MACE_ENABLE_NEON
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
const int pad_hw[2] = {pad_top, pad_left};
const index_t input_shape[4] =
{batch, input_channels, input_height, input_width};
// make host compiler happy
MACE_UNUSED(pad_hw);
MACE_UNUSED(input_shape);
if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S1(input,
filter_data,
input_shape,
output_shape.data(),
pad_hw,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S2(input,
filter_data,
input_shape,
output_shape.data(),
pad_hw,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dGeneral(input,
filter_data,
input_shape,
output_shape.data(),
filter_shape.data(),
strides_.data(),
dilations_.data(),
pad_hw,
output);
};
}
conv_func(input_data, output_data);
const index_t batch = output->dim(0);
const index_t channels = output->dim(1);
const index_t height = output->dim(2);
const index_t width = output->dim(3);
if (bias_data != nullptr) {
#pragma omp parallel for collapse(2)
......@@ -229,55 +158,10 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase {
}
private:
void DepthwiseConv2dGeneral(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
const index_t *filter_shape,
const int *stride_hw,
const int *dilation_hw,
const int *pad_hw,
float *output) {
const index_t multiplier = filter_shape[0] / filter_shape[1];
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < in_shape[0]; ++b) {
for (index_t m = 0; m < filter_shape[0]; ++m) {
for (index_t h = 0; h < out_shape[2]; ++h) {
for (index_t w = 0; w < out_shape[3]; ++w) {
const index_t out_channels = filter_shape[0];
const index_t in_channels = filter_shape[1];
const index_t filter_height = filter_shape[2];
const index_t filter_width = filter_shape[3];
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
index_t c = m / multiplier;
index_t o = m % multiplier;
float sum = 0;
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_hw[0] + kh * dilation_hw[0] - pad_hw[0];
index_t iw = w * stride_hw[1] + kw * dilation_hw[1] - pad_hw[1];
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((o * in_channels) + c) * filter_height + kh)
* filter_width + kw;
sum += input[in_offset] * filter[filter_offset];
}
}
}
output[out_offset] = sum;
}
}
}
}
}
#ifdef MACE_ENABLE_NEON
std::unique_ptr<arm::fp32::Conv2dBase> conv2d_delegator_;
#endif // MACE_ENABLE_NEON
std::unique_ptr<ref::DepthwiseConv2d<float>> ref_conv2d_delegator_;
protected:
MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS);
......@@ -542,7 +426,6 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase {
};
#endif // MACE_ENABLE_OPENCL
void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "DepthwiseConv2d",
DepthwiseConv2dOp, DeviceType::CPU, float);
......
......@@ -64,7 +64,7 @@ class Conv2d<float> {
paddings_(paddings),
padding_type_(padding_type) {}
~Conv2d() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ref/depthwise_conv_2d.h"
#include <vector>
namespace mace {
namespace ops {
namespace ref {
MaceStatus DepthwiseConv2d<float>::Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) {
MACE_UNUSED(context);
const std::vector<index_t> in_shape = input->shape();
const std::vector<index_t> filter_shape = filter->shape();
std::vector<index_t> out_shape(4);
std::vector<int> paddings(2);
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter->shape().data(),
dilations_.data(),
strides_.data(),
padding_type_,
out_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(),
filter->shape().data(),
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::FLOOR,
out_shape.data());
}
out_shape[1] *= in_shape[1];
const index_t pad_top = paddings[0] >> 1;
const index_t pad_left = paddings[1] >> 1;
output->Resize(out_shape);
const index_t multiplier = filter_shape[0];
const index_t in_image_size = in_shape[2] * in_shape[3];
const index_t out_image_size = out_shape[2] * out_shape[3];
const index_t in_batch_size = in_shape[1] * in_image_size;
const index_t out_batch_size = out_shape[1] * out_image_size;
const index_t filter_size = filter_shape[2] * filter_shape[3];
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto output_data = output->mutable_data<float>();
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < in_shape[0]; b++) {
for (index_t m = 0; m < out_shape[1]; ++m) {
const index_t c = m / multiplier;
const index_t multi_index = m % multiplier;
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
const index_t in_channels = in_shape[1];
float *out_ptr_base =
output_data + b * out_batch_size + m * out_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
float sum = 0;
const float *in_ptr_base =
input_data + b * in_batch_size + c * in_image_size;
const float *filter_ptr =
filter_data + multi_index * in_channels * filter_size
+ c * filter_size;
for (index_t kh = 0; kh < filter_shape[2]; ++kh) {
for (index_t kw = 0; kw < filter_shape[3]; ++kw) {
const index_t
ih = -pad_top + h * strides_[0] + kh * dilations_[0];
const index_t
iw = -pad_left + w * strides_[1] + kw * dilations_[1];
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
sum += in_ptr_base[ih * in_width + iw] * filter_ptr[kw];
}
} // kw
filter_ptr += filter_shape[3];
} // kh
out_ptr_base[h * out_width + w] = sum;
} // w
} // h
} // m
} // b
return MaceStatus::MACE_SUCCESS;
}
} // namespace ref
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REF_DEPTHWISE_CONV_2D_H_
#define MACE_OPS_REF_DEPTHWISE_CONV_2D_H_
#include <vector>
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
namespace ref {
template<typename OUTPUT_TYPE>
class DepthwiseConv2d {
public:
DepthwiseConv2d(const std::vector<int> strides,
const std::vector<int> dilations,
const std::vector<int> paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
~DepthwiseConv2d() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
};
template<>
class DepthwiseConv2d<float> {
public:
DepthwiseConv2d(const std::vector<int> strides,
const std::vector<int> dilations,
const std::vector<int> paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
~DepthwiseConv2d() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
};
} // namespace ref
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REF_DEPTHWISE_CONV_2D_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册