提交 7c9ca067 编写于 作者: 李寅

Refactor cpu

上级 7c1711d8
...@@ -91,7 +91,6 @@ extern void Register_Eltwise(OperatorRegistry *op_registry); ...@@ -91,7 +91,6 @@ extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_FusedConv2D(OperatorRegistry *op_registry); extern void Register_FusedConv2D(OperatorRegistry *op_registry);
extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry); extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistry *op_registry); extern void Register_LocalResponseNorm(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry);
...@@ -132,7 +131,6 @@ OperatorRegistry::OperatorRegistry() { ...@@ -132,7 +131,6 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_FoldedBatchNorm(this); ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this); ops::Register_FullyConnected(this);
ops::Register_FusedConv2D(this); ops::Register_FusedConv2D(this);
ops::Register_GlobalAvgPooling(this);
ops::Register_ImageToBuffer(this); ops::Register_ImageToBuffer(this);
ops::Register_LocalResponseNorm(this); ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this); ops::Register_MatMul(this);
......
...@@ -318,6 +318,7 @@ class Tensor { ...@@ -318,6 +318,7 @@ class Tensor {
public: public:
explicit MappingGuard(const Tensor *tensor) : tensor_(tensor) { explicit MappingGuard(const Tensor *tensor) : tensor_(tensor) {
if (tensor_ != nullptr) { if (tensor_ != nullptr) {
MACE_CHECK_NOTNULL(tensor_->buffer_);
tensor_->buffer_->Map(&mapped_image_pitch_); tensor_->buffer_->Map(&mapped_image_pitch_);
} }
} }
......
...@@ -121,23 +121,30 @@ void PReLUActivation(const T *input_ptr, ...@@ -121,23 +121,30 @@ void PReLUActivation(const T *input_ptr,
} }
template <DeviceType D, typename T> template <DeviceType D, typename T>
class ActivationFunctor { class ActivationFunctor;
template <>
class ActivationFunctor<DeviceType::CPU, float> {
public: public:
ActivationFunctor(ActivationType type, T relux_max_limit) ActivationFunctor(ActivationType type, float relux_max_limit)
: activation_(type), relux_max_limit_(relux_max_limit) {} : activation_(type), relux_max_limit_(relux_max_limit) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *alpha, const Tensor *alpha,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
const T *input_ptr = input->data<T>(); const float *input_ptr = input->data<float>();
T *output_ptr = output->mutable_data<T>(); float *output_ptr = output->mutable_data<float>();
if (activation_ == PRELU) { if (activation_ == PRELU) {
MACE_CHECK_NOTNULL(alpha); MACE_CHECK_NOTNULL(alpha);
const T *alpha_ptr = alpha->data<T>(); const float *alpha_ptr = alpha->data<float>();
const index_t outer_size = output->dim(0) * output->dim(1) const index_t outer_size = output->dim(0);
* output->dim(2); const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(input_ptr, outer_size, input->dim(3), 1, alpha_ptr, PReLUActivation(input_ptr,
outer_size,
input->dim(1),
inner_size,
alpha_ptr,
output_ptr); output_ptr);
} else { } else {
DoActivation(input_ptr, output_ptr, output->size(), activation_, DoActivation(input_ptr, output_ptr, output->size(), activation_,
...@@ -145,22 +152,6 @@ class ActivationFunctor { ...@@ -145,22 +152,6 @@ class ActivationFunctor {
} }
} }
private:
ActivationType activation_;
T relux_max_limit_;
};
template <>
class ActivationFunctor<DeviceType::NEON, float> {
public:
ActivationFunctor(ActivationType type, float relux_max_limit)
: activation_(type), relux_max_limit_(relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future);
private: private:
ActivationType activation_; ActivationType activation_;
float relux_max_limit_; float relux_max_limit_;
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
void ActivationFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
if (activation_ == PRELU) {
MACE_CHECK_NOTNULL(alpha);
const float *alpha_ptr = alpha->data<float>();
const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(input_ptr, outer_size, input->dim(1), inner_size, alpha_ptr,
output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/batch_norm.h"
namespace mace {
namespace kernels {
void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// ( \offset - \frac { \scale * mean } {
// \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
const float *offset_ptr = offset->data<float>();
float *output_ptr = output->mutable_data<float>();
std::vector<float> new_scale;
std::vector<float> new_offset;
if (!folded_constant_) {
new_scale.resize(channels);
new_offset.resize(channels);
const float *mean_ptr = mean->data<float>();
const float *var_ptr = var->data<float>();
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon);
new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
}
}
const float *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const float *offset_data = folded_constant_ ? offset_ptr : new_offset.data();
index_t channel_size = height * width;
index_t batch_size = channels * channel_size;
// NEON is slower, so stick to the trivial implementaion
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
index_t offset = b * batch_size + c * channel_size;
for (index_t hw = 0; hw < height * width; ++hw) {
output_ptr[offset + hw] =
scale_data[c] * input_ptr[offset + hw] + offset_data[c];
}
}
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/arm/conv_winograd.h"
namespace mace {
namespace kernels {
namespace {
void Conv2dNCHW(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
float *output) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h;
index_t iw = w * stride_w + kw * dilation_w;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((m * in_channels) + c) * filter_height + kh) * filter_width
+ kw;
output[out_offset] += input[in_offset] * filter[filter_offset];
}
}
}
}
}
}
}
}
} // namespace
extern void Conv2dNeonK1x1S1(const float *input,
const float *filter,
const index_t batch,
const index_t height,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> filter_shape(4);
if (is_filter_transformed_) {
// TOC -> OIHW
filter_shape[0] = filter->dim(1);
filter_shape[1] = filter->dim(2);
filter_shape[2] = filter_shape[3] = 3;
} else {
filter_shape = filter->shape();
}
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter_shape.data(),
dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::FLOOR,
output_shape.data());
}
output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter_shape[2];
index_t filter_w = filter_shape[3];
MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
index_t padded_input_height = input_height + paddings[0];
index_t padded_input_width = input_width + paddings[1];
index_t extra_input_height = padded_input_height;
index_t extra_input_width = padded_input_width;
index_t extra_output_height = height;
index_t extra_output_width = width;
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
std::function<void(const float *input, float *output)> conv_func;
bool use_winograd = is_filter_transformed_ || (filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1
&& input_channels >= 8 && channels >= 8);
bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1;
bool use_neon_1x1_s1 = filter_h == 1 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
std::vector<index_t> transformed_input_shape;
std::vector<index_t> transformed_output_shape;
std::vector<index_t> transformed_filter_shape;
// When size of input feature map is bigger than 16x16,
// set winograd out tile size to 6 to get higher performance.
index_t winograd_out_tile_size = 2;
if (input_height > 16 && input_width > 16) {
winograd_out_tile_size = 6;
}
if (use_winograd) {
extra_output_height = RoundUp<index_t>(height, winograd_out_tile_size);
extra_input_height = std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, winograd_out_tile_size);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
index_t tile_height_count = extra_output_height / winograd_out_tile_size;
index_t tile_width_count = extra_output_width / winograd_out_tile_size;
index_t tile_count = tile_height_count * tile_width_count;
index_t in_tile_area =
(winograd_out_tile_size + 2) * (winograd_out_tile_size + 2);
transformed_input_shape.insert(transformed_input_shape.end(),
{in_tile_area, batch, input_channels,
tile_count});
transformed_output_shape.insert(transformed_output_shape.end(),
{in_tile_area, batch, channels,
tile_count});
transformed_filter_shape.insert(transformed_filter_shape.end(),
{in_tile_area, channels, input_channels});
} else if (use_neon_3x3_s1) {
extra_output_height = RoundUp<index_t>(height, 2);
extra_input_height = std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_3x3_s2) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 2 + 3);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 2 + 3);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
}
// decide scratch size before allocate it
index_t total_scratch_size = 0;
index_t transformed_input_size = 0;
index_t transformed_output_size = 0;
index_t padded_input_size = 0;
index_t padded_output_size = 0;
if (use_winograd) {
transformed_input_size =
std::accumulate(transformed_input_shape.begin(),
transformed_input_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
transformed_output_size =
std::accumulate(transformed_output_shape.begin(),
transformed_output_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
total_scratch_size += transformed_input_size + transformed_output_size;
}
if (extra_input_height != input_height || extra_input_width != input_width) {
padded_input_size =
batch * input_channels * (input_height + pad_top + pad_bottom)
* (input_width + pad_left + pad_right) * sizeof(float);
total_scratch_size += padded_input_size;
}
if (extra_output_height != height || extra_output_width != width) {
padded_output_size =
batch * channels * extra_output_height * extra_output_width
* sizeof(float);
total_scratch_size += padded_output_size;
}
// Init scratch buffer
scratch_->Rewind();
scratch_->GrowSize(total_scratch_size);
Tensor transformed_input(scratch_->Scratch(transformed_input_size), DT_FLOAT);
Tensor
transformed_output(scratch_->Scratch(transformed_output_size), DT_FLOAT);
Tensor padded_input(scratch_->Scratch(padded_input_size), DT_FLOAT);
Tensor padded_output(scratch_->Scratch(padded_output_size), DT_FLOAT);
// decide which convolution function to call
if (use_winograd) {
transformed_input.Resize(transformed_input_shape);
transformed_output.Resize(transformed_output_shape);
const float *transformed_filter_ptr;
if (transformed_filter_.dim_size() == 0) {
transformed_filter_.Resize(transformed_filter_shape);
if (is_filter_transformed_) {
transformed_filter_ptr = filter_data;
} else {
switch (winograd_out_tile_size) {
case 2:
TransformFilter4x4(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter_.mutable_data<float>());
break;
case 6:
TransformFilter8x8(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter_.mutable_data<float>());
break;
default:MACE_NOT_IMPLEMENTED;
}
transformed_filter_ptr = transformed_filter_.data<float>();
}
} else {
transformed_filter_ptr = transformed_filter_.data<float>();
}
conv_func = [&](const float *pad_input, float *pad_output) {
WinoGradConv3x3s1(pad_input,
transformed_filter_ptr,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
winograd_out_tile_size,
transformed_input.mutable_data<float>(),
transformed_output.mutable_data<float>(),
pad_output);
};
} else if (use_neon_3x3_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (use_neon_3x3_s2) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S2(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (use_neon_1x1_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK1x1S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
pad_output);
};
} else {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNCHW(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_output);
};
}
// pad input and output
const Tensor *pad_input_ptr = input;
if (extra_input_height != input_height || extra_input_width != input_width) {
padded_input.Clear();
ConstructNCHWInputWithSpecificPadding(input,
pad_top,
pad_bottom,
pad_left,
pad_right,
&padded_input);
pad_input_ptr = &padded_input;
}
Tensor *pad_output_ptr = output;
if (extra_output_height != height || extra_output_width != width) {
padded_output.Resize({batch, channels, extra_output_height,
extra_output_width});
padded_output.Clear();
pad_output_ptr = &padded_output;
}
const float *pad_input_data = pad_input_ptr->data<float>();
float *pad_output_data = pad_output_ptr->mutable_data<float>();
conv_func(pad_input_data, pad_output_data);
// unpack output
if (extra_output_height != height || extra_output_width != width) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t h = 0; h < height; ++h) {
memcpy(
output_data + b * channels * height * width + c * height * width
+ h * width,
pad_output_data
+ b * channels * extra_output_height * extra_output_width
+ c * extra_output_height * extra_output_width
+ h * extra_output_width,
sizeof(float) * width);
}
}
}
}
if (bias_data != nullptr) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t i = 0; i < height * width; ++i) {
output_data[(b * channels + c) * height * width + i] += bias_data[c];
}
}
}
}
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
...@@ -12,49 +12,46 @@ ...@@ -12,49 +12,46 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_KERNELS_GLOBAL_AVG_POOLING_H_ #ifndef MACE_KERNELS_ARM_CONV_2D_NEON_H_
#define MACE_KERNELS_GLOBAL_AVG_POOLING_H_ #define MACE_KERNELS_ARM_CONV_2D_NEON_H_
#include "mace/core/future.h" #include "mace/core/types.h"
#include "mace/core/tensor.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> extern void Conv2dNeonK1x1S1(const float *input,
struct GlobalAvgPoolingFunctor { const float *filter,
void operator()(const T *input, const index_t batch,
const index_t *input_shape, const index_t height,
T *output, const index_t width,
StatsFuture *future) { const index_t in_channels,
index_t batch = input_shape[0]; const index_t out_channels,
index_t channels = input_shape[1]; float *output);
index_t height = input_shape[2];
index_t width = input_shape[3]; extern void Conv2dNeonK3x3S1(const float *input,
const float *filter,
index_t image_size = height * width; const index_t batch,
index_t input_offset = 0; const index_t in_height,
index_t total_channels = batch * channels; const index_t in_width,
const index_t in_channels,
for (int c = 0; c < total_channels; ++c) { const index_t out_height,
T sum = 0; const index_t out_width,
for (int i = 0; i < image_size; ++i) { const index_t out_channels,
sum += input[input_offset + i]; float *output);
}
output[c] = sum / image_size; extern void Conv2dNeonK3x3S2(const float *input,
input_offset += image_size; const float *filter,
} const index_t batch,
} const index_t in_height,
}; const index_t in_width,
const index_t in_channels,
template <> const index_t out_height,
void GlobalAvgPoolingFunctor<DeviceType::NEON, float>::operator()( const index_t out_width,
const float *input, const index_t out_channels,
const index_t *input_shape, float *output);
float *output,
StatsFuture *future);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_GLOBAL_AVG_POOLING_H_ #endif // MACE_KERNELS_ARM_CONV_2D_NEON_H_
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include "mace/core/types.h" #include "mace/kernels/arm/conv_2d_neon.h"
#include "mace/kernels/gemm.h" #include "mace/kernels/gemm.h"
namespace mace { namespace mace {
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include "mace/core/types.h" #include "mace/kernels/arm/conv_2d_neon.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/depthwise_conv2d.h"
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
namespace {
void DepthwiseConv2dNCHW(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t multiplier = out_channels / in_channels;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
index_t c = m / multiplier;
index_t o = m % multiplier;
float sum = 0;
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h - pad_top;
index_t iw = w * stride_w + kw * dilation_w - pad_left;
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((o * in_channels) + c) * filter_height + kh) * filter_width
+ kw;
sum += input[in_offset] * filter[filter_offset];
}
}
}
output[out_offset] = sum;
}
}
}
}
}
} // namespace
extern void DepthwiseConv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
float *output);
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
float *output);
void DepthwiseConv2dFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape
{filter->dim(0) * filter->dim(1), filter->dim(1), filter->dim(2),
filter->dim(3)};
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter_shape.data(),
dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::FLOOR,
output_shape.data());
}
output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter_shape[2];
index_t filter_w = filter_shape[3];
MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
int valid_h_start = pad_top == 0 ? 0 : (pad_top - 1) / stride_h + 1;
int valid_h_stop = pad_bottom == 0
? height
: height - ((pad_bottom - 1) / stride_h + 1);
int valid_w_start = pad_left == 0 ? 0 : (pad_left - 1) / stride_w + 1;
int valid_w_stop = pad_right == 0
? width
: width - ((pad_right - 1) / stride_w + 1);
std::function<void(const float *input, float *output)> conv_func;
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S1(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
pad_top,
pad_left,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S2(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
pad_top,
pad_left,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNCHW(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
};
}
conv_func(input_data, output_data);
if (bias_data != nullptr) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t i = 0; i < height * width; ++i) {
output_data[(b * channels + c) * height * width + i] += bias_data[c];
}
}
}
}
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_ARM_DEPTHWISE_CONV2D_NEON_H_
#define MACE_KERNELS_ARM_DEPTHWISE_CONV2D_NEON_H_
#include "mace/core/types.h"
namespace mace {
namespace kernels {
void DepthwiseConv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output);
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_ARM_DEPTHWISE_CONV2D_NEON_H_
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include "mace/core/types.h" #include "mace/kernels/arm/depthwise_conv2d_neon.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
...@@ -60,10 +60,10 @@ void DepthwiseConv2dNeonK3x3S1(const float *input, ...@@ -60,10 +60,10 @@ void DepthwiseConv2dNeonK3x3S1(const float *input,
const index_t out_channels, const index_t out_channels,
const int pad_top, const int pad_top,
const int pad_left, const int pad_left,
const int valid_h_start, const index_t valid_h_start,
const int valid_h_stop, const index_t valid_h_stop,
const int valid_w_start, const index_t valid_w_start,
const int valid_w_stop, const index_t valid_w_stop,
float *output) { float *output) {
const index_t multiplier = out_channels / in_channels; const index_t multiplier = out_channels / in_channels;
const index_t in_image_size = in_height * in_width; const index_t in_image_size = in_height * in_width;
...@@ -277,10 +277,10 @@ void DepthwiseConv2dNeonK3x3S2(const float *input, ...@@ -277,10 +277,10 @@ void DepthwiseConv2dNeonK3x3S2(const float *input,
const index_t out_channels, const index_t out_channels,
const int pad_top, const int pad_top,
const int pad_left, const int pad_left,
const int valid_h_start, const index_t valid_h_start,
const int valid_h_stop, const index_t valid_h_stop,
const int valid_w_start, const index_t valid_w_start,
const int valid_w_stop, const index_t valid_w_stop,
float *output) { float *output) {
const index_t multiplier = out_channels / in_channels; const index_t multiplier = out_channels / in_channels;
const index_t in_image_size = in_height * in_width; const index_t in_image_size = in_height * in_width;
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/fully_connected.h"
#include "mace/kernels/gemm.h"
namespace mace {
namespace kernels {
void FullyConnectedFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
output->Resize(output_shape);
const index_t N = output->dim(0);
const index_t input_size = weight->dim(1);
const index_t output_size = weight->dim(0);
const float *input_ptr = input->data<float>();
const float *weight_ptr = weight->data<float>();
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
for (int i = 0; i < N; ++i) {
for (int j = 0; j < output_size; ++j) {
output_ptr[j + i * output_size] += bias_ptr[j];
}
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/local_response_norm.h"
namespace mace {
namespace kernels {
void LocalResponseNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
int depth_radius,
float bias,
float alpha,
float beta,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
index_t image_size = height * width;
index_t batch_size = channels * image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const int begin_input_c = std::max(static_cast<index_t>(0),
c - depth_radius);
const int end_input_c = std::min(channels, c + depth_radius + 1);
index_t pos = b * batch_size;
for (index_t hw = 0; hw < height * width; ++hw, ++pos) {
float accum = 0.f;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_ptr[pos + input_c * image_size];
accum += input_val * input_val;
}
const float multiplier = std::pow(bias + alpha * accum, -beta);
output_ptr[pos + c * image_size] =
input_ptr[pos + c * image_size] * multiplier;
}
}
}
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/pooling.h"
namespace mace {
namespace kernels {
namespace {
void MaxPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = std::numeric_limits<float>::lowest();
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res = std::max(res, input[input_offset]);
}
}
}
output[out_offset] = res;
}
}
}
}
}
void AvgPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = 0;
int block_size = 0;
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res += input[input_offset];
++block_size;
}
}
}
output[out_offset] = res / block_size;
}
}
}
}
}
} // namespace
void PoolingFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {
input_tensor->dim(1), input_tensor->dim(1), kernels_[0], kernels_[1]};
std::vector<int> paddings(2);
if (paddings_.empty()) {
kernels::CalcNCHWPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_,
strides_, padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input_tensor->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::CEIL,
output_shape.data());
}
output_tensor->Resize(output_shape);
const float *input = input_tensor->data<float>();
float *output = output_tensor->mutable_data<float>();
const index_t *input_shape = input_tensor->shape().data();
index_t batch = output_shape[0];
index_t channels = output_shape[1];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t input_channels = input_shape[1];
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
index_t in_image_size = input_height * input_width;
int filter_h = kernels_[0];
int filter_w = kernels_[1];
int stride_h = strides_[0];
int stride_w = strides_[1];
int dilation_h = dilations_[0];
int dilation_w = dilations_[1];
int pad_top = paddings[0] / 2;
int pad_left = paddings[1] / 2;
if (pooling_type_ == PoolingType::MAX) {
MaxPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else if (pooling_type_ == PoolingType::AVG) {
AvgPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else {
MACE_NOT_IMPLEMENTED;
}
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/softmax.h"
namespace mace {
namespace kernels {
void SoftmaxFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t class_count = input->dim(1);
const index_t class_size = input->dim(2) * input->dim(3);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) {
std::vector<float>
max_val(class_size, std::numeric_limits<float>::lowest());
std::vector<float> sum_val(class_size, 0.f);
// calculate max for each class
for (index_t c = 0; c < class_count; ++c) {
const float *input_ptr = input_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
max_val[k] = std::max(max_val[k], input_ptr[k]);
}
}
// calculate data - max for each class
#pragma omp parallel for
for (index_t c = 0; c < class_count; ++c) {
const float *input_ptr = input_data + (b * class_count + c) * class_size;
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] = ::exp(input_ptr[k] - max_val[k]);
}
}
// calculate sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
sum_val[k] += output_ptr[k];
}
}
// calculate (data - max) / sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] /= sum_val[k];
}
}
}
}
} // namespace kernels
} // namespace mace
...@@ -34,21 +34,24 @@ struct BatchNormFunctorBase { ...@@ -34,21 +34,24 @@ struct BatchNormFunctorBase {
BatchNormFunctorBase(bool folded_constant, BatchNormFunctorBase(bool folded_constant,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit) const float relux_max_limit)
: folded_constant_(folded_constant), : folded_constant_(folded_constant),
activation_(activation), activation_(activation),
relux_max_limit_(relux_max_limit) {} relux_max_limit_(relux_max_limit) {}
const bool folded_constant_; const bool folded_constant_;
const ActivationType activation_; const ActivationType activation_;
const float relux_max_limit_; const float relux_max_limit_;
}; };
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct BatchNormFunctor : BatchNormFunctorBase { struct BatchNormFunctor;
template<>
struct BatchNormFunctor<DeviceType::CPU, float> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant, BatchNormFunctor(const bool folded_constant,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit) const float relux_max_limit)
: BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {} : BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *scale, const Tensor *scale,
...@@ -67,29 +70,29 @@ struct BatchNormFunctor : BatchNormFunctorBase { ...@@ -67,29 +70,29 @@ struct BatchNormFunctor : BatchNormFunctorBase {
// new_offset = \offset - mean * common_val; // new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset; // Y = new_scale * X + new_offset;
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t height = input->dim(1); const index_t channels = input->dim(1);
const index_t width = input->dim(2); const index_t height = input->dim(2);
const index_t channels = input->dim(3); const index_t width = input->dim(3);
Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard scale_mapper(scale); Tensor::MappingGuard scale_mapper(scale);
Tensor::MappingGuard offset_mapper(offset); Tensor::MappingGuard offset_mapper(offset);
Tensor::MappingGuard output_mapper(output); Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>(); const float *input_ptr = input->data<float>();
const T *scale_ptr = scale->data<T>(); const float *scale_ptr = scale->data<float>();
const T *offset_ptr = offset->data<T>(); const float *offset_ptr = offset->data<float>();
T *output_ptr = output->mutable_data<T>(); float *output_ptr = output->mutable_data<float>();
std::vector<T> new_scale; std::vector<float> new_scale;
std::vector<T> new_offset; std::vector<float> new_offset;
if (!folded_constant_) { if (!folded_constant_) {
new_scale.resize(channels); new_scale.resize(channels);
new_offset.resize(channels); new_offset.resize(channels);
Tensor::MappingGuard mean_mapper(mean); Tensor::MappingGuard mean_mapper(mean);
Tensor::MappingGuard var_mapper(var); Tensor::MappingGuard var_mapper(var);
const T *mean_ptr = mean->data<T>(); const float *mean_ptr = mean->data<float>();
const T *var_ptr = var->data<T>(); const float *var_ptr = var->data<float>();
#pragma omp parallel for #pragma omp parallel for
for (index_t c = 0; c < channels; ++c) { for (index_t c = 0; c < channels; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon); new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon);
...@@ -97,44 +100,21 @@ struct BatchNormFunctor : BatchNormFunctorBase { ...@@ -97,44 +100,21 @@ struct BatchNormFunctor : BatchNormFunctorBase {
} }
} }
const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data(); const float *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data(); const float
*offset_data = folded_constant_ ? offset_ptr : new_offset.data();
const int elements = batch * height * width; index_t channel_size = height * width;
constexpr int c_tile_size = 4; index_t batch_size = channels * channel_size;
const int c_tiles = channels / c_tile_size;
const index_t remains_start = c_tiles * c_tile_size;
if (c_tiles > 0) { // NEON is slower, so stick to the trivial implementaion
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) { for (index_t b = 0; b < batch; ++b) {
for (int cb = 0; cb < c_tiles; ++cb) { for (index_t c = 0; c < channels; ++c) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) index_t offset = b * batch_size + c * channel_size;
static_assert(c_tile_size == 4, "channels tile size must be 4"); for (index_t hw = 0; hw < height * width; ++hw) {
int c = cb * c_tile_size; output_ptr[offset + hw] =
int pos = i * channels + c; scale_data[c] * input_ptr[offset + hw] + offset_data[c];
float32x4_t scales = vld1q_f32(scale_data + c);
float32x4_t offsets = vld1q_f32(offset_data + c);
float32x4_t in = vld1q_f32(input_ptr + pos);
float32x4_t out = vfmaq_f32(offsets, scales, in);
vst1q_f32(output_ptr + pos, out);
#else
for (int ci = 0; ci < c_tile_size; ++ci) {
int c = cb * c_tile_size + ci;
index_t pos = i * channels + c;
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
}
#endif
}
}
}
if (remains_start < channels) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (index_t c = remains_start; c < channels; ++c) {
index_t pos = i * channels + c;
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
} }
} }
} }
...@@ -143,29 +123,12 @@ struct BatchNormFunctor : BatchNormFunctorBase { ...@@ -143,29 +123,12 @@ struct BatchNormFunctor : BatchNormFunctorBase {
} }
}; };
template <> template<typename T>
struct BatchNormFunctor<DeviceType::NEON, float> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
const float relux_max_limit)
: BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future);
};
template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase { struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant, BatchNormFunctor(const bool folded_constant,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit) const float relux_max_limit)
: BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {} : BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *scale, const Tensor *scale,
const Tensor *offset, const Tensor *offset,
......
...@@ -26,49 +26,41 @@ ...@@ -26,49 +26,41 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct BiasAddFunctor { struct BiasAddFunctor;
template<>
struct BiasAddFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *bias, const Tensor *bias,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t height = input->dim(1); const index_t channels = input->dim(1);
const index_t width = input->dim(2); const index_t height = input->dim(2);
const index_t channels = input->dim(3); const index_t width = input->dim(3);
Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard bias_mapper(bias); Tensor::MappingGuard bias_mapper(bias);
Tensor::MappingGuard output_mapper(output); Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>(); const float *input_ptr = input->data<float>();
const T *bias_ptr = bias->data<T>(); const float *bias_ptr = bias->data<float>();
T *output_ptr = output->mutable_data<T>(); float *output_ptr = output->mutable_data<float>();
#pragma omp parallel for collapse(4) #pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) { for (index_t n = 0; n < batch; ++n) {
for (index_t h = 0; h < height; ++h) { for (index_t c = 0; c < channels; ++c) {
for (index_t w = 0; w < width; ++w) { for (index_t hw = 0; hw < height * width; ++hw) {
for (index_t c = 0; c < channels; ++c) { index_t pos = (n * channels + c) * height * width + hw;
index_t pos = (((n * height) + h) * width + w) * channels + c; output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
}
} }
} }
} }
} }
}; };
/* template<typename T>
template <>
void BiasAddFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
*/
template <typename T>
struct BiasAddFunctor<DeviceType::OPENCL, T> { struct BiasAddFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *bias, const Tensor *bias,
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct ChannelShuffleFunctor { struct ChannelShuffleFunctor {
explicit ChannelShuffleFunctor(const int groups) : groups_(groups) {} explicit ChannelShuffleFunctor(const int groups) : groups_(groups) {}
...@@ -39,20 +39,25 @@ struct ChannelShuffleFunctor { ...@@ -39,20 +39,25 @@ struct ChannelShuffleFunctor {
T *output_ptr = output->mutable_data<T>(); T *output_ptr = output->mutable_data<T>();
index_t batch = input->dim(0); index_t batch = input->dim(0);
index_t height = input->dim(1); index_t channels = input->dim(1);
index_t width = input->dim(2); index_t height = input->dim(2);
index_t channels = input->dim(3); index_t width = input->dim(3);
index_t bhw_fuse = batch * height * width; index_t image_size = height * width;
int channels_per_group = channels / groups_; index_t batch_size = channels * image_size;
index_t channels_per_group = channels / groups_;
#pragma omp parallel for
for (int bhw = 0; bhw < bhw_fuse; ++bhw) { #pragma omp parallel for collapse(2)
for (int c = 0; c < channels; ++c) { for (index_t b = 0; b < batch; ++b) {
index_t channel_base = bhw * channels; for (index_t c = 0; c < channels; ++c) {
output_ptr[channel_base + c] = const T *input_base = input_ptr + b * batch_size;
input_ptr[channel_base + c % groups_ * channels_per_group T *output_base = output_ptr + b * batch_size;
+ c / groups_]; index_t g = c % groups_;
index_t idx = c / groups_;
for (index_t hw = 0; hw < height * width; ++hw) {
output_base[c * image_size + hw] = input_base[
(c % groups_ * channels_per_group + c / groups_) * image_size + hw];
}
} }
} }
} }
...@@ -60,7 +65,7 @@ struct ChannelShuffleFunctor { ...@@ -60,7 +65,7 @@ struct ChannelShuffleFunctor {
const int groups_; const int groups_;
}; };
template <typename T> template<typename T>
struct ChannelShuffleFunctor<DeviceType::OPENCL, T> { struct ChannelShuffleFunctor<DeviceType::OPENCL, T> {
explicit ChannelShuffleFunctor(const int groups) : groups_(groups) {} explicit ChannelShuffleFunctor(const int groups) : groups_(groups) {}
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include <algorithm> #include <algorithm>
#include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -27,165 +28,13 @@ ...@@ -27,165 +28,13 @@
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/kernels/activation.h" #include "mace/kernels/activation.h"
#include "mace/kernels/conv_pool_2d_util.h" #include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/arm/conv_2d_neon.h"
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/utils/utils.h" #include "mace/utils/utils.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <typename T,
int inc_tile_size,
int c_count,
int h_count,
int w_count>
void Conv2dKernelFunc(const T *input_ptr, // batch start
const T *filter_ptr,
const T *bias_ptr,
T *output_ptr, // batch start
const int h_offset,
const int w_offset,
const int c_offset,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channels,
const int input_channels,
const int width,
const int padded_width) {
T sum[h_count * w_count * c_count] = {0.0f};
if (bias_ptr != nullptr) {
for (int hi = 0; hi < h_count; ++hi) {
for (int wi = 0; wi < w_count; ++wi) {
for (int ci = 0; ci < c_count; ++ci) {
const int sum_idx = (hi * w_count + wi) * c_count + ci;
sum[sum_idx] = bias_ptr[c_offset + ci];
}
}
}
}
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inc = 0;
for (; inc + inc_tile_size <= input_channels; inc += inc_tile_size) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
// AArch64 NEON has 32 128-bit general purpose registers
static_assert(inc_tile_size == 4, "input channels tile size must be 4");
float32x4_t in[h_count * w_count]; // NOLINT(runtime/arrays)
#else
T in[h_count * w_count * inc_tile_size]; // NOLINT(runtime/arrays)
#endif
for (int hi = 0; hi < h_count; ++hi) {
for (int wi = 0; wi < w_count; ++wi) {
const int in_idx = hi * w_count + wi;
const int inh = (h_offset + hi) * stride_h + kh * dilation_h;
const int inw = (w_offset + wi) * stride_w + kw * dilation_w;
const int in_offset =
(inh * padded_width + inw) * input_channels + inc;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert(inc_tile_size == 4,
"input channels tile size must be 4");
in[in_idx] = vld1q_f32(input_ptr + in_offset);
#else
for (int inci = 0; inci < inc_tile_size; ++inci) {
in[in_idx * inc_tile_size + inci] = input_ptr[in_offset + inci];
}
#endif
}
}
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert(inc_tile_size == 4, "input channels tile size must be 4");
float32x4_t weights[c_count]; // NOLINT(runtime/arrays)
#else
T weights[c_count * inc_tile_size]; // NOLINT(runtime/arrays)
#endif
for (int ci = 0; ci < c_count; ++ci) {
const int weights_idx = ci;
const int filter_offset =
((kh * kernel_w + kw) * channels + c_offset + ci) *
input_channels +
inc;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
weights[weights_idx] = vld1q_f32(filter_ptr + filter_offset);
#else
for (int inci = 0; inci < inc_tile_size; ++inci) {
weights[weights_idx * inc_tile_size + inci] =
filter_ptr[filter_offset + inci];
}
#endif
}
for (int hi = 0; hi < h_count; ++hi) {
for (int wi = 0; wi < w_count; ++wi) {
for (int ci = 0; ci < c_count; ++ci) {
const int weights_idx = ci;
const int in_idx = hi * w_count + wi;
const int sum_idx = (hi * w_count + wi) * c_count + ci;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
float32x4_t tmp = vmulq_f32(in[in_idx], weights[weights_idx]);
sum[sum_idx] += vaddvq_f32(tmp);
#else
for (int inci = 0; inci < inc_tile_size; ++inci) {
sum[sum_idx] += in[in_idx * inc_tile_size + inci] *
weights[weights_idx * inc_tile_size + inci];
}
#endif
}
}
}
}
// handling the remaining input channels
for (; inc < input_channels; ++inc) {
T in[h_count * w_count]; // NOLINT(runtime/arrays)
for (int hi = 0; hi < h_count; ++hi) {
for (int wi = 0; wi < w_count; ++wi) {
const int in_idx = hi * w_count + wi;
const int inh = (h_offset + hi) * stride_h + kh * dilation_h;
const int inw = (w_offset + wi) * stride_w + kw * dilation_w;
const int in_offset =
(inh * padded_width + inw) * input_channels + inc;
in[in_idx] = input_ptr[in_offset];
}
}
T weights[c_count]; // NOLINT(runtime/arrays)
for (int ci = 0; ci < c_count; ++ci) {
const int weights_idx = ci;
const int filter_offset =
((kh * kernel_w + kw) * channels + c_offset + ci) *
input_channels +
inc;
weights[weights_idx] = filter_ptr[filter_offset];
}
for (int hi = 0; hi < h_count; ++hi) {
for (int wi = 0; wi < w_count; ++wi) {
for (int ci = 0; ci < c_count; ++ci) {
const int weights_idx = ci;
const int in_idx = hi * w_count + wi;
const int sum_idx = (hi * w_count + wi) * c_count + ci;
sum[sum_idx] += in[in_idx] * weights[weights_idx];
}
}
}
}
}
}
// save output
for (int hi = 0; hi < h_count; ++hi) {
for (int wi = 0; wi < w_count; ++wi) {
for (int ci = 0; ci < c_count; ++ci) {
const int out_offset =
((h_offset + hi) * width + w_offset + wi) * channels + c_offset +
ci;
const int sum_idx = (hi * w_count + wi) * c_count + ci;
output_ptr[out_offset] = sum[sum_idx];
}
}
}
}
struct Conv2dFunctorBase { struct Conv2dFunctorBase {
Conv2dFunctorBase(const int *strides, Conv2dFunctorBase(const int *strides,
const Padding &padding_type, const Padding &padding_type,
...@@ -193,12 +42,12 @@ struct Conv2dFunctorBase { ...@@ -193,12 +42,12 @@ struct Conv2dFunctorBase {
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit) const float relux_max_limit)
: strides_(strides), : strides_(strides),
padding_type_(padding_type), padding_type_(padding_type),
paddings_(paddings), paddings_(paddings),
dilations_(dilations), dilations_(dilations),
activation_(activation), activation_(activation),
relux_max_limit_(relux_max_limit) {} relux_max_limit_(relux_max_limit) {}
const int *strides_; // [stride_h, stride_w] const int *strides_; // [stride_h, stride_w]
const Padding padding_type_; const Padding padding_type_;
...@@ -208,100 +57,11 @@ struct Conv2dFunctorBase { ...@@ -208,100 +57,11 @@ struct Conv2dFunctorBase {
const float relux_max_limit_; const float relux_max_limit_;
}; };
#define MACE_DO_CONV2D(CC, CH, CW) \ template<DeviceType D, typename T>
Conv2dKernelFunc<T, inc_tile_size, CC, CH, CW>( \ struct Conv2dFunctor;
input_ptr, filter_data, bias_data, output_ptr, \
h_offset, w_offset, c_offset, kernel_h, kernel_w, \ template<>
stride_h, stride_w, dilation_h, dilation_w, \ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
channels, input_channels, width, padded_width);
#define MACE_CASE_W_CONV2D(CC, CH) \
switch (w_count) { \
case 1: \
MACE_DO_CONV2D(CC, CH, 1); \
break; \
case 2: \
MACE_DO_CONV2D(CC, CH, 2); \
break; \
case 3: \
MACE_DO_CONV2D(CC, CH, 3); \
break; \
case 4: \
MACE_DO_CONV2D(CC, CH, 4); \
break; \
default: \
LOG(FATAL) << "Unsupported w tile: " << w_count; \
}
#define MACE_CASE_H_CONV2D(CC) \
switch (h_count) { \
case 1: \
MACE_CASE_W_CONV2D(CC, 1); \
break; \
case 2: \
MACE_CASE_W_CONV2D(CC, 2); \
break; \
default: \
LOG(FATAL) << "Unsupported h tile: " << h_count; \
}
#define MACE_CASE_C_CONV2D \
switch (c_count) { \
case 1: \
MACE_CASE_H_CONV2D(1); \
break; \
case 2: \
MACE_CASE_H_CONV2D(2); \
break; \
case 3: \
MACE_CASE_H_CONV2D(3); \
break; \
case 4: \
MACE_CASE_H_CONV2D(4); \
break; \
case 5: \
MACE_CASE_H_CONV2D(5); \
break; \
case 6: \
MACE_CASE_H_CONV2D(6); \
break; \
case 7: \
MACE_CASE_H_CONV2D(7); \
break; \
case 8: \
MACE_CASE_H_CONV2D(8); \
break; \
case 9: \
MACE_CASE_H_CONV2D(9); \
break; \
case 10: \
MACE_CASE_H_CONV2D(10); \
break; \
case 11: \
MACE_CASE_H_CONV2D(11); \
break; \
case 12: \
MACE_CASE_H_CONV2D(12); \
break; \
case 13: \
MACE_CASE_H_CONV2D(13); \
break; \
case 14: \
MACE_CASE_H_CONV2D(14); \
break; \
case 15: \
MACE_CASE_H_CONV2D(15); \
break; \
case 16: \
MACE_CASE_H_CONV2D(16); \
break; \
default: \
LOG(FATAL) << "Unsupported c tile: " << c_count; \
}
template <DeviceType D, typename T>
struct Conv2dFunctor : Conv2dFunctorBase {
Conv2dFunctor(const int *strides, Conv2dFunctor(const int *strides,
const Padding &padding_type, const Padding &padding_type,
const std::vector<int> &paddings, const std::vector<int> &paddings,
...@@ -310,15 +70,62 @@ struct Conv2dFunctor : Conv2dFunctorBase { ...@@ -310,15 +70,62 @@ struct Conv2dFunctor : Conv2dFunctorBase {
const float relux_max_limit, const float relux_max_limit,
const bool is_filter_transformed, const bool is_filter_transformed,
ScratchBuffer *scratch) ScratchBuffer *scratch)
: Conv2dFunctorBase(strides, : Conv2dFunctorBase(strides,
padding_type, padding_type,
paddings, paddings,
dilations, dilations,
activation, activation,
relux_max_limit) {} relux_max_limit),
is_filter_transformed_(is_filter_transformed),
void operator()(const Tensor *input, // NHWC scratch_(scratch) {}
const Tensor *filter, // HWOI or TOI
void Conv2dGeneral(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
float *output) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h;
index_t iw = w * stride_w + kw * dilation_w;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((m * in_channels) + c) * filter_height + kh)
* filter_width
+ kw;
output[out_offset] +=
input[in_offset] * filter[filter_offset];
}
}
}
}
}
}
}
}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias, const Tensor *bias,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
...@@ -326,138 +133,374 @@ struct Conv2dFunctor : Conv2dFunctorBase { ...@@ -326,138 +133,374 @@ struct Conv2dFunctor : Conv2dFunctorBase {
MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
std::vector<index_t> filter_shape(4);
if (is_filter_transformed_) {
// TOC -> OIHW
filter_shape[0] = filter->dim(1);
filter_shape[1] = filter->dim(2);
filter_shape[2] = filter_shape[3] = 3;
} else {
filter_shape = filter->shape();
}
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
if (paddings_.empty()) { if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize( CalcNCHWPaddingAndOutputSize(input->shape().data(),
input->shape().data(), filter->shape().data(), dilations_, strides_, filter_shape.data(),
padding_type_, output_shape.data(), paddings.data()); dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), filter->shape().data(), CalcNCHWOutputSize(input->shape().data(),
paddings_.data(), dilations_, strides_, RoundType::FLOOR, filter_shape.data(),
output_shape.data()); paddings_.data(),
dilations_,
strides_,
RoundType::FLOOR,
output_shape.data());
} }
output->Resize(output_shape); output->Resize(output_shape);
output->Clear();
int batch = output->dim(0);
int height = output->dim(1); index_t batch = output->dim(0);
int width = output->dim(2); index_t channels = output->dim(1);
int channels = output->dim(3); index_t height = output->dim(2);
index_t width = output->dim(3);
int input_batch = input->dim(0);
int input_height = input->dim(1); index_t input_batch = input->dim(0);
int input_width = input->dim(2); index_t input_channels = input->dim(1);
int input_channels = input->dim(3); index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
int kernel_h = filter->dim(0);
int kernel_w = filter->dim(1); index_t filter_h = filter_shape[2];
MACE_CHECK(filter->dim(2) == channels, filter->dim(2), " != ", channels); index_t filter_w = filter_shape[3];
MACE_CHECK(filter->dim(3) == input_channels, filter->dim(3), " != ", MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels); input_channels);
int stride_h = strides_[0]; index_t stride_h = strides_[0];
int stride_w = strides_[1]; index_t stride_w = strides_[1];
int dilation_h = dilations_[0]; index_t dilation_h = dilations_[0];
int dilation_w = dilations_[1]; index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
int padded_height = input_height + paddings[0]; index_t padded_input_height = input_height + paddings[0];
int padded_width = input_width + paddings[1]; index_t padded_input_width = input_width + paddings[1];
index_t extra_input_height = padded_input_height;
index_t extra_input_width = padded_input_width;
index_t extra_output_height = height;
index_t extra_output_width = width;
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
std::function<void(const float *input, float *output)> conv_func;
bool
use_winograd = is_filter_transformed_ || (filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1
&& input_channels >= 8 && channels >= 8);
bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1;
bool use_neon_1x1_s1 = filter_h == 1 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
std::vector<index_t> transformed_input_shape;
std::vector<index_t> transformed_output_shape;
std::vector<index_t> transformed_filter_shape;
// When size of input feature map is bigger than 16x16,
// set winograd out tile size to 6 to get higher performance.
index_t winograd_out_tile_size = 2;
if (input_height > 16 && input_width > 16) {
winograd_out_tile_size = 6;
}
if (use_winograd) {
extra_output_height = RoundUp<index_t>(height, winograd_out_tile_size);
extra_input_height =
std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, winograd_out_tile_size);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
index_t tile_height_count = extra_output_height / winograd_out_tile_size;
index_t tile_width_count = extra_output_width / winograd_out_tile_size;
index_t tile_count = tile_height_count * tile_width_count;
index_t in_tile_area =
(winograd_out_tile_size + 2) * (winograd_out_tile_size + 2);
transformed_input_shape.insert(transformed_input_shape.end(),
{in_tile_area, batch, input_channels,
tile_count});
transformed_output_shape.insert(transformed_output_shape.end(),
{in_tile_area, batch, channels,
tile_count});
transformed_filter_shape.insert(transformed_filter_shape.end(),
{in_tile_area, channels, input_channels});
} else if (use_neon_3x3_s1) {
extra_output_height = RoundUp<index_t>(height, 2);
extra_input_height =
std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_3x3_s2) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 2 + 3);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 2 + 3);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
}
Tensor padded_input; // decide scratch size before allocate it
// Keep this alive during kernel execution index_t total_scratch_size = 0;
if (paddings[0] > 0 || paddings[1] > 0) { index_t transformed_input_size = 0;
ConstructNHWCInputWithPadding(input, paddings.data(), &padded_input); index_t transformed_output_size = 0;
input = &padded_input; index_t padded_input_size = 0;
index_t padded_output_size = 0;
if (use_winograd) {
transformed_input_size =
std::accumulate(transformed_input_shape.begin(),
transformed_input_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
transformed_output_size =
std::accumulate(transformed_output_shape.begin(),
transformed_output_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
total_scratch_size += transformed_input_size + transformed_output_size;
} }
if (extra_input_height != input_height
|| extra_input_width != input_width) {
padded_input_size =
batch * input_channels * (input_height + pad_top + pad_bottom)
* (input_width + pad_left + pad_right) * sizeof(float);
total_scratch_size += padded_input_size;
}
if (extra_output_height != height || extra_output_width != width) {
padded_output_size =
batch * channels * extra_output_height * extra_output_width
* sizeof(float);
total_scratch_size += padded_output_size;
}
// Init scratch buffer
scratch_->Rewind();
scratch_->GrowSize(total_scratch_size);
Tensor
transformed_input(scratch_->Scratch(transformed_input_size), DT_FLOAT);
Tensor
transformed_output(scratch_->Scratch(transformed_output_size), DT_FLOAT);
Tensor padded_input(scratch_->Scratch(padded_input_size), DT_FLOAT);
Tensor padded_output(scratch_->Scratch(padded_output_size), DT_FLOAT);
// decide which convolution function to call
if (use_winograd) {
transformed_input.Resize(transformed_input_shape);
transformed_output.Resize(transformed_output_shape);
const float *transformed_filter_ptr;
if (transformed_filter_.dim_size() == 0) {
transformed_filter_.Resize(transformed_filter_shape);
if (is_filter_transformed_) {
transformed_filter_ptr = filter_data;
} else {
switch (winograd_out_tile_size) {
case 2:
TransformFilter4x4(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter_.mutable_data<float>());
break;
case 6:
TransformFilter8x8(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter_.mutable_data<float>());
break;
default:MACE_NOT_IMPLEMENTED;
}
transformed_filter_ptr = transformed_filter_.data<float>();
}
} else {
transformed_filter_ptr = transformed_filter_.data<float>();
}
// padded_input.DebugPrint(); conv_func = [&](const float *pad_input, float *pad_output) {
WinoGradConv3x3s1(pad_input,
transformed_filter_ptr,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
winograd_out_tile_size,
transformed_input.mutable_data<float>(),
transformed_output.mutable_data<float>(),
pad_output);
};
} else if (use_neon_3x3_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (use_neon_3x3_s2) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S2(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (use_neon_1x1_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK1x1S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
pad_output);
};
} else {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dGeneral(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_output);
};
}
Tensor::MappingGuard input_mapper(input); // pad input and output
Tensor::MappingGuard filter_mapper(filter); const Tensor *pad_input_ptr = input;
Tensor::MappingGuard bias_mapper(bias); if (extra_input_height != input_height
Tensor::MappingGuard output_mapper(output); || extra_input_width != input_width) {
auto input_data = input->data<T>(); padded_input.Clear();
auto filter_data = filter->data<T>(); ConstructNCHWInputWithSpecificPadding(input,
auto bias_data = bias == nullptr ? nullptr : bias->data<T>(); pad_top,
auto output_data = output->mutable_data<T>(); pad_bottom,
pad_left,
pad_right,
&padded_input);
pad_input_ptr = &padded_input;
}
constexpr int inc_tile_size = 4; Tensor *pad_output_ptr = output;
// TODO(heliangliang) Auto tuning these parameters if (extra_output_height != height || extra_output_width != width) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) padded_output.Resize({batch, channels, extra_output_height,
const int c_tile_size = 4; extra_output_width});
const int h_tile_size = 2; padded_output.Clear();
const int w_tile_size = 2; pad_output_ptr = &padded_output;
#else }
const int c_tile_size = 4; const float *pad_input_data = pad_input_ptr->data<float>();
const int h_tile_size = 1; float *pad_output_data = pad_output_ptr->mutable_data<float>();
const int w_tile_size = 2;
#endif conv_func(pad_input_data, pad_output_data);
// unpack output
if (extra_output_height != height || extra_output_width != width) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t h = 0; h < height; ++h) {
memcpy(
output_data + b * channels * height * width + c * height * width
+ h * width,
pad_output_data
+ b * channels * extra_output_height * extra_output_width
+ c * extra_output_height * extra_output_width
+ h * extra_output_width,
sizeof(float) * width);
}
}
}
}
const int c_tiles = RoundUpDiv(channels, c_tile_size); if (bias_data != nullptr) {
const int h_tiles = RoundUpDiv(height, h_tile_size); #pragma omp parallel for collapse(2)
const int w_tiles = RoundUpDiv(width, w_tile_size); for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
#pragma omp parallel for collapse(4) for (index_t i = 0; i < height * width; ++i) {
for (int n = 0; n < batch; ++n) { output_data[(b * channels + c) * height * width + i] +=
for (int cb = 0; cb < c_tiles; ++cb) { bias_data[c];
for (int hb = 0; hb < h_tiles; ++hb) {
for (int wb = 0; wb < w_tiles; ++wb) {
const T *input_ptr =
input_data + n * padded_height * padded_width * input_channels;
T *output_ptr = output_data + n * height * width * channels;
const int h_offset = hb * h_tile_size;
const int w_offset = wb * w_tile_size;
const int c_offset = cb * c_tile_size;
const int h_count = std::min(h_tile_size, height - h_offset);
const int w_count = std::min(w_tile_size, width - w_offset);
const int c_count = std::min(c_tile_size, channels - c_offset);
MACE_CASE_C_CONV2D;
} }
} }
} }
} }
DoActivation(output_data, output_data, output->size(), activation_, DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_); relux_max_limit_);
} }
};
template <>
struct Conv2dFunctor<DeviceType::NEON, float> : Conv2dFunctorBase {
Conv2dFunctor(const int *strides,
const Padding &padding_type,
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const bool is_filter_transformed,
ScratchBuffer *scratch)
: Conv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit),
is_filter_transformed_(is_filter_transformed),
scratch_(scratch) {}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
Tensor transformed_filter_; Tensor transformed_filter_;
bool is_filter_transformed_; bool is_filter_transformed_;
ScratchBuffer *scratch_; ScratchBuffer *scratch_;
}; };
template <typename T> template<typename T>
struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase { struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase {
Conv2dFunctor(const int *strides, Conv2dFunctor(const int *strides,
const Padding &padding_type, const Padding &padding_type,
...@@ -467,12 +510,12 @@ struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase { ...@@ -467,12 +510,12 @@ struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase {
const float relux_max_limit, const float relux_max_limit,
const bool is_filter_transformed, const bool is_filter_transformed,
ScratchBuffer *scratch) ScratchBuffer *scratch)
: Conv2dFunctorBase(strides, : Conv2dFunctorBase(strides,
padding_type, padding_type,
paddings, paddings,
dilations, dilations,
activation, activation,
relux_max_limit) {} relux_max_limit) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *filter, const Tensor *filter,
......
...@@ -368,7 +368,6 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor, ...@@ -368,7 +368,6 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
const int pad_left, const int pad_left,
const int pad_right, const int pad_right,
Tensor *output_tensor) { Tensor *output_tensor) {
Tensor::MappingGuard input_mapper(input_tensor);
const float *input = input_tensor->data<float>(); const float *input = input_tensor->data<float>();
const index_t *input_shape = input_tensor->shape().data(); const index_t *input_shape = input_tensor->shape().data();
......
...@@ -25,15 +25,15 @@ ...@@ -25,15 +25,15 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct DepthToSpaceOpFunctor { struct DepthToSpaceOpFunctor {
explicit DepthToSpaceOpFunctor(const int block_size, bool d2s) explicit DepthToSpaceOpFunctor(const int block_size, bool d2s)
: block_size_(block_size), d2s_(d2s) {} : block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) { void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const int batch_size = input->dim(0); const int batch_size = input->dim(0);
const int input_height = input->dim(1); const int input_depth = input->dim(1);
const int input_width = input->dim(2); const int input_height = input->dim(2);
const int input_depth = input->dim(3); const int input_width = input->dim(3);
index_t output_depth, output_width, output_height; index_t output_depth, output_width, output_height;
...@@ -46,8 +46,8 @@ struct DepthToSpaceOpFunctor { ...@@ -46,8 +46,8 @@ struct DepthToSpaceOpFunctor {
output_width = input_width / block_size_; output_width = input_width / block_size_;
output_height = input_height / block_size_; output_height = input_height / block_size_;
} }
std::vector<index_t> output_shape = {batch_size, output_height, std::vector<index_t> output_shape = {batch_size, output_depth,
output_width, output_depth}; output_height, output_width};
output->Resize(output_shape); output->Resize(output_shape);
...@@ -59,23 +59,22 @@ struct DepthToSpaceOpFunctor { ...@@ -59,23 +59,22 @@ struct DepthToSpaceOpFunctor {
if (d2s_) { if (d2s_) {
#pragma omp parallel for #pragma omp parallel for
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < output_height; ++h) { for (int d = 0; d < output_depth; ++d) {
const int in_h = h / block_size_; for (int h = 0; h < output_height; ++h) {
const int offset_h = (h % block_size_); const int in_h = h / block_size_;
for (int w = 0; w < output_width; ++w) { const int offset_h = (h % block_size_);
const int in_w = w / block_size_; for (int w = 0; w < output_width; ++w) {
const int offset_w = w % block_size_; const index_t in_w = w / block_size_;
const int offset_d = const index_t offset_w = w % block_size_;
const index_t offset_d =
(offset_h * block_size_ + offset_w) * output_depth; (offset_h * block_size_ + offset_w) * output_depth;
for (int d = 0; d < output_depth; ++d) {
const int in_d = d + offset_d; const index_t in_d = d + offset_d;
const int o_index = const index_t o_index =
((b * output_height + h) * output_width + w) * output_depth + ((b * output_depth + d) * output_height + h) * output_width + w;
d; const index_t i_index =
const int i_index = ((b * input_depth + in_d) * input_height + in_h) * input_width
((b * input_height + in_h) * input_width + in_w) * + in_w;
input_depth +
in_d;
output_ptr[o_index] = input_ptr[i_index]; output_ptr[o_index] = input_ptr[i_index];
} }
} }
...@@ -84,22 +83,23 @@ struct DepthToSpaceOpFunctor { ...@@ -84,22 +83,23 @@ struct DepthToSpaceOpFunctor {
} else { } else {
#pragma omp parallel for #pragma omp parallel for
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < input_height; ++h) { for (int d = 0; d < input_depth; ++d) {
const int out_h = h / block_size_; for (int h = 0; h < input_height; ++h) {
const int offset_h = (h % block_size_); const int out_h = h / block_size_;
for (int w = 0; w < input_width; ++w) { const int offset_h = (h % block_size_);
const int out_w = w / block_size_; for (int w = 0; w < input_width; ++w) {
const int offset_w = (w % block_size_); const int out_w = w / block_size_;
const int offset_d = const int offset_w = (w % block_size_);
const int offset_d =
(offset_h * block_size_ + offset_w) * input_depth; (offset_h * block_size_ + offset_w) * input_depth;
for (int d = 0; d < input_depth; ++d) {
const int out_d = d + offset_d; const int out_d = d + offset_d;
const int o_index = const index_t o_index =
((b * output_height + out_h) * output_width + out_w) * ((b * output_depth + out_d) * output_height + out_h)
output_depth + * output_width + out_w;
out_d; const index_t i_index =
const int i_index = ((b * input_depth + d) * input_height + h) * input_width
((b * input_height + h) * input_width + w) * input_depth + d; + w;
output_ptr[o_index] = input_ptr[i_index]; output_ptr[o_index] = input_ptr[i_index];
} }
} }
...@@ -112,10 +112,10 @@ struct DepthToSpaceOpFunctor { ...@@ -112,10 +112,10 @@ struct DepthToSpaceOpFunctor {
bool d2s_; bool d2s_;
}; };
template <typename T> template<typename T>
struct DepthToSpaceOpFunctor<DeviceType::OPENCL, T> { struct DepthToSpaceOpFunctor<DeviceType::OPENCL, T> {
DepthToSpaceOpFunctor(const int block_size, bool d2s) DepthToSpaceOpFunctor(const int block_size, bool d2s)
: block_size_(block_size), d2s_(d2s) {} : block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future); void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
const int block_size_; const int block_size_;
......
...@@ -26,225 +26,12 @@ ...@@ -26,225 +26,12 @@
#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/kernels/conv_pool_2d_util.h" #include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/activation.h" #include "mace/kernels/activation.h"
#include "mace/kernels/arm/depthwise_conv2d_neon.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <typename T>
void DepthwiseConv2dKernel(const T *input_ptr,
const T *filter_ptr,
const T *bias_ptr,
T *output_ptr,
int batch,
int height,
int width,
int channels,
int input_height,
int input_width,
int input_channels,
int multiplier,
int padded_h_start,
int padded_h_stop,
int padded_w_start,
int padded_w_stop,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int h_start,
int h_stop,
int w_start,
int w_stop) {
#pragma omp parallel for collapse(4)
for (int n = 0; n < batch; ++n) {
for (int h = h_start; h < h_stop; ++h) {
for (int w = w_start; w < w_stop; ++w) {
for (int c = 0; c < channels; ++c) {
const index_t inc = c / multiplier;
const index_t m = c % multiplier;
T bias_channel = bias_ptr ? bias_ptr[c] : 0;
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
output_ptr[offset] = bias_channel;
T sum = 0;
const T *filter_base = filter_ptr + inc * multiplier + m;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh < 0 || inh >= input_height || inw < 0 ||
inw >= input_width) {
MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw >= padded_w_start && inw < padded_w_stop,
"Out of range read from input: ", padded_h_start,
" <= ", inh, " < ", padded_h_stop, ", ",
padded_w_start, " <= ", inw, " < ", padded_w_stop);
} else {
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
inc;
sum += input_ptr[input_offset] * filter_base[0]; // HWIM
}
filter_base += input_channels * multiplier;
}
}
output_ptr[offset] += sum;
}
}
}
}
}
template <typename T>
void DepthwiseConv2dNoOOBCheckKernel(const T *input_ptr,
const T *filter_ptr,
const T *bias_ptr,
T *output_ptr,
int batch,
int height,
int width,
int channels,
int input_height,
int input_width,
int input_channels,
int multiplier,
int padded_h_start,
int padded_h_stop,
int padded_w_start,
int padded_w_stop,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int h_start,
int h_stop,
int w_start,
int w_stop) {
if (multiplier == 1) {
constexpr int c_tile_size = 4;
#pragma omp parallel for collapse(3)
for (int n = 0; n < batch; ++n) {
for (int h = h_start; h < h_stop; ++h) {
for (int w = w_start; w < w_stop; ++w) {
int c;
for (c = 0; c + c_tile_size <= channels; c += c_tile_size) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert(c_tile_size == 4, "channels tile size must be 4");
float32x4_t sum = vdupq_n_f32(0);
if (bias_ptr != nullptr) {
sum = vld1q_f32(bias_ptr + c);
}
#else
T sum[c_tile_size] = {0};
if (bias_ptr != nullptr) {
for (int ci = 0; ci < c_tile_size; ++ci) {
sum[ci] = bias_ptr[c + ci];
}
}
#endif
const T *filter_base = filter_ptr + c;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width);
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
c;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
float32x4_t in = vld1q_f32(input_ptr + input_offset);
float32x4_t weights = vld1q_f32(filter_base);
sum = vfmaq_f32(sum, in, weights);
#else
for (int ci = 0; ci < c_tile_size; ++ci) {
sum[ci] +=
input_ptr[input_offset + ci] * filter_base[ci]; // HWIM
}
#endif
filter_base += input_channels;
}
}
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
vst1q_f32(output_ptr + offset, sum);
#else
for (int ci = 0; ci < c_tile_size; ++ci) {
output_ptr[offset + ci] = sum[ci];
}
#endif
}
for (; c < channels; ++c) {
T bias_channel = bias_ptr ? bias_ptr[c] : 0;
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
output_ptr[offset] = bias_channel;
T sum = 0;
const T *filter_base = filter_ptr + c;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width);
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
c;
sum += input_ptr[input_offset] * filter_base[0]; // HWIM
filter_base += input_channels * multiplier;
}
}
output_ptr[offset] += sum;
}
}
}
}
} else {
#pragma omp parallel for collapse(4)
for (int n = 0; n < batch; ++n) {
for (int h = h_start; h < h_stop; ++h) {
for (int w = w_start; w < w_stop; ++w) {
for (int c = 0; c < channels; ++c) {
const index_t inc = c / multiplier;
const index_t m = c % multiplier;
T bias_channel = bias_ptr ? bias_ptr[c] : 0;
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
output_ptr[offset] = bias_channel;
T sum = 0;
const T *filter_base = filter_ptr + inc * multiplier + m;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width);
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
inc;
sum += input_ptr[input_offset] * filter_base[0]; // HWIM
filter_base += input_channels * multiplier;
}
}
output_ptr[offset] += sum;
}
}
}
}
}
}
struct DepthwiseConv2dFunctorBase { struct DepthwiseConv2dFunctorBase {
DepthwiseConv2dFunctorBase(const int *strides, DepthwiseConv2dFunctorBase(const int *strides,
const Padding padding_type, const Padding padding_type,
...@@ -252,12 +39,12 @@ struct DepthwiseConv2dFunctorBase { ...@@ -252,12 +39,12 @@ struct DepthwiseConv2dFunctorBase {
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit) const float relux_max_limit)
: strides_(strides), : strides_(strides),
padding_type_(padding_type), padding_type_(padding_type),
paddings_(paddings), paddings_(paddings),
dilations_(dilations), dilations_(dilations),
activation_(activation), activation_(activation),
relux_max_limit_(relux_max_limit) {} relux_max_limit_(relux_max_limit) {}
const int *strides_; // [stride_h, stride_w] const int *strides_; // [stride_h, stride_w]
const Padding padding_type_; const Padding padding_type_;
...@@ -267,159 +54,246 @@ struct DepthwiseConv2dFunctorBase { ...@@ -267,159 +54,246 @@ struct DepthwiseConv2dFunctorBase {
const float relux_max_limit_; const float relux_max_limit_;
}; };
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { struct DepthwiseConv2dFunctor;
template<>
struct DepthwiseConv2dFunctor<DeviceType::CPU, float>
: public DepthwiseConv2dFunctorBase {
DepthwiseConv2dFunctor(const int *strides, DepthwiseConv2dFunctor(const int *strides,
const Padding padding_type, const Padding padding_type,
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit) const float relux_max_limit)
: DepthwiseConv2dFunctorBase(strides, : DepthwiseConv2dFunctorBase(strides,
padding_type, padding_type,
paddings, paddings,
dilations, dilations,
activation, activation,
relux_max_limit) {} relux_max_limit) {}
void operator()(const Tensor *input, // NHWC void DepthwiseConv2dGeneral(const float *input,
const Tensor *filter, // HWIM const float *filter,
const Tensor *bias, // O const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t multiplier = out_channels / in_channels;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
index_t c = m / multiplier;
index_t o = m % multiplier;
float sum = 0;
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h - pad_top;
index_t iw = w * stride_w + kw * dilation_w - pad_left;
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((o * in_channels) + c) * filter_height + kh)
* filter_width
+ kw;
sum += input[in_offset] * filter[filter_offset];
}
}
}
output[out_offset] = sum;
}
}
}
}
}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
MACE_CHECK_NOTNULL(input); MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
// Create a fake conv_2d filter to calculate the paddings and output size
std::vector<index_t> fake_filter_shape(4);
fake_filter_shape[0] = filter->shape()[0];
fake_filter_shape[1] = filter->shape()[1];
fake_filter_shape[2] = filter->shape()[2] * filter->shape()[3];
fake_filter_shape[3] = 1;
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
std::vector<index_t> filter_shape
{filter->dim(0) * filter->dim(1), filter->dim(1), filter->dim(2),
filter->dim(3)};
if (paddings_.empty()) { if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize( CalcNCHWPaddingAndOutputSize(input->shape().data(),
input->shape().data(), fake_filter_shape.data(), dilations_, strides_, filter_shape.data(),
padding_type_, output_shape.data(), paddings.data()); dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), fake_filter_shape.data(), CalcNCHWOutputSize(input->shape().data(),
paddings_.data(), dilations_, strides_, RoundType::FLOOR, filter_shape.data(),
output_shape.data()); paddings_.data(),
dilations_,
strides_,
RoundType::FLOOR,
output_shape.data());
} }
auto input_shape = fake_filter_shape;
output->Resize(output_shape); output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0); index_t batch = output->dim(0);
index_t height = output->dim(1); index_t channels = output->dim(1);
index_t width = output->dim(2); index_t height = output->dim(2);
index_t channels = output->dim(3); index_t width = output->dim(3);
index_t input_batch = input->dim(0); index_t input_batch = input->dim(0);
index_t input_height = input->dim(1); index_t input_channels = input->dim(1);
index_t input_width = input->dim(2); index_t input_height = input->dim(2);
index_t input_channels = input->dim(3); index_t input_width = input->dim(3);
index_t kernel_h = filter->dim(0); index_t filter_h = filter_shape[2];
index_t kernel_w = filter->dim(1); index_t filter_w = filter_shape[3];
index_t multiplier = filter->dim(3); MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=", MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels); input_channels);
MACE_CHECK(channels == input_channels * multiplier);
int stride_h = strides_[0]; index_t stride_h = strides_[0];
int stride_w = strides_[1]; index_t stride_w = strides_[1];
int dilation_h = dilations_[0]; index_t dilation_h = dilations_[0];
int dilation_w = dilations_[1]; index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
// The left-upper most offset of the padded input int pad_top = paddings[0] >> 1;
int paddings_top = paddings[0] / 2; int pad_bottom = paddings[0] - pad_top;
int paddings_bottom = paddings[0] - paddings_top; int pad_left = paddings[1] >> 1;
int paddings_left = paddings[1] / 2; int pad_right = paddings[1] - pad_left;
int paddings_right = paddings[1] - paddings_left;
index_t valid_h_start = pad_top == 0 ? 0 : (pad_top - 1) / stride_h + 1;
int padded_h_start = 0 - paddings_top; index_t valid_h_stop = pad_bottom == 0
int padded_w_start = 0 - paddings_left;
index_t padded_h_stop = input_height + paddings_bottom;
index_t padded_w_stop = input_width + paddings_right;
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard filter_mapper(filter);
Tensor::MappingGuard bias_mapper(bias);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
const T *filter_ptr = filter->data<T>();
const T *bias_ptr = bias == nullptr ? nullptr : bias->data<T>();
T *output_ptr = output->mutable_data<T>();
int valid_h_start =
paddings_top == 0 ? 0 : (paddings_top - 1) / stride_h + 1;
int valid_h_stop = paddings_bottom == 0
? height ? height
: height - ((paddings_bottom - 1) / stride_h + 1); : height - ((pad_bottom - 1) / stride_h + 1);
int valid_w_start = index_t valid_w_start = pad_left == 0 ? 0 : (pad_left - 1) / stride_w + 1;
paddings_left == 0 ? 0 : (paddings_left - 1) / stride_w + 1; index_t valid_w_stop = pad_right == 0
int valid_w_stop = paddings_right == 0
? width ? width
: width - ((paddings_right - 1) / stride_w + 1); : width - ((pad_right - 1) / stride_w + 1);
// Calculate border elements with out-of-boundary checking std::function<void(const float *input, float *output)> conv_func;
if (valid_h_start > 0) {
DepthwiseConv2dKernel<T>( Tensor::MappingGuard input_guard(input);
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, Tensor::MappingGuard filter_guard(filter);
channels, input_height, input_width, input_channels, multiplier, Tensor::MappingGuard bias_guard(bias);
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, Tensor::MappingGuard output_guard(output);
kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, 0, auto input_data = input->data<float>();
valid_h_start, 0, width); auto filter_data = filter->data<float>();
} auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
if (valid_h_stop < height) { auto output_data = output->mutable_data<float>();
DepthwiseConv2dKernel<T>(
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
channels, input_height, input_width, input_channels, multiplier, && dilation_h == 1 && dilation_w == 1) {
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, conv_func = [=](const float *input, float *output) {
kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, DepthwiseConv2dNeonK3x3S1(input,
std::max(valid_h_start, valid_h_stop), height, 0, width); filter_data,
} batch,
if (valid_w_start > 0) { input_height,
DepthwiseConv2dKernel<T>( input_width,
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, input_channels,
channels, input_height, input_width, input_channels, multiplier, height,
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, width,
kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, channels,
valid_h_start, valid_h_stop, 0, valid_w_start); pad_top,
} pad_left,
if (valid_w_stop < width) { valid_h_start,
DepthwiseConv2dKernel<T>( valid_h_stop,
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, valid_w_start,
channels, input_height, input_width, input_channels, multiplier, valid_w_stop,
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, output);
kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, };
valid_h_start, valid_h_stop, std::max(valid_w_start, valid_w_stop), } else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
width); && dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S2(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
pad_top,
pad_left,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dGeneral(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
};
} }
// Calculate border elements without out-of-boundary checking conv_func(input_data, output_data);
DepthwiseConv2dNoOOBCheckKernel<T>(
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width,
channels, input_height, input_width, input_channels, multiplier,
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, kernel_h,
kernel_w, stride_h, stride_w, dilation_h, dilation_w, valid_h_start,
valid_h_stop, valid_w_start, valid_w_stop);
output_ptr = output->mutable_data<T>(); if (bias_data != nullptr) {
DoActivation(output_ptr, output_ptr, output->size(), activation_, #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t i = 0; i < height * width; ++i) {
output_data[(b * channels + c) * height * width + i] +=
bias_data[c];
}
}
}
}
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_); relux_max_limit_);
} }
}; };
template <> template<typename T>
struct DepthwiseConv2dFunctor<DeviceType::NEON, float> struct DepthwiseConv2dFunctor<DeviceType::OPENCL, T>
: DepthwiseConv2dFunctorBase { : DepthwiseConv2dFunctorBase {
DepthwiseConv2dFunctor(const int *strides, DepthwiseConv2dFunctor(const int *strides,
const Padding padding_type, const Padding padding_type,
...@@ -439,29 +313,6 @@ struct DepthwiseConv2dFunctor<DeviceType::NEON, float> ...@@ -439,29 +313,6 @@ struct DepthwiseConv2dFunctor<DeviceType::NEON, float>
const Tensor *bias, const Tensor *bias,
Tensor *output, Tensor *output,
StatsFuture *future); StatsFuture *future);
};
template <typename T>
struct DepthwiseConv2dFunctor<DeviceType::OPENCL, T>
: DepthwiseConv2dFunctorBase {
DepthwiseConv2dFunctor(const int *strides,
const Padding padding_type,
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit)
: DepthwiseConv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_; cl::Kernel kernel_;
uint32_t kwg_size_; uint32_t kwg_size_;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/kernels/activation.h" #include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h" #include "mace/kernels/opencl/helper.h"
#include "mace/kernels/gemm.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
...@@ -41,7 +42,10 @@ struct FullyConnectedBase { ...@@ -41,7 +42,10 @@ struct FullyConnectedBase {
}; };
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct FullyConnectedFunctor : FullyConnectedBase { struct FullyConnectedFunctor;
template <>
struct FullyConnectedFunctor<DeviceType::CPU, float>: FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type, FullyConnectedFunctor(const BufferType weight_type,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit) const float relux_max_limit)
...@@ -52,33 +56,25 @@ struct FullyConnectedFunctor : FullyConnectedBase { ...@@ -52,33 +56,25 @@ struct FullyConnectedFunctor : FullyConnectedBase {
const Tensor *bias, const Tensor *bias,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), 1, 1, weight->dim(0)}; std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
output->Resize(output_shape); output->Resize(output_shape);
const index_t N = output->dim(0); const index_t N = output->dim(0);
const index_t input_size = weight->dim(1); const index_t input_size = weight->dim(1);
const index_t output_size = weight->dim(0); const index_t output_size = weight->dim(0);
Tensor::MappingGuard guard_input(input); Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_weight(weight); Tensor::MappingGuard guard_weight(weight);
Tensor::MappingGuard guard_bias(bias); Tensor::MappingGuard guard_bias(bias);
Tensor::MappingGuard guard_output(output); Tensor::MappingGuard guard_output(output);
const T *input_ptr = input->data<T>(); const float *input_ptr = input->data<float>();
const T *weight_ptr = weight->data<T>(); const float *weight_ptr = weight->data<float>();
const T *bias_ptr = bias == nullptr ? nullptr : bias->data<T>(); const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
T *output_ptr = output->mutable_data<T>(); float *output_ptr = output->mutable_data<float>();
#pragma omp parallel for collapse(2) Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
for (int out_idx = 0; out_idx < output_size; ++out_idx) { for (int j = 0; j < output_size; ++j) {
T sum = 0; output_ptr[j + i * output_size] += bias_ptr[j];
if (bias_ptr != nullptr) sum = bias_ptr[out_idx];
index_t input_offset = i * input_size;
index_t weight_offset = out_idx * input_size;
for (int in_idx = 0; in_idx < input_size; ++in_idx) {
sum += input_ptr[input_offset] * weight_ptr[weight_offset];
input_offset++;
weight_offset++;
}
output_ptr[i * output_size + out_idx] = sum;
} }
} }
...@@ -87,20 +83,6 @@ struct FullyConnectedFunctor : FullyConnectedBase { ...@@ -87,20 +83,6 @@ struct FullyConnectedFunctor : FullyConnectedBase {
} }
}; };
template <>
struct FullyConnectedFunctor<DeviceType::NEON, float> : FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type,
const ActivationType activation,
const float relux_max_limit)
: FullyConnectedBase(weight_type, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
};
template <typename T> template <typename T>
struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase { struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type, FullyConnectedFunctor(const BufferType weight_type,
......
...@@ -747,7 +747,7 @@ void Gemv(const float *m_ptr, ...@@ -747,7 +747,7 @@ void Gemv(const float *m_ptr,
for (index_t h = 0; h < remain_h; ++h) { for (index_t h = 0; h < remain_h; ++h) {
float32x4_t vsum0 = vdupq_n_f32(0.f); float32x4_t vsum0 = vdupq_n_f32(0.f);
const float *m_ptr0 = m_ptr + (h + remain_start_height) * width; const float *m_ptr0 = m_ptr + (h + remain_start_height) * width;
const float *v_ptr0 = v_ptr; const float *v_ptr0 = v_ptr + b * width;
for (index_t w = 0; w < width_d4; ++w) { for (index_t w = 0; w < width_d4; ++w) {
float32x4_t vm = vld1q_f32(m_ptr0); float32x4_t vm = vld1q_f32(m_ptr0);
float32x4_t vv = vld1q_f32(v_ptr0); float32x4_t vv = vld1q_f32(v_ptr0);
...@@ -761,7 +761,7 @@ void Gemv(const float *m_ptr, ...@@ -761,7 +761,7 @@ void Gemv(const float *m_ptr,
m_ptr0++; m_ptr0++;
v_ptr0++; v_ptr0++;
} }
out_ptr[remain_start_height + h] = sum; out_ptr[remain_start_height + h + b * height] = sum;
} }
} }
#else #else
......
...@@ -17,8 +17,11 @@ ...@@ -17,8 +17,11 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct LocalResponseNormFunctor { struct LocalResponseNormFunctor;
template<>
struct LocalResponseNormFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input, void operator()(const Tensor *input,
int depth_radius, int depth_radius,
float bias, float bias,
...@@ -27,48 +30,39 @@ struct LocalResponseNormFunctor { ...@@ -27,48 +30,39 @@ struct LocalResponseNormFunctor {
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t height = input->dim(1); const index_t channels = input->dim(1);
const index_t width = input->dim(2); const index_t height = input->dim(2);
const index_t channels = input->dim(3); const index_t width = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>(); const float *input_ptr = input->data<float>();
T *output_ptr = output->mutable_data<T>(); float *output_ptr = output->mutable_data<float>();
const int elements = batch * height * width; index_t image_size = height * width;
index_t batch_size = channels * image_size;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) { for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) { for (index_t c = 0; c < channels; ++c) {
const int begin_input_c = std::max(static_cast<index_t>(0), const int begin_input_c = std::max(static_cast<index_t>(0),
c - depth_radius); c - depth_radius);
const int end_input_c = std::min(channels, c + depth_radius + 1); const int end_input_c = std::min(channels, c + depth_radius + 1);
index_t pos = i * channels;
float accum = 0.f; index_t pos = b * batch_size;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) { for (index_t hw = 0; hw < height * width; ++hw, ++pos) {
const float input_val = input_ptr[pos + input_c]; float accum = 0.f;
accum += input_val * input_val; for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_ptr[pos + input_c * image_size];
accum += input_val * input_val;
}
const float multiplier = std::pow(bias + alpha * accum, -beta);
output_ptr[pos + c * image_size] =
input_ptr[pos + c * image_size] * multiplier;
} }
const float multiplier = std::pow(bias + alpha * accum, -beta);
output_ptr[pos + c] = input_ptr[pos + c] * multiplier;
} }
} }
} }
}; };
template <>
struct LocalResponseNormFunctor<DeviceType::NEON, float> {
void operator()(const Tensor *input,
int depth_radius,
float bias,
float alpha,
float beta,
Tensor *output,
StatsFuture *future);
};
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -57,7 +57,10 @@ struct PoolingFunctorBase { ...@@ -57,7 +57,10 @@ struct PoolingFunctorBase {
}; };
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct PoolingFunctor : PoolingFunctorBase { struct PoolingFunctor;
template <>
struct PoolingFunctor<DeviceType::CPU, float>: PoolingFunctorBase {
PoolingFunctor(const PoolingType pooling_type, PoolingFunctor(const PoolingType pooling_type,
const int *kernels, const int *kernels,
const int *strides, const int *strides,
...@@ -68,43 +71,141 @@ struct PoolingFunctor : PoolingFunctorBase { ...@@ -68,43 +71,141 @@ struct PoolingFunctor : PoolingFunctorBase {
pooling_type, kernels, strides, padding_type, paddings, dilations) { pooling_type, kernels, strides, padding_type, paddings, dilations) {
} }
void MaxPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = std::numeric_limits<float>::lowest();
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res = std::max(res, input[input_offset]);
}
}
}
output[out_offset] = res;
}
}
}
}
}
void AvgPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = 0;
int block_size = 0;
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res += input[input_offset];
++block_size;
}
}
}
output[out_offset] = res / block_size;
}
}
}
}
}
void operator()(const Tensor *input_tensor, void operator()(const Tensor *input_tensor,
Tensor *output_tensor, Tensor *output_tensor,
StatsFuture *future) { StatsFuture *future) {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = { std::vector<index_t> filter_shape = {
kernels_[0], kernels_[1], input_tensor->dim(3), input_tensor->dim(3)}; input_tensor->dim(1), input_tensor->dim(1), kernels_[0], kernels_[1]};
std::vector<int> paddings(2); std::vector<int> paddings(2);
if (paddings_.empty()) { if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize( kernels::CalcNCHWPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_, input_tensor->shape().data(), filter_shape.data(), dilations_,
strides_, padding_type_, output_shape.data(), paddings.data()); strides_, padding_type_, output_shape.data(), paddings.data());
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), CalcNCHWOutputSize(input_tensor->shape().data(),
paddings_.data(), dilations_, strides_, RoundType::CEIL, filter_shape.data(),
output_shape.data()); paddings_.data(),
dilations_,
strides_,
RoundType::CEIL,
output_shape.data());
} }
output_tensor->Resize(output_shape); output_tensor->Resize(output_shape);
Tensor::MappingGuard in_guard(input_tensor); Tensor::MappingGuard input_guard(input_tensor);
Tensor::MappingGuard out_guard(output_tensor); Tensor::MappingGuard output_guard(output_tensor);
const T *input = input_tensor->data<T>(); const float *input = input_tensor->data<float>();
T *output = output_tensor->mutable_data<T>(); float *output = output_tensor->mutable_data<float>();
const index_t *input_shape = input_tensor->shape().data(); const index_t *input_shape = input_tensor->shape().data();
index_t batch = output_shape[0]; index_t batch = output_shape[0];
index_t height = output_shape[1]; index_t channels = output_shape[1];
index_t width = output_shape[2]; index_t height = output_shape[2];
index_t channels = output_shape[3]; index_t width = output_shape[3];
index_t input_height = input_shape[1]; index_t input_height = input_shape[2];
index_t input_width = input_shape[2]; index_t input_width = input_shape[3];
index_t input_channels = input_shape[3];
index_t in_image_size = input_height * input_width;
int kernel_h = kernels_[0]; int filter_h = kernels_[0];
int kernel_w = kernels_[1]; int filter_w = kernels_[1];
int stride_h = strides_[0]; int stride_h = strides_[0];
int stride_w = strides_[1]; int stride_w = strides_[1];
...@@ -112,84 +213,47 @@ struct PoolingFunctor : PoolingFunctorBase { ...@@ -112,84 +213,47 @@ struct PoolingFunctor : PoolingFunctorBase {
int dilation_h = dilations_[0]; int dilation_h = dilations_[0];
int dilation_w = dilations_[1]; int dilation_w = dilations_[1];
// The left-upper most offset of the padded input int pad_top = paddings[0] / 2;
int padded_h_start = 0 - paddings[0] / 2; int pad_left = paddings[1] / 2;
int padded_w_start = 0 - paddings[1] / 2;
if (pooling_type_ == MAX) {
#pragma omp parallel for collapse(4)
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
index_t out_offset =
(((b * height) + h) * width + w) * channels + c;
index_t in_offset = b * in_image_size * input_channels + c;
T res = std::numeric_limits<T>::lowest();
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width) {
index_t input_offset =
in_offset + (inh * input_width + inw) * input_channels;
res = std::max(res, input[input_offset]);
}
}
}
output[out_offset] = res;
}
}
}
}
} else if (pooling_type_ == AVG) {
#pragma omp parallel for collapse(4)
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
index_t out_offset =
(((b * height) + h) * width + w) * channels + c;
index_t in_offset = b * in_image_size * input_channels + c;
T sum = static_cast<T>(0);
int block_size = 0;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width) {
index_t input_offset =
in_offset + (inh * input_width + inw) * input_channels;
sum += input[input_offset];
block_size += 1;
}
}
}
output[out_offset] = sum / block_size;
}
}
}
}
}
}
};
template <> if (pooling_type_ == PoolingType::MAX) {
struct PoolingFunctor<DeviceType::NEON, float> : PoolingFunctorBase { MaxPooling(input,
PoolingFunctor(const PoolingType pooling_type, batch,
const int *kernels, input_height,
const int *strides, input_width,
const Padding padding_type, channels,
const std::vector<int> &paddings, height,
const int *dilations) width,
: PoolingFunctorBase( filter_h,
pooling_type, kernels, strides, padding_type, paddings, dilations) { filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else if (pooling_type_ == PoolingType::AVG) {
AvgPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else {
MACE_NOT_IMPLEMENTED;
}
} }
void operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future);
}; };
template <typename T> template <typename T>
......
...@@ -38,15 +38,15 @@ inline float CalculateResizeScale(index_t in_size, ...@@ -38,15 +38,15 @@ inline float CalculateResizeScale(index_t in_size,
index_t out_size, index_t out_size,
bool align_corners) { bool align_corners) {
return (align_corners && out_size > 1) return (align_corners && out_size > 1)
? (in_size - 1) / static_cast<float>(out_size - 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
: in_size / static_cast<float>(out_size); : in_size / static_cast<float>(out_size);
} }
inline void ComputeInterpolationWeights( inline void ComputeInterpolationWeights(
const index_t out_size, const index_t out_size,
const index_t in_size, const index_t in_size,
const float scale, const float scale,
CachedInterpolation *interpolation) { CachedInterpolation *interpolation) {
interpolation[out_size].lower = 0; interpolation[out_size].lower = 0;
interpolation[out_size].upper = 0; interpolation[out_size].upper = 0;
for (index_t i = out_size - 1; i >= 0; --i) { for (index_t i = out_size - 1; i >= 0; --i) {
...@@ -68,8 +68,7 @@ inline float ComputeLerp(const float top_left, ...@@ -68,8 +68,7 @@ inline float ComputeLerp(const float top_left,
return top + (bottom - top) * y_lerp; return top + (bottom - top) * y_lerp;
} }
template <typename T> inline void ResizeImage(const float *images,
void ResizeImage(const T *images,
const index_t batch_size, const index_t batch_size,
const index_t in_height, const index_t in_height,
const index_t in_width, const index_t in_width,
...@@ -78,39 +77,32 @@ void ResizeImage(const T *images, ...@@ -78,39 +77,32 @@ void ResizeImage(const T *images,
const index_t channels, const index_t channels,
const std::vector<CachedInterpolation> &xs_vec, const std::vector<CachedInterpolation> &xs_vec,
const std::vector<CachedInterpolation> &ys, const std::vector<CachedInterpolation> &ys,
T *output) { float *output) {
const index_t in_batch_num_values = channels * in_height * in_width;
const index_t out_batch_num_values = channels * out_height * out_width;
const CachedInterpolation *xs = xs_vec.data(); const CachedInterpolation *xs = xs_vec.data();
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch_size; ++b) { for (index_t b = 0; b < batch_size; ++b) {
for (index_t y = 0; y < out_height; ++y) { for (index_t c = 0; c < channels; ++c) {
const T *batch_input_ptr = images + in_batch_num_values * b; const float
T *batch_output_ptr = output + out_batch_num_values * b; *channel_input_ptr = images + (b * channels + c) * in_height * in_width;
const T *y_lower_input_ptr = float *channel_output_ptr =
batch_input_ptr + ys[y].lower * in_width * channels; output + (b * channels + c) * out_height * out_width;
const T *y_upper_input_ptr = for (index_t y = 0; y < out_height; ++y) {
batch_input_ptr + ys[y].upper * in_width * channels; const float *y_lower_input_ptr =
T *y_output_ptr = batch_output_ptr + y * out_width * channels; channel_input_ptr + ys[y].lower * in_width;
const float ys_lerp = ys[y].lerp; const float *y_upper_input_ptr =
channel_input_ptr + ys[y].upper * in_width;
for (index_t x = 0; x < out_width; ++x) { const float ys_lerp = ys[y].lerp;
const float xs_lerp = xs[x].lerp;
const T *top_left_ptr = y_lower_input_ptr + xs[x].lower * channels; for (index_t x = 0; x < out_width; ++x) {
const T *top_right_ptr = y_lower_input_ptr + xs[x].upper * channels; const float xs_lerp = xs[x].lerp;
const T *bottom_left_ptr = y_upper_input_ptr + xs[x].lower * channels; const float top_left = y_lower_input_ptr[xs[x].lower];
const T *bottom_right_ptr = y_upper_input_ptr + xs[x].upper * channels; const float top_right = y_lower_input_ptr[xs[x].upper];
T *output_ptr = y_output_ptr + x * channels; const float bottom_left = y_upper_input_ptr[xs[x].lower];
const float bottom_right = y_upper_input_ptr[xs[x].upper];
for (index_t c = 0; c < channels; ++c) { channel_output_ptr[y * out_width + x] =
const T top_left = top_left_ptr[c]; ComputeLerp(top_left, top_right, bottom_left,
const T top_right = top_right_ptr[c]; bottom_right, xs_lerp, ys_lerp);
const T bottom_left = bottom_left_ptr[c];
const T bottom_right = bottom_right_ptr[c];
output_ptr[c] = ComputeLerp(top_left, top_right, bottom_left,
bottom_right, xs_lerp, ys_lerp);
} }
} }
} }
...@@ -120,7 +112,7 @@ void ResizeImage(const T *images, ...@@ -120,7 +112,7 @@ void ResizeImage(const T *images,
struct ResizeBilinearFunctorBase { struct ResizeBilinearFunctorBase {
ResizeBilinearFunctorBase(const std::vector<index_t> &size, ResizeBilinearFunctorBase(const std::vector<index_t> &size,
bool align_corners) bool align_corners)
: align_corners_(align_corners) { : align_corners_(align_corners) {
MACE_CHECK(size.size() == 2); MACE_CHECK(size.size() == 2);
out_height_ = size[0]; out_height_ = size[0];
out_width_ = size[1]; out_width_ = size[1];
...@@ -132,38 +124,43 @@ struct ResizeBilinearFunctorBase { ...@@ -132,38 +124,43 @@ struct ResizeBilinearFunctorBase {
index_t out_width_; index_t out_width_;
}; };
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct ResizeBilinearFunctor : ResizeBilinearFunctorBase { struct ResizeBilinearFunctor;
template<>
struct ResizeBilinearFunctor<DeviceType::CPU, float>
: ResizeBilinearFunctorBase {
ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners) ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners)
: ResizeBilinearFunctorBase(size, align_corners) {} : ResizeBilinearFunctorBase(size, align_corners) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) { void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t in_height = input->dim(1); const index_t channels = input->dim(1);
const index_t in_width = input->dim(2); const index_t in_height = input->dim(2);
const index_t channels = input->dim(3); const index_t in_width = input->dim(3);
index_t out_height = out_height_; index_t out_height = out_height_;
index_t out_width = out_width_; index_t out_width = out_width_;
MACE_CHECK(out_height > 0 && out_width > 0); MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> out_shape{batch, out_height, out_width, channels}; std::vector<index_t> out_shape{batch, channels, out_height, out_width};
output->Resize(out_shape); output->Resize(out_shape);
Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output); Tensor::MappingGuard output_mapper(output);
const T *input_data = input->data<T>(); const float *input_data = input->data<float>();
T *output_data = output->mutable_data<T>(); float *output_data = output->mutable_data<float>();
if (out_height == in_height && out_width == in_width) { if (out_height == in_height && out_width == in_width) {
std::copy(input_data, input_data + channels * in_height * in_width, std::copy(input_data,
input_data + batch * channels * in_height * in_width,
output_data); output_data);
return; return;
} }
float height_scale = float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_); CalculateResizeScale(in_height, out_height, align_corners_);
float width_scale = float width_scale =
CalculateResizeScale(in_width, out_width, align_corners_); CalculateResizeScale(in_width, out_width, align_corners_);
std::vector<CachedInterpolation> ys(out_height + 1); std::vector<CachedInterpolation> ys(out_height + 1);
std::vector<CachedInterpolation> xs(out_width + 1); std::vector<CachedInterpolation> xs(out_width + 1);
...@@ -177,11 +174,11 @@ struct ResizeBilinearFunctor : ResizeBilinearFunctorBase { ...@@ -177,11 +174,11 @@ struct ResizeBilinearFunctor : ResizeBilinearFunctorBase {
} }
}; };
template <typename T> template<typename T>
struct ResizeBilinearFunctor<DeviceType::OPENCL, T> struct ResizeBilinearFunctor<DeviceType::OPENCL, T>
: ResizeBilinearFunctorBase { : ResizeBilinearFunctorBase {
ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners) ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners)
: ResizeBilinearFunctorBase(size, align_corners) {} : ResizeBilinearFunctorBase(size, align_corners) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future); void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <limits>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/cl2_header.h"
...@@ -29,50 +30,66 @@ ...@@ -29,50 +30,66 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct SoftmaxFunctor { struct SoftmaxFunctor;
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future) {
Tensor::MappingGuard logits_guard(logits); template<>
struct SoftmaxFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t class_count = input->dim(1);
const index_t class_size = input->dim(2) * input->dim(3);
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
const T *logits_ptr = logits->data<T>(); const float *input_data = input->data<float>();
T *output_ptr = output->mutable_data<T>(); float *output_data = output->mutable_data<float>();
auto &logits_shape = logits->shape();
const index_t batch_size = for (index_t b = 0; b < batch; ++b) {
std::accumulate(logits_shape.begin(), logits_shape.end() - 1, 1, std::vector<float>
std::multiplies<index_t>()); max_val(class_size, std::numeric_limits<float>::lowest());
const index_t num_classes = logits_shape.back(); std::vector<float> sum_val(class_size, 0.f);
#pragma omp parallel // calculate max for each class
{ for (index_t c = 0; c < class_count; ++c) {
// Allocate per thread buffer const float
std::vector<T> exp_data(num_classes); *input_ptr = input_data + (b * class_count + c) * class_size;
#pragma omp for for (index_t k = 0; k < class_size; ++k) {
for (index_t i = 0; i < batch_size; ++i) { max_val[k] = std::max(max_val[k], input_ptr[k]);
const index_t pos = i * num_classes; }
T max_value = logits_ptr[pos]; }
for (index_t c = 1; c < num_classes; ++c) {
max_value = std::max(max_value, logits_ptr[pos + c]); // calculate data - max for each class
#pragma omp parallel for
for (index_t c = 0; c < class_count; ++c) {
const float
*input_ptr = input_data + (b * class_count + c) * class_size;
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] = ::exp(input_ptr[k] - max_val[k]);
} }
// TODO(liuqi): check overflow? }
T sum = 0;
for (index_t c = 0; c < num_classes; ++c) { // calculate sum for each class
exp_data[c] = ::exp((logits_ptr[pos + c] - max_value)); for (index_t c = 0; c < class_count; ++c) {
sum += exp_data[c]; float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
sum_val[k] += output_ptr[k];
} }
for (index_t c = 0; c < num_classes; ++c) { }
output_ptr[pos + c] = exp_data[c] / sum;
// calculate (data - max) / sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] /= sum_val[k];
} }
} }
} }
} }
}; };
template <> template<typename T>
struct SoftmaxFunctor<DeviceType::NEON, float> {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future);
};
template <typename T>
struct SoftmaxFunctor<DeviceType::OPENCL, T> { struct SoftmaxFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future); void operator()(const Tensor *logits, Tensor *output, StatsFuture *future);
......
...@@ -35,11 +35,6 @@ void Register_Activation(OperatorRegistry *op_registry) { ...@@ -35,11 +35,6 @@ void Register_Activation(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
ActivationOp<DeviceType::OPENCL, half>); ActivationOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
ActivationOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -31,9 +31,21 @@ void ReluBenchmark( ...@@ -31,9 +31,21 @@ void ReluBenchmark(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("Activation", "ReluBM")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
...@@ -43,11 +55,7 @@ void ReluBenchmark( ...@@ -43,11 +55,7 @@ void ReluBenchmark(
.AddStringArg("activation", "RELU") .AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("Activation", "ReluBM") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
...@@ -93,7 +101,11 @@ void ReluxBenchmark( ...@@ -93,7 +101,11 @@ void ReluxBenchmark(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
}
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
...@@ -157,10 +169,23 @@ void PreluBenchmark( ...@@ -157,10 +169,23 @@ void PreluBenchmark(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, float>("Alpha", {channels}); net.AddRandomInput<D, float>("Alpha", {channels});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("Activation", "PreluBM")
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "PRELU")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Alpha", "AlphaImage", BufferToImage<D, float>(&net, "Alpha", "AlphaImage",
...@@ -173,12 +198,7 @@ void PreluBenchmark( ...@@ -173,12 +198,7 @@ void PreluBenchmark(
.AddStringArg("activation", "PRELU") .AddStringArg("activation", "PRELU")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("Activation", "PreluBM") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "PRELU")
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
...@@ -224,7 +244,11 @@ void TanhBenchmark( ...@@ -224,7 +244,11 @@ void TanhBenchmark(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
}
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
...@@ -286,7 +310,11 @@ void SigmoidBenchmark( ...@@ -286,7 +310,11 @@ void SigmoidBenchmark(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
}
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
......
...@@ -269,16 +269,11 @@ void TestSimplePrelu() { ...@@ -269,16 +269,11 @@ void TestSimplePrelu() {
net.RunOp(D); net.RunOp(D);
} }
if (D == DeviceType::NEON) { if (D == DeviceType::CPU) {
auto expected = CreateTensor<float>( auto expected = CreateTensor<float>(
{2, 2, 2, 2}, {2, 2, 2, 2},
{-14, 7, -12, 6, -15, -15, -12, -12, -6, 3, -4, 2, -3, -3, 0, 0}); {-14, 7, -12, 6, -15, -15, -12, -12, -6, 3, -4, 2, -3, -3, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} else {
auto expected = CreateTensor<float>(
{2, 2, 2, 2},
{-14, 7, -12, 6, -10, -15, -8, -12, -6, 3, -4, 2, -2, -3, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} }
} }
} // namespace } // namespace
...@@ -287,10 +282,6 @@ TEST_F(ActivationOpTest, CPUSimplePrelu) { ...@@ -287,10 +282,6 @@ TEST_F(ActivationOpTest, CPUSimplePrelu) {
TestSimplePrelu<DeviceType::CPU>(); TestSimplePrelu<DeviceType::CPU>();
} }
TEST_F(ActivationOpTest, NEONSimplePrelu) {
TestSimplePrelu<DeviceType::NEON>();
}
TEST_F(ActivationOpTest, OPENCLSimplePrelu) { TEST_F(ActivationOpTest, OPENCLSimplePrelu) {
TestSimplePrelu<DeviceType::OPENCL>(); TestSimplePrelu<DeviceType::OPENCL>();
} }
......
...@@ -35,12 +35,6 @@ void Register_AddN(OperatorRegistry *op_registry) { ...@@ -35,12 +35,6 @@ void Register_AddN(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
AddNOp<DeviceType::OPENCL, half>); AddNOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
AddNOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -35,11 +35,6 @@ void Register_BatchNorm(OperatorRegistry *op_registry) { ...@@ -35,11 +35,6 @@ void Register_BatchNorm(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
BatchNormOp<DeviceType::OPENCL, half>); BatchNormOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
BatchNormOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -30,13 +30,29 @@ void BatchNorm( ...@@ -30,13 +30,29 @@ void BatchNorm(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, T>("Scale", {channels}); net.AddRandomInput<D, T>("Scale", {channels});
net.AddRandomInput<D, T>("Offset", {channels}); net.AddRandomInput<D, T>("Offset", {channels});
net.AddRandomInput<D, T>("Mean", {channels}); net.AddRandomInput<D, T>("Mean", {channels});
net.AddRandomInput<D, T>("Var", {channels}, true); net.AddRandomInput<D, T>("Var", {channels}, true);
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("BatchNorm", "BatchNormBM")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Scale", "ScaleImage", BufferToImage<D, float>(&net, "Scale", "ScaleImage",
...@@ -57,15 +73,7 @@ void BatchNorm( ...@@ -57,15 +73,7 @@ void BatchNorm(
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("BatchNorm", "BatchNormBM") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
} }
// tuning // tuning
......
...@@ -34,7 +34,22 @@ void Simple() { ...@@ -34,7 +34,22 @@ void Simple() {
net.AddInputFromArray<D, float>("Mean", {1}, {10}); net.AddInputFromArray<D, float>("Mean", {1}, {10});
net.AddInputFromArray<D, float>("Var", {1}, {11.67f}); net.AddInputFromArray<D, float>("Var", {1}, {11.67f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Scale", "ScaleImage", BufferToImage<D, float>(&net, "Scale", "ScaleImage",
...@@ -61,18 +76,6 @@ void Simple() { ...@@ -61,18 +76,6 @@ void Simple() {
// Transfer output // Transfer output
ImageToBuffer<D, float>(&net, "OutputImage", "Output", ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -97,17 +100,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { ...@@ -97,17 +100,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
index_t height = 64; index_t height = 64;
index_t width = 64; index_t width = 64;
// Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
...@@ -117,9 +110,30 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { ...@@ -117,9 +110,30 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
// Construct graph
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -170,15 +184,6 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -170,15 +184,6 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-1)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
...@@ -188,9 +193,29 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -188,9 +193,29 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-1)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -242,15 +267,6 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { ...@@ -242,15 +267,6 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
...@@ -260,9 +276,29 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { ...@@ -260,9 +276,29 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -313,15 +349,6 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -313,15 +349,6 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-1)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
...@@ -331,9 +358,29 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -331,9 +358,29 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-1)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -375,63 +422,6 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -375,63 +422,6 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-1, 1e-2); ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-1, 1e-2);
} }
TEST_F(BatchNormOpTest, NEONTest) {
srand(time(NULL));
unsigned int seed;
// generate random input
index_t batch = 1 + rand_r(&seed) % 10;
index_t channels = 3 + rand_r(&seed) % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels});
// run cpu
net.RunOp();
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNeon")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNeon")
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
// Run on neon
net.RunOp(DeviceType::NEON);
net.Sync();
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-5, 1e-4);
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -29,10 +29,22 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { ...@@ -29,10 +29,22 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) {
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, T>("Bias", {channels}, true); net.AddRandomInput<D, T>("Bias", {channels}, true);
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("BiasAdd", "BiasAddBM")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Bias", "BiasImage", BufferToImage<D, T>(&net, "Bias", "BiasImage",
...@@ -43,11 +55,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { ...@@ -43,11 +55,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) {
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("BiasAdd", "BiasAddBM") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
......
...@@ -31,7 +31,23 @@ void BiasAddSimple() { ...@@ -31,7 +31,23 @@ void BiasAddSimple() {
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
net.AddInputFromArray<D, float>("Bias", {1}, {0.5f}); net.AddInputFromArray<D, float>("Bias", {1}, {0.5f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Bias", "BiasImage", BufferToImage<D, float>(&net, "Bias", "BiasImage",
...@@ -49,13 +65,7 @@ void BiasAddSimple() { ...@@ -49,13 +65,7 @@ void BiasAddSimple() {
ImageToBuffer<D, float>(&net, "OutputImage", "Output", ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("BiasAdd", "BiasAddTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -81,22 +91,33 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { ...@@ -81,22 +91,33 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
index_t height = 64 + rand_r(&seed) % 50; index_t height = 64 + rand_r(&seed) % 50;
index_t width = 64 + rand_r(&seed) % 50; index_t width = 64 + rand_r(&seed) % 50;
// Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels}); "Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {channels}, true); net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
// Construct graph
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -130,22 +151,32 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { ...@@ -130,22 +151,32 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
index_t height = 103 + rand_r(&seed) % 100; index_t height = 103 + rand_r(&seed) % 100;
index_t width = 113 + rand_r(&seed) % 100; index_t width = 113 + rand_r(&seed) % 100;
// Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels}); "Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {channels}, true); net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
// Construct graph
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
......
...@@ -34,7 +34,14 @@ class ChannelShuffleOp : public Operator<D, T> { ...@@ -34,7 +34,14 @@ class ChannelShuffleOp : public Operator<D, T> {
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
int channels = input->dim(3); int channels;
if (D == OPENCL) {
channels = input->dim(3);
} else if (D == CPU) {
channels = input->dim(1);
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(channels % group_ == 0, MACE_CHECK(channels % group_ == 0,
"input channels must be an integral multiple of group. ", "input channels must be an integral multiple of group. ",
input->dim(3)); input->dim(3));
......
...@@ -29,9 +29,20 @@ void ChannelShuffle( ...@@ -29,9 +29,20 @@ void ChannelShuffle(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, height, channels, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("Softmax", "SoftmaxBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
...@@ -41,10 +52,7 @@ void ChannelShuffle( ...@@ -41,10 +52,7 @@ void ChannelShuffle(
.AddIntArg("group", group) .AddIntArg("group", group)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("Softmax", "SoftmaxBM") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
......
...@@ -22,21 +22,31 @@ namespace test { ...@@ -22,21 +22,31 @@ namespace test {
class ChannelShuffleOpTest : public OpsTestBase {}; class ChannelShuffleOpTest : public OpsTestBase {};
TEST_F(ChannelShuffleOpTest, C8G4_CPU) { TEST_F(ChannelShuffleOpTest, C8G4_CPU) {
// Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
.Input("Input")
.Output("Output")
.AddIntArg("group", 4)
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddInputFromArray<DeviceType::CPU, float>( net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 1, 2, 8}, "Input", {1, 1, 2, 8},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
// Construct graph
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("group", 4)
.Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
auto expected = CreateTensor<float>( auto expected = CreateTensor<float>(
......
...@@ -33,11 +33,6 @@ void Register_Concat(OperatorRegistry *op_registry) { ...@@ -33,11 +33,6 @@ void Register_Concat(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
ConcatOp<DeviceType::OPENCL, half>); ConcatOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
ConcatOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -35,12 +35,6 @@ void Register_Conv2D(OperatorRegistry *op_registry) { ...@@ -35,12 +35,6 @@ void Register_Conv2D(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
Conv2dOp<DeviceType::OPENCL, half>); Conv2dOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
Conv2dOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -41,21 +41,34 @@ void Conv2d(int iters, ...@@ -41,21 +41,34 @@ void Conv2d(int iters,
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
if (D == DeviceType::NEON) { if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width}); net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddRandomInput<D, float>("Filter", net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h, {output_channels, channels, kernel_h,
kernel_w}); kernel_w});
net.AddRandomInput<D, float>("Bias", {output_channels}); net.AddRandomInput<D, float>("Bias", {output_channels});
} else { } else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
net.AddRandomInput<D, float>("Filter", net.AddRandomInput<D, float>("Filter",
{kernel_h, kernel_w, output_channels, {kernel_h, kernel_w, output_channels,
channels}); channels});
net.AddRandomInput<D, float>("Bias", {output_channels}); net.AddRandomInput<D, float>("Bias", {output_channels});
} else {
MACE_NOT_IMPLEMENTED;
} }
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {dilation, dilation})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -73,16 +86,7 @@ void Conv2d(int iters, ...@@ -73,16 +86,7 @@ void Conv2d(int iters,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("Conv2D", "Conv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {dilation, dilation})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} }
net.Setup(D); net.Setup(D);
...@@ -134,7 +138,6 @@ void Conv2d(int iters, ...@@ -134,7 +138,6 @@ void Conv2d(int iters,
#define BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \ #define BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \ BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, NEON); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, OPENCL); \ BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, OPENCL); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, OPENCL); BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, OPENCL);
......
...@@ -38,7 +38,32 @@ void TestNHWCSimple3x3VALID() { ...@@ -38,7 +38,32 @@ void TestNHWCSimple3x3VALID() {
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f}); net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -63,18 +88,7 @@ void TestNHWCSimple3x3VALID() { ...@@ -63,18 +88,7 @@ void TestNHWCSimple3x3VALID() {
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("Conv2D", "Conv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
auto expected = CreateTensor<float>({1, 1, 1, 1}, {18.1f}); auto expected = CreateTensor<float>({1, 1, 1, 1}, {18.1f});
...@@ -95,7 +109,32 @@ void TestNHWCSimple3x3SAME() { ...@@ -95,7 +109,32 @@ void TestNHWCSimple3x3SAME() {
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f}); net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -120,18 +159,7 @@ void TestNHWCSimple3x3SAME() { ...@@ -120,18 +159,7 @@ void TestNHWCSimple3x3SAME() {
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("Conv2D", "Conv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
auto expected = CreateTensor<float>( auto expected = CreateTensor<float>(
...@@ -166,7 +194,32 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -166,7 +194,32 @@ void TestNHWCSimple3x3WithoutBias() {
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -187,18 +240,7 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -187,18 +240,7 @@ void TestNHWCSimple3x3WithoutBias() {
ImageToBuffer<D, T>(&net, "OutputImage", "Output", ImageToBuffer<D, T>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("Conv2D", "Conv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -234,7 +276,32 @@ void TestNHWCCombined3x3() { ...@@ -234,7 +276,32 @@ void TestNHWCCombined3x3() {
1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f}); 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f});
net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f}); net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -258,18 +325,7 @@ void TestNHWCCombined3x3() { ...@@ -258,18 +325,7 @@ void TestNHWCCombined3x3() {
ImageToBuffer<D, T>(&net, "OutputImage", "Output", ImageToBuffer<D, T>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("Conv2D", "Conv2DTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -309,7 +365,31 @@ void TestConv1x1() { ...@@ -309,7 +365,31 @@ void TestConv1x1() {
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}); {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f});
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f}); net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Filter", "FilterImage", BufferToImage<D, float>(&net, "Filter", "FilterImage",
...@@ -332,17 +412,7 @@ void TestConv1x1() { ...@@ -332,17 +412,7 @@ void TestConv1x1() {
ImageToBuffer<D, float>(&net, "OutputImage", "Output", ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("Conv2D", "Conv2DTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -377,27 +447,44 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape, ...@@ -377,27 +447,44 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape,
index_t width = shape[1]; index_t width = shape[1];
index_t input_channels = shape[2] + (rand_r(&seed) % 10); index_t input_channels = shape[2] + (rand_r(&seed) % 10);
index_t output_channels = shape[3] + (rand_r(&seed) % 10); index_t output_channels = shape[3] + (rand_r(&seed) % 10);
// Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("InputNCHW")
.Input("Filter") .Input("FilterOIHW")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -470,15 +557,6 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape, ...@@ -470,15 +557,6 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
index_t output_channels = filter_shape[3]; index_t output_channels = filter_shape[3];
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {dilations[0], dilations[1]})
.Finalize(net.NewOperatorDef());
std::vector<float> float_input_data; std::vector<float> float_input_data;
GenerateRandomRealTypeData({batch, height, width, input_channels}, GenerateRandomRealTypeData({batch, height, width, input_channels},
...@@ -497,8 +575,33 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape, ...@@ -497,8 +575,33 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
float_filter_data); float_filter_data);
net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data); net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {dilations[0], dilations[1]})
.Finalize(net.NewOperatorDef());
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -605,27 +708,44 @@ void TestDilationConvNxN(const std::vector<index_t> &shape, ...@@ -605,27 +708,44 @@ void TestDilationConvNxN(const std::vector<index_t> &shape,
index_t width = shape[1]; index_t width = shape[1];
index_t input_channels = shape[2]; index_t input_channels = shape[2];
index_t output_channels = shape[3]; index_t output_channels = shape[3];
// Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("InputNCHW")
.Input("Filter") .Input("FilterOIHW")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {dilation_rate, dilation_rate}) .AddIntsArg("dilations", {dilation_rate, dilation_rate})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -692,17 +812,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, ...@@ -692,17 +812,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape,
index_t width = shape[1]; index_t width = shape[1];
index_t input_channels = shape[2]; index_t input_channels = shape[2];
index_t output_channels = shape[3]; index_t output_channels = shape[3];
// Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels}); net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
...@@ -710,8 +821,33 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, ...@@ -710,8 +821,33 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape,
"Filter", {kernel_h, kernel_w, output_channels, input_channels}); "Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels}); net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -763,83 +899,6 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedPad4) { ...@@ -763,83 +899,6 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedPad4) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({107, 113, 5, 7}, {4, 4}); TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({107, 113, 5, 7}, {4, 4});
} }
static void TestNeonArbitraryPadConvNxN(const std::vector<index_t> &shape,
const std::vector<int> &paddings) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w) {
srand(time(NULL));
// generate random input
index_t batch = 1;
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2];
index_t output_channels = shape[3];
// Construct graph
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTestCPU")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input",
{batch, height, width,
input_channels});
net.AddRandomInput<DeviceType::CPU, float>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<DeviceType::CPU, float>("Bias", {output_channels});
// run cpu
net.RunOp();
// run neon
OpDefBuilder("Conv2D", "Conv2dTestNEON")
.Input("InputNeon")
.Input("FilterNeon")
.Input("Bias")
.Output("OutputNeon")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
net.FillHWOIInputToOIHWInput<DeviceType::CPU, float>("FilterNeon",
"Filter");
// Run on device
net.RunOp(DeviceType::NEON);
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-5, 1e-3);
};
for (int kernel_size : {1, 3, 5}) {
for (int stride : {1, 2}) {
if (stride <= kernel_size) {
func(kernel_size, kernel_size, stride, stride);
}
}
}
}
TEST_F(Conv2dOpTest, NEONTest) {
TestNeonArbitraryPadConvNxN({32, 34, 32, 64}, {0, 0});
TestNeonArbitraryPadConvNxN({32, 32, 32, 64}, {1, 1});
TestNeonArbitraryPadConvNxN({128, 128, 16, 16}, {2, 2});
TestNeonArbitraryPadConvNxN({107, 113, 5, 7}, {4, 4});
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -37,10 +37,17 @@ class DepthToSpaceOp : public Operator<D, T> { ...@@ -37,10 +37,17 @@ class DepthToSpaceOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
int input_depth = input->dim(3); int input_depth;
if (D == CPU) {
input_depth = input->dim(1);
} else if (D == OPENCL) {
input_depth = input->dim(3);
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(input_depth % (block_size_ * block_size_) == 0, MACE_CHECK(input_depth % (block_size_ * block_size_) == 0,
"input depth should be dividable by block_size * block_size", "input depth should be dividable by block_size * block_size",
input->dim(3)); input_depth);
MACE_CHECK((input_depth % 4) == 0, MACE_CHECK((input_depth % 4) == 0,
"input channel should be dividable by 4"); "input channel should be dividable by 4");
functor_(input, output, future); functor_(input, output, future);
......
...@@ -29,9 +29,20 @@ void DepthToSpace( ...@@ -29,9 +29,20 @@ void DepthToSpace(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("DepthToSpace", "DepthToSpaceBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
...@@ -41,10 +52,7 @@ void DepthToSpace( ...@@ -41,10 +52,7 @@ void DepthToSpace(
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("DepthToSpace", "DepthToSpaceBM") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
......
...@@ -36,11 +36,21 @@ void RunDepthToSpace(const bool d2s, ...@@ -36,11 +36,21 @@ void RunDepthToSpace(const bool d2s,
const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest"; const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest";
// Construct graph // Construct graph
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder(ops_name, ops_test_name) OpDefBuilder(ops_name, ops_test_name)
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else { } else {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
...@@ -50,9 +60,10 @@ void RunDepthToSpace(const bool d2s, ...@@ -50,9 +60,10 @@ void RunDepthToSpace(const bool d2s,
.Output("OutputImage") .Output("OutputImage")
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Run
net.RunOp(D);
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "Output", ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "Output",
...@@ -176,22 +187,31 @@ void RandomTest(const bool d2s, const int block_size, ...@@ -176,22 +187,31 @@ void RandomTest(const bool d2s, const int block_size,
const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest"; const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest";
// Add input data // Add input data
net.AddRandomInput<D, float>("Input1", shape); net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder(ops_name, ops_test_name) OpDefBuilder(ops_name, ops_test_name)
.Input("Input1") .Input("InputNCHW")
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
.Output("Output") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(); net.RunOp();
BufferToImage<D, T>(&net, "Input1", "InputImg1", net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
BufferToImage<D, T>(&net, "Input", "InputImg",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder(ops_name, ops_test_name) OpDefBuilder(ops_name, ops_test_name)
.Input("InputImg1") .Input("InputImg")
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("OutputImg") .Output("OutputImg")
......
...@@ -35,12 +35,6 @@ void Register_DepthwiseConv2d(OperatorRegistry *op_registry) { ...@@ -35,12 +35,6 @@ void Register_DepthwiseConv2d(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
DepthwiseConv2dOp<DeviceType::OPENCL, half>); DepthwiseConv2dOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
DepthwiseConv2dOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -40,21 +40,34 @@ void DepthwiseConv2d(int iters, ...@@ -40,21 +40,34 @@ void DepthwiseConv2d(int iters,
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
if (D == DeviceType::NEON) { if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, input_channels, height, width}); {batch, input_channels, height, width});
net.AddRandomInput<D, float>( net.AddRandomInput<D, float>(
"Filter", {multiplier, input_channels, kernel_h, kernel_w}); "Filter", {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier}); net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
} else { } else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, height, width, input_channels}); {batch, height, width, input_channels});
net.AddRandomInput<D, float>( net.AddRandomInput<D, float>(
"Filter", {kernel_h, kernel_w, input_channels, multiplier}); "Filter", {kernel_h, kernel_w, input_channels, multiplier});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier}); net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
} else {
MACE_NOT_IMPLEMENTED;
} }
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -72,16 +85,7 @@ void DepthwiseConv2d(int iters, ...@@ -72,16 +85,7 @@ void DepthwiseConv2d(int iters,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} }
net.Setup(D); net.Setup(D);
...@@ -131,8 +135,7 @@ void DepthwiseConv2d(int iters, ...@@ -131,8 +135,7 @@ void DepthwiseConv2d(int iters,
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \ #define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \ BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, OPENCL); \ BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, OPENCL); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, OPENCL); \ BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, OPENCL);
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, NEON);
BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1); BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 32, 56, 56, 3, 3, 2, VALID, 1); BM_DEPTHWISE_CONV_2D(1, 32, 56, 56, 3, 3, 2, VALID, 1);
......
...@@ -35,7 +35,31 @@ void SimpleValidTest() { ...@@ -35,7 +35,31 @@ void SimpleValidTest() {
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, float>(
"Filter", {2, 2, 2, 1}, {1.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f, 8.0f}); "Filter", {2, 2, 2, 1}, {1.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f, 8.0f});
net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f}); net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWIO,
"FilterOIHW",
OIHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -60,17 +84,7 @@ void SimpleValidTest() { ...@@ -60,17 +84,7 @@ void SimpleValidTest() {
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -144,7 +158,33 @@ void ComplexValidTest() { ...@@ -144,7 +158,33 @@ void ComplexValidTest() {
0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74}); 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74});
net.AddInputFromArray<D, float>("Bias", {6}, net.AddInputFromArray<D, float>("Bias", {6},
{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWIO,
"FilterOIHW",
OIHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -169,18 +209,7 @@ void ComplexValidTest() { ...@@ -169,18 +209,7 @@ void ComplexValidTest() {
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -224,7 +253,7 @@ TEST_F(DepthwiseConv2dOpTest, ComplexOpenCLHalf) { ...@@ -224,7 +253,7 @@ TEST_F(DepthwiseConv2dOpTest, ComplexOpenCLHalf) {
} }
namespace { namespace {
template<DeviceType D, typename T> template<typename T>
void TestNxNS12(const index_t height, const index_t width) { void TestNxNS12(const index_t height, const index_t width) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
...@@ -238,16 +267,28 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -238,16 +267,28 @@ void TestNxNS12(const index_t height, const index_t width) {
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", net.AddRandomInput<DeviceType::OPENCL, float>("Input",
{batch, height, width, input_channels}); {batch, height, width,
net.AddRandomInput<D, float>( input_channels});
net.AddRandomInput<DeviceType::OPENCL, float>(
"Filter", {kernel_h, kernel_w, input_channels, multiplier}); "Filter", {kernel_h, kernel_w, input_channels, multiplier});
net.AddRandomInput<D, float>("Bias", {multiplier * input_channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Bias",
{multiplier
* input_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWIO,
"FilterOIHW",
OIHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("InputNCHW")
.Input("Filter") .Input("FilterOIHW")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
...@@ -256,48 +297,41 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -256,48 +297,41 @@ void TestNxNS12(const index_t height, const index_t width) {
// Run on cpu // Run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
if (D == DeviceType::OPENCL) { BufferToImage<DeviceType::OPENCL, T>(&net, "Input", "InputImage",
BufferToImage<D, T>(&net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL);
kernels::BufferType::IN_OUT_CHANNEL); BufferToImage<DeviceType::OPENCL, T>(&net, "Filter", "FilterImage",
BufferToImage<D, T>(&net, "Filter", "FilterImage", kernels::BufferType::DW_CONV2D_FILTER);
kernels::BufferType::DW_CONV2D_FILTER); BufferToImage<DeviceType::OPENCL, T>(&net, "Bias", "BiasImage",
BufferToImage<D, T>(&net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
kernels::BufferType::ARGUMENT); OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") .Input("InputImage")
.Input("InputImage") .Input("FilterImage")
.Input("FilterImage") .Input("BiasImage")
.Input("BiasImage") .Output("OutputImage")
.Output("OutputImage") .AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntArg("padding", type)
.AddIntArg("padding", type) .AddIntsArg("dilations", {1, 1})
.AddIntsArg("dilations", {1, 1}) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .Finalize(net.NewOperatorDef());
.Finalize(net.NewOperatorDef());
net.RunOp(DeviceType::OPENCL);
net.RunOp(D);
// Transfer output
// Transfer output ImageToBuffer<DeviceType::OPENCL, float>(&net,
ImageToBuffer<D, float>(&net, "OutputImage", "DeviceOutput", "OutputImage",
kernels::BufferType::IN_OUT_CHANNEL); "DeviceOutput",
} else { kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("DeviceOutput")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
// Check // Check
if (DataTypeToEnum<T>::value == DT_FLOAT) { if (DataTypeToEnum<T>::value == DT_FLOAT) {
...@@ -319,109 +353,27 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -319,109 +353,27 @@ void TestNxNS12(const index_t height, const index_t width) {
} // namespace } // namespace
TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12) { TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12) {
TestNxNS12<DeviceType::OPENCL, float>(4, 4); TestNxNS12<float>(4, 4);
} }
TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12Half) { TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12Half) {
TestNxNS12<DeviceType::OPENCL, half>(4, 4); TestNxNS12<half>(4, 4);
} }
TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12) { TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12) {
TestNxNS12<DeviceType::OPENCL, float>(128, 128); TestNxNS12<float>(128, 128);
} }
TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12Half) { TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12Half) {
TestNxNS12<DeviceType::OPENCL, half>(128, 128); TestNxNS12<half>(128, 128);
} }
TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12) { TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12) {
TestNxNS12<DeviceType::OPENCL, float>(107, 113); TestNxNS12<float>(107, 113);
} }
TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12Half) { TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12Half) {
TestNxNS12<DeviceType::OPENCL, half>(107, 113); TestNxNS12<half>(107, 113);
}
namespace {
void TestNEONNxNS12(const index_t height,
const index_t width,
const index_t input_channels,
const index_t multiplier) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
// generate random input
index_t batch = 1;
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<CPU, float>("Input",
{batch, height, width, input_channels});
net.AddRandomInput<CPU, float>(
"Filter", {kernel_h, kernel_w, input_channels, multiplier});
net.AddRandomInput<CPU, float>("Bias", {multiplier * input_channels});
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
// Run on cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNeon")
.Input("FilterNeon")
.Input("Bias")
.Output("OutputNeon")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
net.FillHWIOInputToOIHWInput<DeviceType::CPU, float>("FilterNeon",
"Filter");
// Run
net.RunOp(NEON);
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
// Check
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-5, 1e-3);
};
for (int kernel_size : {1, 3, 5}) {
for (int stride : {1, 2}) {
if (kernel_size > stride) {
func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME);
}
}
}
}
} // namespace
TEST_F(DepthwiseConv2dOpTest, NEONTest) {
TestNEONNxNS12(4, 4, 32, 1);
TestNEONNxNS12(64, 64, 32, 1);
TestNEONNxNS12(112, 112, 32, 1);
TestNEONNxNS12(128, 128, 15, 1);
TestNEONNxNS12(107, 113, 15, 1);
} }
} // namespace test } // namespace test
......
...@@ -35,11 +35,6 @@ void Register_Eltwise(OperatorRegistry *op_registry) { ...@@ -35,11 +35,6 @@ void Register_Eltwise(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
EltwiseOp<DeviceType::OPENCL, half>); EltwiseOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
EltwiseOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -35,11 +35,6 @@ void Register_FoldedBatchNorm(OperatorRegistry *op_registry) { ...@@ -35,11 +35,6 @@ void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
FoldedBatchNormOp<DeviceType::OPENCL, half>); FoldedBatchNormOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -36,7 +36,7 @@ void CalculateScaleOffset(const std::vector<float> &gamma, ...@@ -36,7 +36,7 @@ void CalculateScaleOffset(const std::vector<float> &gamma,
} }
} }
template <DeviceType D> template<DeviceType D>
void Simple() { void Simple() {
OpsTestNet net; OpsTestNet net;
...@@ -49,7 +49,18 @@ void Simple() { ...@@ -49,7 +49,18 @@ void Simple() {
net.AddInputFromArray<D, float>("Scale", {1}, scale); net.AddInputFromArray<D, float>("Scale", {1}, scale);
net.AddInputFromArray<D, float>("Offset", {1}, offset); net.AddInputFromArray<D, float>("Offset", {1}, offset);
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Scale", "ScaleImage", BufferToImage<D, float>(&net, "Scale", "ScaleImage",
...@@ -58,33 +69,24 @@ void Simple() { ...@@ -58,33 +69,24 @@ void Simple() {
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage") .Input("InputImage")
.Input("ScaleImage") .Input("ScaleImage")
.Input("OffsetImage") .Input("OffsetImage")
.Output("OutputImage") .Output("OutputImage")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
// Transfer output // Transfer output
ImageToBuffer<D, float>(&net, "OutputImage", "Output", ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
auto expected = auto expected =
CreateTensor<float>({1, 6, 2, 1}, {-3.8543, -3.8543, -1.5125, -1.5125, CreateTensor<float>({1, 6, 2, 1}, {-3.8543, -3.8543, -1.5125, -1.5125,
0.8291, 0.8291, 3.1708, 3.1708, 0.8291, 0.8291, 3.1708, 3.1708,
5.5125, 5.5125, 7.8543, 7.8543}); 5.5125, 5.5125, 7.8543, 7.8543});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-4); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-4);
} }
...@@ -92,100 +94,8 @@ void Simple() { ...@@ -92,100 +94,8 @@ void Simple() {
TEST_F(FoldedBatchNormOpTest, SimpleCPU) { Simple<DeviceType::CPU>(); } TEST_F(FoldedBatchNormOpTest, SimpleCPU) { Simple<DeviceType::CPU>(); }
/*
TEST_F(FoldedBatchNormOpTest, SimpleNEON) {
Simple<DeviceType::NEON>();
}
*/
TEST_F(FoldedBatchNormOpTest, SimpleOPENCL) { Simple<DeviceType::OPENCL>(); } TEST_F(FoldedBatchNormOpTest, SimpleOPENCL) { Simple<DeviceType::OPENCL>(); }
/*
TEST_F(FoldedBatchNormOpTest, SimpleRandomNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, channels, height,
width});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(FoldedBatchNormOpTest, ComplexRandomNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 103;
index_t width = 113;
// Construct graph
OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, channels, height,
width});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
*/
TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) {
// generate random input // generate random input
static unsigned int seed = time(NULL); static unsigned int seed = time(NULL);
...@@ -196,22 +106,33 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { ...@@ -196,22 +106,33 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels}); "Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -225,11 +146,11 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { ...@@ -225,11 +146,11 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) {
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage") .Input("InputImage")
.Input("ScaleImage") .Input("ScaleImage")
.Input("OffsetImage") .Input("OffsetImage")
.Output("OutputImage") .Output("OutputImage")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on opencl // Run on opencl
net.RunOp(DeviceType::OPENCL); net.RunOp(DeviceType::OPENCL);
...@@ -250,22 +171,33 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -250,22 +171,33 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels}); "Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -279,12 +211,12 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -279,12 +211,12 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) {
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage") .Input("InputImage")
.Input("ScaleImage") .Input("ScaleImage")
.Input("OffsetImage") .Input("OffsetImage")
.Output("OutputImage") .Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataType::DT_HALF)) .AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on opencl // Run on opencl
net.RunOp(DeviceType::OPENCL); net.RunOp(DeviceType::OPENCL);
...@@ -305,22 +237,33 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { ...@@ -305,22 +237,33 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels}); "Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -334,11 +277,11 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { ...@@ -334,11 +277,11 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) {
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage") .Input("InputImage")
.Input("ScaleImage") .Input("ScaleImage")
.Input("OffsetImage") .Input("OffsetImage")
.Output("OutputImage") .Output("OutputImage")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on opencl // Run on opencl
net.RunOp(DeviceType::OPENCL); net.RunOp(DeviceType::OPENCL);
...@@ -358,22 +301,33 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -358,22 +301,33 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels}); "Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels}); net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -387,12 +341,12 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -387,12 +341,12 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) {
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage") .Input("InputImage")
.Input("ScaleImage") .Input("ScaleImage")
.Input("OffsetImage") .Input("OffsetImage")
.Output("OutputImage") .Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataType::DT_HALF)) .AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on opencl // Run on opencl
net.RunOp(DeviceType::OPENCL); net.RunOp(DeviceType::OPENCL);
......
...@@ -35,12 +35,6 @@ void Register_FullyConnected(OperatorRegistry *op_registry) { ...@@ -35,12 +35,6 @@ void Register_FullyConnected(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
FullyConnectedOp<DeviceType::OPENCL, half>); FullyConnectedOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FullyConnectedOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -23,19 +23,19 @@ ...@@ -23,19 +23,19 @@
namespace mace { namespace mace {
namespace ops { namespace ops {
template <DeviceType D, class T> template<DeviceType D, class T>
class FullyConnectedOp : public Operator<D, T> { class FullyConnectedOp : public Operator<D, T> {
public: public:
FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws) FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(static_cast<kernels::BufferType>( functor_(static_cast<kernels::BufferType>(
OperatorBase::GetSingleArgument<int>( OperatorBase::GetSingleArgument<int>(
"weight_type", static_cast<int>( "weight_type", static_cast<int>(
kernels::WEIGHT_WIDTH))), kernels::WEIGHT_WIDTH))),
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {} OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
...@@ -44,8 +44,17 @@ class FullyConnectedOp : public Operator<D, T> { ...@@ -44,8 +44,17 @@ class FullyConnectedOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
const index_t input_size = input->dim(1) * input->dim(2) * input->dim(3); const index_t input_size = input->dim(1) * input->dim(2) * input->dim(3);
MACE_CHECK(input_size == weight->dim(1) && weight->dim(0) == bias->dim(0)) MACE_CHECK(input_size == weight->dim(1) && weight->dim(0) == bias->dim(0),
<< "The size of Input, Weight and Bias don't match."; "The size of Input: ",
input_size,
" Weight: ",
weight->dim(1),
",",
weight->dim(
0),
" and Bias ",
bias->dim(0),
" don't match.");
functor_(input, weight, bias, output, future); functor_(input, weight, bias, output, future);
return true; return true;
......
...@@ -36,8 +36,14 @@ void FCBenchmark( ...@@ -36,8 +36,14 @@ void FCBenchmark(
{out_channel, height * width * channel}); {out_channel, height * width * channel});
net.AddRandomInput<D, float>("Bias", {out_channel}); net.AddRandomInput<D, float>("Bias", {out_channel});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
const int width_size = height * width * channel; OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
kernels::BufferType weight_type = kernels::BufferType::WEIGHT_WIDTH; kernels::BufferType weight_type = kernels::BufferType::WEIGHT_WIDTH;
BufferToImage<D, T>(&net, "Weight", "WeightImage", BufferToImage<D, T>(&net, "Weight", "WeightImage",
weight_type); weight_type);
...@@ -55,12 +61,7 @@ void FCBenchmark( ...@@ -55,12 +61,7 @@ void FCBenchmark(
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("FC", "FullyConnectedTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
......
...@@ -40,7 +40,18 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -40,7 +40,18 @@ void Simple(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Weight", weight_shape, weight_value); net.AddInputFromArray<D, float>("Weight", weight_shape, weight_value);
net.AddInputFromArray<D, float>("Bias", bias_shape, bias_value); net.AddInputFromArray<D, float>("Bias", bias_shape, bias_value);
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.Transpose2D<D, float>("Weight", "WeightTranspose");
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Weight", "WeightImage", BufferToImage<D, float>(&net, "Weight", "WeightImage",
...@@ -62,14 +73,7 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -62,14 +73,7 @@ void Simple(const std::vector<index_t> &input_shape,
ImageToBuffer<D, float>(&net, "OutputImage", "Output", ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("FC", "FullyConnectedTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -130,12 +134,6 @@ void Complex(const index_t batch, ...@@ -130,12 +134,6 @@ void Complex(const index_t batch,
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
...@@ -144,9 +142,18 @@ void Complex(const index_t batch, ...@@ -144,9 +142,18 @@ void Complex(const index_t batch,
"Weight", {out_channel, height * width * channels}); "Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel}); net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel});
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -216,12 +223,6 @@ void TestWXFormat(const index_t batch, ...@@ -216,12 +223,6 @@ void TestWXFormat(const index_t batch,
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
...@@ -230,9 +231,18 @@ void TestWXFormat(const index_t batch, ...@@ -230,9 +231,18 @@ void TestWXFormat(const index_t batch,
"Weight", {out_channel, height * width * channels}); "Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel}); net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel});
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -286,61 +296,6 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) { ...@@ -286,61 +296,6 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) {
TestWXFormat<half>(1, 16, 32, 32, 32); TestWXFormat<half>(1, 16, 32, 32, 32);
} }
namespace {
void FullyConnectedTestNEON(const index_t batch,
const index_t height,
const index_t width,
const index_t channels,
const index_t out_channel) {
srand(time(NULL));
// Construct graph
OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>(
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::CPU, float>("Bias", {out_channel});
// run cpu
net.RunOp();
// Run on neon
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("OutputNeon")
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(DeviceType::NEON);
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-3, 1e-3);
}
} // namespace
TEST_F(FullyConnectedOpTest, TestNEON) {
FullyConnectedTestNEON(1, 7, 7, 32, 16);
FullyConnectedTestNEON(1, 7, 7, 512, 128);
FullyConnectedTestNEON(1, 1, 1, 2048, 1024);
FullyConnectedTestNEON(3, 1, 1, 16, 8);
FullyConnectedTestNEON(3, 7, 7, 32, 16);
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -35,11 +35,6 @@ void Register_FusedConv2D(OperatorRegistry *op_registry) { ...@@ -35,11 +35,6 @@ void Register_FusedConv2D(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
FusedConv2dOp<DeviceType::OPENCL, half>); FusedConv2dOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FusedConv2D")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FusedConv2dOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -37,7 +37,33 @@ void TestNHWCSimple3x3VALID() { ...@@ -37,7 +37,33 @@ void TestNHWCSimple3x3VALID() {
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, float>("Bias", {1}, {-0.1f}); net.AddInputFromArray<D, float>("Bias", {1}, {-0.1f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -63,19 +89,7 @@ void TestNHWCSimple3x3VALID() { ...@@ -63,19 +89,7 @@ void TestNHWCSimple3x3VALID() {
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("FusedConv2D", "FusedConv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
auto expected = CreateTensor<float>({1, 1, 1, 1}, {0.0f}); auto expected = CreateTensor<float>({1, 1, 1, 1}, {0.0f});
...@@ -96,7 +110,33 @@ void TestNHWCSimple3x3SAME() { ...@@ -96,7 +110,33 @@ void TestNHWCSimple3x3SAME() {
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, float>("Bias", {1}, {-0.1f}); net.AddInputFromArray<D, float>("Bias", {1}, {-0.1f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -122,19 +162,7 @@ void TestNHWCSimple3x3SAME() { ...@@ -122,19 +162,7 @@ void TestNHWCSimple3x3SAME() {
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("FusedConv2D", "FusedConv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
auto expected = CreateTensor<float>( auto expected = CreateTensor<float>(
...@@ -168,7 +196,33 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -168,7 +196,33 @@ void TestNHWCSimple3x3WithoutBias() {
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage", BufferToImage<D, T>(&net, "Filter", "FilterImage",
...@@ -190,19 +244,7 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -190,19 +244,7 @@ void TestNHWCSimple3x3WithoutBias() {
ImageToBuffer<D, float>(&net, "OutputImage", "Output", ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("FusedConv2D", "FusedConv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -241,7 +283,31 @@ void TestConv1x1() { ...@@ -241,7 +283,31 @@ void TestConv1x1() {
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}); {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f});
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f}); net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Filter", "FilterImage", BufferToImage<D, float>(&net, "Filter", "FilterImage",
...@@ -264,17 +330,7 @@ void TestConv1x1() { ...@@ -264,17 +330,7 @@ void TestConv1x1() {
ImageToBuffer<D, float>(&net, "OutputImage", "Output", ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("FusedConv2D", "FusedConv2dTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
// Check // Check
...@@ -308,27 +364,43 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -308,27 +364,43 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
index_t width = shape[1]; index_t width = shape[1];
index_t input_channels = shape[2] + (rand_r(&seed) % 10); index_t input_channels = shape[2] + (rand_r(&seed) % 10);
index_t output_channels = shape[3] + (rand_r(&seed) % 10); index_t output_channels = shape[3] + (rand_r(&seed) % 10);
// Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("FusedConv2D", "FusedConv2dTest") OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("Input") .Input("InputNCHW")
.Input("Filter") .Input("FilterOIHW")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -386,17 +458,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape, ...@@ -386,17 +458,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape,
index_t width = shape[1]; index_t width = shape[1];
index_t input_channels = shape[2]; index_t input_channels = shape[2];
index_t output_channels = shape[3]; index_t output_channels = shape[3];
// Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
std::vector<float> float_input_data; std::vector<float> float_input_data;
GenerateRandomRealTypeData({batch, height, width, input_channels}, GenerateRandomRealTypeData({batch, height, width, input_channels},
...@@ -415,8 +478,33 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape, ...@@ -415,8 +478,33 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape,
float_filter_data); float_filter_data);
net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data); net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -479,27 +567,42 @@ void TestGeneralConvNxNS12(const std::vector<index_t> &image_shape, ...@@ -479,27 +567,42 @@ void TestGeneralConvNxNS12(const std::vector<index_t> &image_shape,
index_t kernel_w = filter_shape[1]; index_t kernel_w = filter_shape[1];
index_t output_channels = filter_shape[2]; index_t output_channels = filter_shape[2];
index_t input_channels = filter_shape[3]; index_t input_channels = filter_shape[3];
// Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("FusedConv2D", "FusedConv2dTest") OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("Input") .Input("InputNCHW")
.Input("Filter") .Input("FilterOIHW")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -561,27 +664,44 @@ void TestAtrousConvNxN(const std::vector<index_t> &shape, ...@@ -561,27 +664,44 @@ void TestAtrousConvNxN(const std::vector<index_t> &shape,
index_t width = shape[1]; index_t width = shape[1];
index_t output_channels = shape[2]; index_t output_channels = shape[2];
index_t input_channels = shape[3]; index_t input_channels = shape[3];
// Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("FusedConv2D", "FusedConv2dTest") OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("Input") .Input("InputNCHW")
.Input("Filter") .Input("FilterOIHW")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {dilation, dilation}) .AddIntsArg("dilations", {dilation, dilation})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -651,17 +771,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape, ...@@ -651,17 +771,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape,
index_t kernel_w = filter_shape[1]; index_t kernel_w = filter_shape[1];
index_t output_channels = filter_shape[2]; index_t output_channels = filter_shape[2];
index_t input_channels = filter_shape[3]; index_t input_channels = filter_shape[3];
// Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
...@@ -670,8 +781,33 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape, ...@@ -670,8 +781,33 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape,
"Filter", {kernel_h, kernel_w, output_channels, input_channels}); "Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, float>("Bias", {output_channels}); net.AddRandomInput<D, float>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -718,81 +854,6 @@ TEST_F(FusedConv2dOpTest, OPENCL15X15AtrousConvD4) { ...@@ -718,81 +854,6 @@ TEST_F(FusedConv2dOpTest, OPENCL15X15AtrousConvD4) {
{2, 2}); {2, 2});
} }
namespace {
void TestNEONGeneralConvNxNS12(
const std::vector<index_t> &image_shape,
const std::vector<index_t> &filter_shape) {
testing::internal::LogToStderr();
auto func = [&](int stride_h, int stride_w, Padding type) {
srand(time(NULL));
// generate random input
index_t batch = 1;
index_t height = image_shape[0];
index_t width = image_shape[1];
index_t kernel_h = filter_shape[0];
index_t kernel_w = filter_shape[1];
index_t output_channels = filter_shape[2];
index_t input_channels = filter_shape[3];
// Construct graph
OpsTestNet net;
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<CPU, float>("Input",
{batch, height, width, input_channels});
net.AddRandomInput<CPU, float>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<CPU, float>("Bias", {output_channels});
// run on cpu
net.RunOp();
OpDefBuilder("FusedConv2D", "FusedConv2dTest")
.Input("InputNeon")
.Input("FilterNeon")
.Input("Bias")
.Output("OutputNeon")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<float>::value))
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
net.FillHWOIInputToOIHWInput<DeviceType::CPU, float>("FilterNeon",
"Filter");
// Run on device
net.RunOp(DeviceType::NEON);
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-5, 1e-4);
};
for (int stride : {1, 2}) {
func(stride, stride, VALID);
func(stride, stride, SAME);
}
}
} // namespace
TEST_F(FusedConv2dOpTest, NEONTest) {
TestNEONGeneralConvNxNS12({32, 32}, {7, 7, 64, 3});
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/global_avg_pooling.h"
namespace mace {
namespace ops {
void Register_GlobalAvgPooling(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("GlobalAvgPooling")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
GlobalAvgPoolingOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_GLOBAL_AVG_POOLING_H_
#define MACE_OPS_GLOBAL_AVG_POOLING_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/global_avg_pooling.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class GlobalAvgPoolingOp : public Operator<D, T> {
public:
GlobalAvgPoolingOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
std::vector<index_t> output_shape(4);
output_shape[0] = input->shape()[0];
output_shape[1] = input->shape()[1];
output_shape[2] = output_shape[3] = 1;
output->Resize(output_shape);
auto pooling_func = kernels::GlobalAvgPoolingFunctor<D, T>();
pooling_func(input->data<float>(), input->shape().data(),
output->mutable_data<float>(), future);
return true;
}
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_GLOBAL_AVG_POOLING_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/global_avg_pooling.h"
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D>
void GlobalAvgPooling(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input",
{batch, channels, height, width});
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
} // namespace
#define BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, DEVICE) \
static void BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(float))); \
GlobalAvgPooling<DEVICE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_GLOBAL_AVG_POOLING_##N##_##C##_##H##_##W##_##DEVICE)
#define BM_GLOBAL_AVG_POOLING(N, C, H, W) \
BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, CPU);
// BM_GLOBAL_AVG_POOLING_MACRO(N, C, H, W, NEON);
BM_GLOBAL_AVG_POOLING(1, 3, 7, 7);
BM_GLOBAL_AVG_POOLING(1, 3, 64, 64);
BM_GLOBAL_AVG_POOLING(1, 3, 256, 256);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class GlobalAvgPoolingOpTest : public OpsTestBase {};
TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) {
// Construct graph
OpsTestNet net;
OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
std::vector<float> input(147);
for (int i = 0; i < 147; ++i) {
input[i] = i / 49 + 1;
}
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 7, 7}, input);
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>({1, 3, 1, 1}, {1, 2, 3});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -13,12 +13,6 @@ void Register_LocalResponseNorm(OperatorRegistry *op_registry) { ...@@ -13,12 +13,6 @@ void Register_LocalResponseNorm(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
LocalResponseNormOp<DeviceType::CPU, float>); LocalResponseNormOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("LocalResponseNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
LocalResponseNormOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -18,7 +18,9 @@ static void LocalResponseNorm( ...@@ -18,7 +18,9 @@ static void LocalResponseNorm(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
}
OpDefBuilder("LocalResponseNorm", "LocalResponseNormBM") OpDefBuilder("LocalResponseNorm", "LocalResponseNormBM")
.Input("Input") .Input("Input")
...@@ -54,8 +56,7 @@ static void LocalResponseNorm( ...@@ -54,8 +56,7 @@ static void LocalResponseNorm(
BENCHMARK(BM_LOCAL_RESPONSE_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) BENCHMARK(BM_LOCAL_RESPONSE_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_LOCAL_RESPONSE_NORM(N, C, H, W) \ #define BM_LOCAL_RESPONSE_NORM(N, C, H, W) \
BM_LOCAL_RESPONSE_NORM_MACRO(N, C, H, W, float, CPU); \ BM_LOCAL_RESPONSE_NORM_MACRO(N, C, H, W, float, CPU);
BM_LOCAL_RESPONSE_NORM_MACRO(N, C, H, W, float, NEON);
BM_LOCAL_RESPONSE_NORM(1, 1, 512, 512); BM_LOCAL_RESPONSE_NORM(1, 1, 512, 512);
BM_LOCAL_RESPONSE_NORM(1, 3, 128, 128); BM_LOCAL_RESPONSE_NORM(1, 3, 128, 128);
......
...@@ -19,16 +19,21 @@ void Simple() { ...@@ -19,16 +19,21 @@ void Simple() {
net.AddInputFromArray<D, float>("Input", {1, 1, 2, 6}, net.AddInputFromArray<D, float>("Input", {1, 1, 2, 6},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest") if (D == DeviceType::CPU) {
.Input("Input") net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest")
.Input("InputNCHW")
.AddIntArg("depth_radius", 5) .AddIntArg("depth_radius", 5)
.AddFloatArg("bias", 1.0f) .AddFloatArg("bias", 1.0f)
.AddFloatArg("alpha", 1.0f) .AddFloatArg("alpha", 1.0f)
.AddFloatArg("beta", 0.5f) .AddFloatArg("beta", 0.5f)
.Output("Output") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC);
}
// Check // Check
auto expected = auto expected =
...@@ -40,57 +45,6 @@ void Simple() { ...@@ -40,57 +45,6 @@ void Simple() {
TEST_F(LocalResponseNormOpTest, SimpleCPU) { Simple<DeviceType::CPU>(); } TEST_F(LocalResponseNormOpTest, SimpleCPU) { Simple<DeviceType::CPU>(); }
TEST_F(LocalResponseNormOpTest, NEONTest) {
srand(time(NULL));
unsigned int seed;
// generate random input
index_t batch = 1 + rand_r(&seed) % 10;
index_t channels = 3 + rand_r(&seed) % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest")
.Input("Input")
.AddIntArg("depth_radius", 5)
.AddFloatArg("bias", 1.0f)
.AddFloatArg("alpha", 1.0f)
.AddFloatArg("beta", 0.5f)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>(
"Input", {batch, height, width, channels});
// run cpu
net.RunOp();
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest")
.Input("InputNeon")
.AddIntArg("depth_radius", 5)
.AddFloatArg("bias", 1.0f)
.AddFloatArg("alpha", 1.0f)
.AddFloatArg("beta", 0.5f)
.Output("OutputNeon")
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
// Run on neon
net.RunOp(DeviceType::NEON);
net.Sync();
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExpected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExpected"),
*net.GetOutput("OutputNeon"),
0, 0.001);
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -35,6 +35,8 @@ namespace mace { ...@@ -35,6 +35,8 @@ namespace mace {
namespace ops { namespace ops {
namespace test { namespace test {
enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4 };
class OpDefBuilder { class OpDefBuilder {
public: public:
OpDefBuilder(const char *type, const std::string &name) { OpDefBuilder(const char *type, const std::string &name) {
...@@ -173,78 +175,161 @@ class OpsTestNet { ...@@ -173,78 +175,161 @@ class OpsTestNet {
} }
template<DeviceType D, typename T> template<DeviceType D, typename T>
void FillNHWCInputToNCHWInput(const std::string &name_nchw, void Transpose2D(const std::string &src_name,
const std::string &name_nhwc) { const std::string &dst_name) {
Tensor *input = ws_.GetTensor(name_nhwc); Tensor *input = ws_.GetTensor(src_name);
Tensor *output = ws_.CreateTensor(name_nchw, Tensor *output = ws_.CreateTensor(dst_name,
GetDeviceAllocator(D), GetDeviceAllocator(D),
DataTypeToEnum<T>::v()); DataTypeToEnum<T>::v());
const std::vector<index_t> input_shape = input->shape(); const std::vector<index_t> input_shape = input->shape();
index_t batch = input_shape[0]; MACE_CHECK(input_shape.size() == 2, "input shape != 2");
index_t height = input_shape[1]; output->Resize({input_shape[1], input_shape[0]});
index_t width = input_shape[2]; Tensor::MappingGuard input_guard(input);
index_t channels = input_shape[3]; Tensor::MappingGuard output_guard(output);
output->Resize({batch, channels, height, width});
const T *input_data = input->data<T>(); const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>(); T *output_data = output->mutable_data<T>();
for (index_t b = 0; b < batch; ++b) { for (index_t i = 0; i < input_shape[0]; ++i) {
for (index_t c = 0; c < channels; ++c) { for (index_t j = 0; j < input_shape[1]; ++j) {
for (index_t h = 0; h < height; ++h) { output_data[j * input_shape[0] + i] =
for (index_t w = 0; w < width; ++w) { input_data[i * input_shape[1] + j];
output_data[((b * channels + c) * height + h) * width + w] =
input_data[((b * height + h) * width + w) * channels + c];
}
}
} }
} }
} }
template<DeviceType D, typename T> template<DeviceType D, typename T>
void FillHWOIInputToOIHWInput(const std::string &name_oihw, void TransformDataFormat(const std::string &src_name,
const std::string &name_hwoi) { const DataFormat src_format,
Tensor *input = ws_.GetTensor(name_hwoi); const std::string &dst_name,
Tensor *output = ws_.CreateTensor(name_oihw, const DataFormat dst_format) {
Tensor *input = ws_.GetTensor(src_name);
Tensor *output = ws_.CreateTensor(dst_name,
GetDeviceAllocator(D), GetDeviceAllocator(D),
DataTypeToEnum<T>::v()); DataTypeToEnum<T>::v());
const std::vector<index_t> input_shape = input->shape(); const std::vector<index_t> input_shape = input->shape();
index_t height = input_shape[0]; MACE_CHECK(input_shape.size() == 4, "input shape != 4");
index_t width = input_shape[1];
index_t out_channels = input_shape[2]; if (src_format == NHWC && dst_format == NCHW) {
index_t in_channels = input_shape[3]; index_t batch = input_shape[0];
index_t hw = height * width; index_t height = input_shape[1];
index_t oi = out_channels * in_channels; index_t width = input_shape[2];
output->Resize({out_channels, in_channels, height, width}); index_t channels = input_shape[3];
const T *input_data = input->data<T>(); output->Resize({batch, channels, height, width});
T *output_data = output->mutable_data<T>(); Tensor::MappingGuard input_guard(input);
for (index_t i = 0; i < oi; ++i) { Tensor::MappingGuard output_guard(output);
for (index_t j = 0; j < hw; ++j) { const T *input_data = input->data<T>();
output_data[i * height * width + j] = T *output_data = output->mutable_data<T>();
input_data[j * out_channels * in_channels + i]; for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
output_data[((b * channels + c) * height + h) * width + w] =
input_data[((b * height + h) * width + w) * channels + c];
}
}
}
}
} else if (src_format == NCHW && dst_format == NHWC) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
output->Resize({batch, height, width, channels});
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
for (index_t b = 0; b < batch; ++b) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < channels; ++c) {
output_data[((b * height + h) * width + w) * channels + c] =
input_data[((b * channels + c) * height + h) * width + w];
}
}
}
}
} else if (src_format == HWOI && dst_format == OIHW) {
index_t height = input_shape[0];
index_t width = input_shape[1];
index_t out_channels = input_shape[2];
index_t in_channels = input_shape[3];
index_t hw = height * width;
index_t oi = out_channels * in_channels;
output->Resize({out_channels, in_channels, height, width});
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
for (index_t i = 0; i < oi; ++i) {
for (index_t j = 0; j < hw; ++j) {
output_data[i * height * width + j] =
input_data[j * out_channels * in_channels + i];
}
}
} else if (src_format == OIHW && dst_format == HWOI) {
index_t out_channels = input_shape[0];
index_t in_channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
index_t hw = height * width;
index_t oi = out_channels * in_channels;
output->Resize({height, width, out_channels, in_channels});
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
for (index_t i = 0; i < hw; ++i) {
for (index_t j = 0; j < oi; ++j) {
output_data[i * out_channels * in_channels + j] =
input_data[j * height * width + i];
}
}
} else if (src_format == HWIO && dst_format == OIHW) {
index_t height = input_shape[0];
index_t width = input_shape[1];
index_t in_channels = input_shape[2];
index_t out_channels = input_shape[3];
index_t hw = height * width;
output->Resize({out_channels, in_channels, height, width});
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
for (index_t m = 0; m < out_channels; ++m) {
for (index_t c = 0; c < in_channels; ++c) {
for (index_t k = 0; k < hw; ++k) {
output_data[((m * in_channels) + c) * height * width + k] =
input_data[k * out_channels * in_channels + c * out_channels + m];
}
}
} }
} else {
MACE_NOT_IMPLEMENTED;
} }
} }
template<DeviceType D, typename T> template<DeviceType D, typename T>
void FillHWIOInputToOIHWInput(const std::string &name_oihw, void FillNHWCInputToNCHWInput(const std::string &name_nchw,
const std::string &name_hwio) { const std::string &name_nhwc) {
Tensor *input = ws_.GetTensor(name_hwio); Tensor *input = ws_.GetTensor(name_nhwc);
Tensor *output = ws_.CreateTensor(name_oihw, Tensor *output = ws_.CreateTensor(name_nchw,
GetDeviceAllocator(D), GetDeviceAllocator(D),
DataTypeToEnum<T>::v()); DataTypeToEnum<T>::v());
const std::vector<index_t> input_shape = input->shape(); const std::vector<index_t> input_shape = input->shape();
index_t height = input_shape[0]; index_t batch = input_shape[0];
index_t width = input_shape[1]; index_t height = input_shape[1];
index_t in_channels = input_shape[2]; index_t width = input_shape[2];
index_t out_channels = input_shape[3]; index_t channels = input_shape[3];
index_t hw = height * width; output->Resize({batch, channels, height, width});
output->Resize({out_channels, in_channels, height, width});
const T *input_data = input->data<T>(); const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>(); T *output_data = output->mutable_data<T>();
for (index_t m = 0; m < out_channels; ++m) { for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < in_channels; ++c) { for (index_t c = 0; c < channels; ++c) {
for (index_t k = 0; k < hw; ++k) { for (index_t h = 0; h < height; ++h) {
output_data[((m * in_channels) + c) * height * width + k] = for (index_t w = 0; w < width; ++w) {
input_data[k * out_channels * in_channels + c * out_channels + m]; output_data[((b * channels + c) * height + h) * width + w] =
input_data[((b * height + h) * width + w) * channels + c];
}
} }
} }
} }
...@@ -349,7 +434,7 @@ void GenerateRandomRealTypeData(const std::vector<index_t> &shape, ...@@ -349,7 +434,7 @@ void GenerateRandomRealTypeData(const std::vector<index_t> &shape,
std::generate(res->begin(), res->end(), std::generate(res->begin(), res->end(),
[&gen, &nd, positive] { [&gen, &nd, positive] {
return half_float::half_cast<half>( return half_float::half_cast<half>(
positive ? std::abs(nd(gen)) : nd(gen)); positive ? std::abs(nd(gen)) : nd(gen));
}); });
} else { } else {
std::generate(res->begin(), res->end(), [&gen, &nd, positive] { std::generate(res->begin(), res->end(), [&gen, &nd, positive] {
...@@ -528,7 +613,6 @@ struct Expector<EXP_TYPE, RES_TYPE, false> { ...@@ -528,7 +613,6 @@ struct Expector<EXP_TYPE, RES_TYPE, false> {
} }
}; };
template<typename T> template<typename T>
void ExpectTensorNear(const Tensor &x, const Tensor &y, void ExpectTensorNear(const Tensor &x, const Tensor &y,
const double rel_err = 1e-5, const double rel_err = 1e-5,
......
...@@ -23,12 +23,6 @@ void Register_Pooling(OperatorRegistry *op_registry) { ...@@ -23,12 +23,6 @@ void Register_Pooling(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
PoolingOp<DeviceType::CPU, float>); PoolingOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
.Device(DeviceType::CPU)
.TypeConstraint<half>("T")
.Build(),
PoolingOp<DeviceType::CPU, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling") REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
...@@ -39,11 +33,6 @@ void Register_Pooling(OperatorRegistry *op_registry) { ...@@ -39,11 +33,6 @@ void Register_Pooling(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
PoolingOp<DeviceType::OPENCL, half>); PoolingOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
PoolingOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -23,7 +23,7 @@ namespace ops { ...@@ -23,7 +23,7 @@ namespace ops {
namespace test { namespace test {
namespace { namespace {
template <DeviceType D> template<DeviceType D>
void Pooling(int iters, void Pooling(int iters,
int batch, int batch,
int channels, int channels,
...@@ -36,7 +36,20 @@ void Pooling(int iters, ...@@ -36,7 +36,20 @@ void Pooling(int iters,
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
OpDefBuilder("Pooling", "PoolingTest")
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input",
{batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input",
{batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::CPU) {
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.AddIntArg("pooling_type", pooling_type) .AddIntArg("pooling_type", pooling_type)
...@@ -45,10 +58,22 @@ void Pooling(int iters, ...@@ -45,10 +58,22 @@ void Pooling(int iters,
.AddIntArg("padding", padding) .AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
// Add input data OpDefBuilder("Pooling", "PoolingTest")
net.AddRandomInput<DeviceType::CPU, float>("Input", .Input("InputImage")
{batch, channels, height, width}); .Output("OutputImage")
.AddIntArg("pooling_type", pooling_type)
.AddIntsArg("kernels", {kernel, kernel})
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
} else {
MACE_NOT_IMPLEMENTED;
}
// Warm-up // Warm-up
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
...@@ -78,8 +103,8 @@ void Pooling(int iters, ...@@ -78,8 +103,8 @@ void Pooling(int iters,
##DEVICE) ##DEVICE)
#define BM_POOLING(N, C, H, W, K, S, PA, PO) \ #define BM_POOLING(N, C, H, W, K, S, PA, PO) \
BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \
// BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, NEON); BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, OPENCL);
BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX); BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX);
BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX); BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX);
......
...@@ -28,9 +28,21 @@ class PoolingOpTest : public OpsTestBase {}; ...@@ -28,9 +28,21 @@ class PoolingOpTest : public OpsTestBase {};
TEST_F(PoolingOpTest, MAX_VALID) { TEST_F(PoolingOpTest, MAX_VALID) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("kernels", {2, 2}) .AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2}) .AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID) .AddIntArg("padding", Padding::VALID)
...@@ -38,15 +50,14 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -38,15 +50,14 @@ TEST_F(PoolingOpTest, MAX_VALID) {
.AddIntArg("pooling_type", PoolingType::MAX) .AddIntArg("pooling_type", PoolingType::MAX)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
auto expected = auto expected =
CreateTensor<float>({1, 2, 2, 2}, {5, 21, 7, 23, 13, 29, 15, 31}); CreateTensor<float>({1, 2, 2, 2}, {5, 21, 7, 23, 13, 29, 15, 31});
...@@ -57,9 +68,19 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -57,9 +68,19 @@ TEST_F(PoolingOpTest, MAX_VALID) {
TEST_F(PoolingOpTest, MAX_SAME) { TEST_F(PoolingOpTest, MAX_SAME) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("kernels", {2, 2}) .AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2}) .AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
...@@ -67,13 +88,14 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -67,13 +88,14 @@ TEST_F(PoolingOpTest, MAX_SAME) {
.AddIntArg("pooling_type", PoolingType::MAX) .AddIntArg("pooling_type", PoolingType::MAX)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8});
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
auto expected = CreateTensor<float>({1, 2, 2, 1}, {4, 5, 7, 8}); auto expected = CreateTensor<float>({1, 2, 2, 1}, {4, 5, 7, 8});
...@@ -83,9 +105,20 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -83,9 +105,20 @@ TEST_F(PoolingOpTest, MAX_SAME) {
TEST_F(PoolingOpTest, MAX_VALID_DILATION) { TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("kernels", {2, 2}) .AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {1, 1}) .AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID) .AddIntArg("padding", Padding::VALID)
...@@ -93,14 +126,14 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -93,14 +126,14 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
.AddIntArg("pooling_type", PoolingType::MAX) .AddIntArg("pooling_type", PoolingType::MAX)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
auto expected = CreateTensor<float>({1, 2, 2, 1}, {10, 11, 14, 15}); auto expected = CreateTensor<float>({1, 2, 2, 1}, {10, 11, 14, 15});
...@@ -110,9 +143,20 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -110,9 +143,20 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
TEST_F(PoolingOpTest, MAX_k2x2s2x2) { TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 2, 9, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntArg("pooling_type", PoolingType::MAX) .AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {2, 2}) .AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2}) .AddIntsArg("strides", {2, 2})
...@@ -120,13 +164,15 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -120,13 +164,15 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 2, 9, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
auto expected = CreateTensor<float>({1, 1, 5, 1}, {10, 12, 14, 16, 17}); auto expected = CreateTensor<float>({1, 1, 5, 1}, {10, 12, 14, 16, 17});
...@@ -145,12 +191,15 @@ void SimpleMaxPooling3S2() { ...@@ -145,12 +191,15 @@ void SimpleMaxPooling3S2() {
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}); 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
BufferToImage<D, float>(&net, "Input", "InputImage", net.TransformDataFormat<DeviceType::CPU, float>("Input",
kernels::BufferType::IN_OUT_CHANNEL); NHWC,
"InputNCHW",
NCHW);
// Run
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputImage") .Input("InputNCHW")
.Output("OutputImage") .Output("OutputNCHW")
.AddIntArg("pooling_type", PoolingType::MAX) .AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3}) .AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", {2, 2}) .AddIntsArg("strides", {2, 2})
...@@ -158,13 +207,16 @@ void SimpleMaxPooling3S2() { ...@@ -158,13 +207,16 @@ void SimpleMaxPooling3S2() {
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "Output", net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else {
// Run
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("InputImage")
.Output("Output") .Output("OutputImage")
.AddIntArg("pooling_type", PoolingType::MAX) .AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3}) .AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", {2, 2}) .AddIntsArg("strides", {2, 2})
...@@ -172,6 +224,8 @@ void SimpleMaxPooling3S2() { ...@@ -172,6 +224,8 @@ void SimpleMaxPooling3S2() {
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} }
// Check // Check
...@@ -194,9 +248,18 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape, ...@@ -194,9 +248,18 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape,
Padding padding) { Padding padding) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntArg("pooling_type", PoolingType::MAX) .AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3}) .AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", strides) .AddIntsArg("strides", strides)
...@@ -204,11 +267,14 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape, ...@@ -204,11 +267,14 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape,
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input", input_shape);
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -226,7 +292,7 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape, ...@@ -226,7 +292,7 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "OPENCLOutput", ImageToBuffer<D, float>(&net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DT_HALF) { if (DataTypeToEnum<T>::value == DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"),
...@@ -237,12 +303,6 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape, ...@@ -237,12 +303,6 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape,
} }
} // namespace } // namespace
// TODO(chenghui) : there is a bug.
// TEST_F(PoolingOpTest, NEONAlignedMaxPooling3S2) {
// AlignedMaxPooling3S2<NEON>(Padding::VALID);
// AlignedMaxPooling3S2<NEON>(Padding::SAME);
//}
TEST_F(PoolingOpTest, OPENCLAlignedMaxPooling3S2) { TEST_F(PoolingOpTest, OPENCLAlignedMaxPooling3S2) {
MaxPooling3S2<OPENCL, float>({3, 64, 32, 32}, {1, 1}, Padding::VALID); MaxPooling3S2<OPENCL, float>({3, 64, 32, 32}, {1, 1}, Padding::VALID);
MaxPooling3S2<OPENCL, float>({3, 64, 32, 32}, {2, 2}, Padding::VALID); MaxPooling3S2<OPENCL, float>({3, 64, 32, 32}, {2, 2}, Padding::VALID);
...@@ -267,9 +327,21 @@ TEST_F(PoolingOpTest, OPENCLUnalignedMaxPooling3S2) { ...@@ -267,9 +327,21 @@ TEST_F(PoolingOpTest, OPENCLUnalignedMaxPooling3S2) {
TEST_F(PoolingOpTest, AVG_VALID) { TEST_F(PoolingOpTest, AVG_VALID) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntsArg("kernels", {2, 2}) .AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2}) .AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID) .AddIntArg("padding", Padding::VALID)
...@@ -277,15 +349,15 @@ TEST_F(PoolingOpTest, AVG_VALID) { ...@@ -277,15 +349,15 @@ TEST_F(PoolingOpTest, AVG_VALID) {
.AddIntArg("pooling_type", PoolingType::AVG) .AddIntArg("pooling_type", PoolingType::AVG)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
auto expected = CreateTensor<float>( auto expected = CreateTensor<float>(
{1, 2, 2, 2}, {2.5, 18.5, 4.5, 20.5, 10.5, 26.5, 12.5, 28.5}); {1, 2, 2, 2}, {2.5, 18.5, 4.5, 20.5, 10.5, 26.5, 12.5, 28.5});
...@@ -339,9 +411,18 @@ void AvgPoolingTest(const std::vector<index_t> &shape, ...@@ -339,9 +411,18 @@ void AvgPoolingTest(const std::vector<index_t> &shape,
Padding padding) { Padding padding) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntArg("pooling_type", PoolingType::AVG) .AddIntArg("pooling_type", PoolingType::AVG)
.AddIntsArg("kernels", kernels) .AddIntsArg("kernels", kernels)
.AddIntsArg("strides", strides) .AddIntsArg("strides", strides)
...@@ -349,11 +430,14 @@ void AvgPoolingTest(const std::vector<index_t> &shape, ...@@ -349,11 +430,14 @@ void AvgPoolingTest(const std::vector<index_t> &shape,
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input", shape);
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -371,8 +455,7 @@ void AvgPoolingTest(const std::vector<index_t> &shape, ...@@ -371,8 +455,7 @@ void AvgPoolingTest(const std::vector<index_t> &shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "OPENCLOutput", ImageToBuffer<D, float>(&net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DT_HALF) { if (DataTypeToEnum<T>::value == DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"),
...@@ -425,64 +508,6 @@ TEST_F(PoolingOpTest, OPENCLUnAlignedLargeKernelAvgPooling) { ...@@ -425,64 +508,6 @@ TEST_F(PoolingOpTest, OPENCLUnAlignedLargeKernelAvgPooling) {
Padding::SAME); Padding::SAME);
} }
namespace {
void AvgPoolingNEONTest(const std::vector<index_t> &shape,
const std::vector<int> &kernels,
const std::vector<int> &strides,
Padding padding,
PoolingType pooling_type) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntArg("pooling_type", pooling_type)
.AddIntsArg("kernels", kernels)
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<CPU, float>("Input", shape);
// run on cpu
net.RunOp();
OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNeon")
.Output("OutputNeon")
.AddIntArg("pooling_type", pooling_type)
.AddIntsArg("kernels", kernels)
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
// run on neon
net.RunOp(DeviceType::NEON);
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-5, 1e-4);
}
} // namespace
TEST_F(PoolingOpTest, NEONTest) {
AvgPoolingNEONTest({3, 31, 37, 128}, {8, 8}, {8, 8},
Padding::VALID, PoolingType::MAX);
AvgPoolingNEONTest({3, 31, 37, 128}, {8, 8}, {8, 8},
Padding::SAME, PoolingType::MAX);
AvgPoolingNEONTest({3, 31, 37, 128}, {8, 8}, {8, 8},
Padding::VALID, PoolingType::AVG);
AvgPoolingNEONTest({3, 31, 37, 128}, {8, 8}, {8, 8},
Padding::SAME, PoolingType::AVG);
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -23,11 +23,6 @@ void Register_Quantize(OperatorRegistry *op_registry) { ...@@ -23,11 +23,6 @@ void Register_Quantize(OperatorRegistry *op_registry) {
.TypeConstraint<uint8_t>("T") .TypeConstraint<uint8_t>("T")
.Build(), .Build(),
QuantizeOp<DeviceType::CPU, uint8_t>); QuantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
QuantizeOp<DeviceType::CPU, uint8_t>);
} }
void Register_Dequantize(OperatorRegistry *op_registry) { void Register_Dequantize(OperatorRegistry *op_registry) {
...@@ -36,11 +31,6 @@ void Register_Dequantize(OperatorRegistry *op_registry) { ...@@ -36,11 +31,6 @@ void Register_Dequantize(OperatorRegistry *op_registry) {
.TypeConstraint<uint8_t>("T") .TypeConstraint<uint8_t>("T")
.Build(), .Build(),
DequantizeOp<DeviceType::CPU, uint8_t>); DequantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
DequantizeOp<DeviceType::CPU, uint8_t>);
} }
void Register_Requantize(OperatorRegistry *op_registry) { void Register_Requantize(OperatorRegistry *op_registry) {
...@@ -49,11 +39,6 @@ void Register_Requantize(OperatorRegistry *op_registry) { ...@@ -49,11 +39,6 @@ void Register_Requantize(OperatorRegistry *op_registry) {
.TypeConstraint<uint8_t>("T") .TypeConstraint<uint8_t>("T")
.Build(), .Build(),
RequantizeOp<DeviceType::CPU, uint8_t>); RequantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
RequantizeOp<DeviceType::CPU, uint8_t>);
} }
} // namespace ops } // namespace ops
......
...@@ -35,11 +35,27 @@ void ResizeBilinearBenchmark(int iters, ...@@ -35,11 +35,27 @@ void ResizeBilinearBenchmark(int iters,
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", if (D == DeviceType::CPU) {
{batch, input_height, input_width, channels}); net.AddRandomInput<D, float>("Input",
{batch, channels, input_height, input_width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input",
{batch, input_height, input_width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddInputFromArray<D, index_t>("OutSize", {2}, net.AddInputFromArray<D, index_t>("OutSize", {2},
{output_height, output_width}); {output_height, output_width});
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark")
.Input("Input")
.Input("OutSize")
.Output("Output")
.AddIntsArg("size", {output_height, output_width})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark")
...@@ -50,13 +66,7 @@ void ResizeBilinearBenchmark(int iters, ...@@ -50,13 +66,7 @@ void ResizeBilinearBenchmark(int iters,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("ResizeBilinear", "ResizeBilinearBenchmark") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Input("OutSize")
.Output("Output")
.AddIntsArg("size", {output_height, output_width})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
......
...@@ -28,19 +28,28 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) { ...@@ -28,19 +28,28 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("Input")
.Output("Output")
.AddIntsArg("size", {1, 2})
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("size", {1, 2})
.Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
auto expected = CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); auto expected = CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
...@@ -52,20 +61,30 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { ...@@ -52,20 +61,30 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("Input")
.Output("Output")
.AddIntArg("align_corners", 1)
.AddIntsArg("size", {1, 2})
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("align_corners", 1)
.AddIntsArg("size", {1, 2})
.Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check // Check
auto expected = CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); auto expected = CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
...@@ -92,15 +111,24 @@ void TestRandomResizeBilinear() { ...@@ -92,15 +111,24 @@ void TestRandomResizeBilinear() {
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, in_height, in_width, channels}); {batch, in_height, in_width, channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.AddIntArg("align_corners", align_corners) .AddIntArg("align_corners", align_corners)
.AddIntsArg("size", {height, width}) .AddIntsArg("size", {height, width})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -129,12 +157,6 @@ void TestRandomResizeBilinear() { ...@@ -129,12 +157,6 @@ void TestRandomResizeBilinear() {
} }
} // namespace } // namespace
/*
TEST_F(ResizeBilinearTest, NEONRandomResizeBilinear) {
TestRandomResizeBilinear<DeviceType::NEON>();
}
*/
TEST_F(ResizeBilinearTest, OPENCLRandomResizeBilinear) { TEST_F(ResizeBilinearTest, OPENCLRandomResizeBilinear) {
TestRandomResizeBilinear<DeviceType::OPENCL>(); TestRandomResizeBilinear<DeviceType::OPENCL>();
} }
......
...@@ -34,11 +34,6 @@ void Register_Slice(OperatorRegistry *op_registry) { ...@@ -34,11 +34,6 @@ void Register_Slice(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
SliceOp<DeviceType::OPENCL, half>); SliceOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
SliceOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -35,11 +35,6 @@ void Register_Softmax(OperatorRegistry *op_registry) { ...@@ -35,11 +35,6 @@ void Register_Softmax(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
SoftmaxOp<DeviceType::OPENCL, half>); SoftmaxOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
SoftmaxOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -31,9 +31,20 @@ void SoftmaxBenchmark( ...@@ -31,9 +31,20 @@ void SoftmaxBenchmark(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("Softmax", "SoftmaxBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
...@@ -42,10 +53,7 @@ void SoftmaxBenchmark( ...@@ -42,10 +53,7 @@ void SoftmaxBenchmark(
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("Softmax", "SoftmaxBM") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
......
...@@ -30,7 +30,17 @@ void Simple() { ...@@ -30,7 +30,17 @@ void Simple() {
net.AddInputFromArray<D, float>("Input", {1, 1, 2, 4}, net.AddInputFromArray<D, float>("Input", {1, 1, 2, 4},
{1, 1, 1, 1, 1, 2, 3, 4}); {1, 1, 1, 1, 1, 2, 3, 4});
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
...@@ -46,13 +56,7 @@ void Simple() { ...@@ -46,13 +56,7 @@ void Simple() {
ImageToBuffer<D, float>(&net, "OutputImage", "Output", ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("Softmax", "SoftmaxTest") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} }
auto expected = CreateTensor<float>( auto expected = CreateTensor<float>(
...@@ -74,13 +78,18 @@ void Complex(const std::vector<index_t> &logits_shape) { ...@@ -74,13 +78,18 @@ void Complex(const std::vector<index_t> &logits_shape) {
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", logits_shape); net.AddRandomInput<D, float>("Input", logits_shape);
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("Softmax", "SoftmaxTest") OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input") .Input("InputNCHW")
.Output("Output") .Output("OutputNCHW")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on cpu // Run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC);
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
...@@ -119,47 +128,6 @@ TEST_F(SoftmaxOpTest, OPENCLUnAligned) { ...@@ -119,47 +128,6 @@ TEST_F(SoftmaxOpTest, OPENCLUnAligned) {
Complex<DeviceType::OPENCL>({5, 211, 107, 1}); Complex<DeviceType::OPENCL>({5, 211, 107, 1});
} }
namespace {
void SoftMaxNEONTest(const std::vector<index_t> &logits_shape) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<CPU, float>("Input", logits_shape);
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run on cpu
net.RunOp();
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNeon")
.Output("OutputNeon")
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
// run on neon
net.RunOp(DeviceType::NEON);
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-5, 1e-5);
}
} // namespace
TEST_F(SoftmaxOpTest, NEONTest) {
SoftMaxNEONTest({5, 64, 64, 3});
SoftMaxNEONTest({8, 128, 128, 8});
SoftMaxNEONTest({1, 113, 107, 13});
SoftMaxNEONTest({5, 211, 107, 1});
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -169,69 +169,6 @@ TEST(SpaceToBatchTest, MultiBatchAndChannelData) { ...@@ -169,69 +169,6 @@ TEST(SpaceToBatchTest, MultiBatchAndChannelData) {
17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32}); 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32});
} }
// TEST(SpaceTobatchTest, CompareTF) {
//
// const std::string space_file = "/data/local/tmp/test/input";
// const std::string batch_file = "/data/local/tmp/test/output";
// const std::vector<index_t> space_shape = {1, 256, 256, 32};
// const int space_size = std::accumulate(space_shape.begin(),
// space_shape.end(), 1, std::multiplies<int>());
// const std::vector<index_t> batch_shape = {4, 130, 130, 32};
// const int batch_size = std::accumulate(batch_shape.begin(),
// batch_shape.end(), 1, std::multiplies<int>());
//
// auto space_tensor = std::unique_ptr<Tensor>(new
// Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// space_tensor->Resize(space_shape);
// std::vector<float> space_data(space_size, 0.0);
// std::ifstream in_file(space_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(space_data.data()),
// space_size * sizeof(float));
// in_file.close();
// Tensor::MappingGuard space_mapper(space_tensor.get());
// float *space_ptr = space_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(space_tensor->size()) == space_data.size())
// << "Space tensor size:" << space_tensor->size()
// << ", space data size:" << space_data.size();
// memcpy(space_ptr, space_data.data(), space_data.size() * sizeof(float));
// } else {
// VLOG(0) << "open space file failed";
// }
//
// auto batch_tensor = std::unique_ptr<Tensor>(new
// Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// std::vector<float> batch_data(batch_size, 0.0);
// batch_tensor->Resize(batch_shape);
// {
// std::ifstream in_file(batch_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(batch_data.data()),
// batch_size * sizeof(float));
// in_file.close();
// } else {
// VLOG(0) << "open batch file failed";
// }
// Tensor::MappingGuard batch_mapper(batch_tensor.get());
// float *batch_ptr = batch_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(batch_tensor->size()) ==
// batch_data.size());
// memcpy(batch_ptr, batch_data.data(), batch_data.size() * sizeof(float));
// }
//
// RunSpaceToBatch<DeviceType::OPENCL>(space_shape, space_data,
// {2, 2},
// {2, 2, 2, 2},
// batch_tensor.get());
//
// RunBatchToSpace<DeviceType::OPENCL>(batch_shape, batch_data,
// {2, 2},
// {2, 2, 2, 2},
// space_tensor.get());
//}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -24,12 +24,12 @@ ...@@ -24,12 +24,12 @@
namespace mace { namespace mace {
namespace ops { namespace ops {
template <DeviceType D, typename T> template<DeviceType D, typename T>
class SpaceToDepthOp : public Operator<D, T> { class SpaceToDepthOp : public Operator<D, T> {
public: public:
SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws) SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), false) { functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), false) {
} }
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
...@@ -37,16 +37,27 @@ class SpaceToDepthOp : public Operator<D, T> { ...@@ -37,16 +37,27 @@ class SpaceToDepthOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
const int block_size = const int block_size =
OperatorBase::GetSingleArgument<int>("block_size", 1); OperatorBase::GetSingleArgument<int>("block_size", 1);
const int input_height = input->dim(1); index_t input_height;
const int input_width = input->dim(2); index_t input_width;
const int input_depth = input->dim(3); index_t input_depth;
if (D == CPU) {
input_height = input->dim(2);
input_width = input->dim(3);
input_depth = input->dim(1);
} else if (D == OPENCL) {
input_height = input->dim(1);
input_width = input->dim(2);
input_depth = input->dim(3);
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK((input_depth % 4) == 0, MACE_CHECK((input_depth % 4) == 0,
"input channel should be dividable by 4"); "input channel should be dividable by 4");
MACE_CHECK( MACE_CHECK(
(input_width%block_size == 0)&&(input_height%block_size == 0), (input_width % block_size == 0) && (input_height % block_size == 0),
"input width and height should be dividable by block_size", "input width and height should be dividable by block_size",
input->dim(3)); input->dim(3));
functor_(input, output, future); functor_(input, output, future);
return true; return true;
} }
......
...@@ -29,9 +29,20 @@ void SpaceToDepth( ...@@ -29,9 +29,20 @@ void SpaceToDepth(
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, height, channels, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) { if (D == DeviceType::CPU) {
OpDefBuilder("SpaceToDepth", "SpaceToDepthBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage", BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
...@@ -41,10 +52,7 @@ void SpaceToDepth( ...@@ -41,10 +52,7 @@ void SpaceToDepth(
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("SpaceToDepth", "SpaceToDepthBM") MACE_NOT_IMPLEMENTED;
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} }
// Warm-up // Warm-up
......
...@@ -19,16 +19,10 @@ namespace ops { ...@@ -19,16 +19,10 @@ namespace ops {
void Register_Transpose(OperatorRegistry *op_registry) { void Register_Transpose(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Transpose") REGISTER_OPERATOR(op_registry, OpKeyBuilder("Transpose")
.Device(DeviceType::CPU) .Device(DeviceType::CPU)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
TransposeOp<DeviceType::CPU, float>); TransposeOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Transpose")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
TransposeOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -221,7 +221,7 @@ class CaffeConverter(object): ...@@ -221,7 +221,7 @@ class CaffeConverter(object):
arg.i = self.dt arg.i = self.dt
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
if self.device == 'neon': if self.device == 'cpu':
data_format_arg.s = 'NCHW' data_format_arg.s = 'NCHW'
else: else:
data_format_arg.s = 'NHWC' data_format_arg.s = 'NHWC'
...@@ -396,7 +396,7 @@ class CaffeConverter(object): ...@@ -396,7 +396,7 @@ class CaffeConverter(object):
def convert_conv2d(self, op): def convert_conv2d(self, op):
use_winograd = False use_winograd = False
if self.device == 'neon': if self.device == 'cpu':
use_winograd = self.check_winograd_conv(op) use_winograd = self.check_winograd_conv(op)
param = op.layer.convolution_param param = op.layer.convolution_param
...@@ -414,14 +414,14 @@ class CaffeConverter(object): ...@@ -414,14 +414,14 @@ class CaffeConverter(object):
# Add filter # Add filter
weight_tensor_name = op.name + '_weight:0' weight_tensor_name = op.name + '_weight:0'
if self.device == 'neon': if self.device == 'cpu':
weight_data = op.data[0] weight_data = op.data[0]
else: else:
# OIHW -> HWOI # OIHW -> HWOI
weight_data = op.data[0].transpose((2, 3, 0, 1)) weight_data = op.data[0].transpose((2, 3, 0, 1))
if self.device == 'neon' and use_winograd: if self.device == 'cpu' and use_winograd:
self.convert_winograd_conv_filter_neon(op, op_def) self.convert_winograd_conv_filter_cpu(op, op_def)
else: else:
self.add_tensor(weight_tensor_name, weight_data) self.add_tensor(weight_tensor_name, weight_data)
...@@ -459,7 +459,7 @@ class CaffeConverter(object): ...@@ -459,7 +459,7 @@ class CaffeConverter(object):
final_op = op final_op = op
self.resolved_ops.add(op.name) self.resolved_ops.add(op.name)
input_format = 'NCHW' if self.device == 'neon' else 'NHWC' input_format = 'NCHW' if self.device == 'cpu' else 'NHWC'
output_shape = Shapes.conv_pool_shape( output_shape = Shapes.conv_pool_shape(
op.get_single_parent().output_shape_map[op.layer.bottom[0]], op.get_single_parent().output_shape_map[op.layer.bottom[0]],
weight_data.shape, paddings, strides, dilations, math.floor, weight_data.shape, paddings, strides, dilations, math.floor,
...@@ -486,7 +486,7 @@ class CaffeConverter(object): ...@@ -486,7 +486,7 @@ class CaffeConverter(object):
def check_winograd_conv(self, op): def check_winograd_conv(self, op):
param = op.layer.convolution_param param = op.layer.convolution_param
filter_shape = np.asarray(op.data[0].shape) filter_shape = np.asarray(op.data[0].shape)
if self.device != 'neon': if self.device != 'cpu':
filter_shape = filter_shape[[2, 3, 0, 1]] # OIHW -> HWOI filter_shape = filter_shape[[2, 3, 0, 1]] # OIHW -> HWOI
paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None) paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None)
...@@ -503,7 +503,7 @@ class CaffeConverter(object): ...@@ -503,7 +503,7 @@ class CaffeConverter(object):
elif len(param.dilation) == 2: elif len(param.dilation) == 2:
dilations = [param.dilation[0], param.dilation[1]] dilations = [param.dilation[0], param.dilation[1]]
input_format = 'NCHW' if self.device == 'neon' else 'NHWC' input_format = 'NCHW' if self.device == 'cpu' else 'NHWC'
output_shape = Shapes.conv_pool_shape( output_shape = Shapes.conv_pool_shape(
op.get_single_parent().output_shape_map[op.layer.bottom[0]], op.get_single_parent().output_shape_map[op.layer.bottom[0]],
filter_shape, paddings, strides, dilations, math.floor, filter_shape, paddings, strides, dilations, math.floor,
...@@ -519,13 +519,13 @@ class CaffeConverter(object): ...@@ -519,13 +519,13 @@ class CaffeConverter(object):
(16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \
(16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \
(width < OPENCL_IMAGE_MAX_SIZE) (width < OPENCL_IMAGE_MAX_SIZE)
elif self.device == 'neon': elif self.device == 'cpu':
return filter_shape[2] == 3 and \ return filter_shape[2] == 3 and \
filter_shape[2] == filter_shape[3] and \ filter_shape[2] == filter_shape[3] and \
filter_shape[0] >= 8 and filter_shape[1] >= 8 filter_shape[0] >= 8 and filter_shape[1] >= 8
return False return False
def convert_winograd_conv_filter_neon(self, op, op_def): def convert_winograd_conv_filter_cpu(self, op, op_def):
# Add filter # Add filter
weight_tensor_name = op.name + '_weight:0' weight_tensor_name = op.name + '_weight:0'
weight_data = op.data[0] # OIHW weight_data = op.data[0] # OIHW
...@@ -739,7 +739,7 @@ class CaffeConverter(object): ...@@ -739,7 +739,7 @@ class CaffeConverter(object):
weight_data = op.data[0].reshape(-1, op.data[0].shape[-1]) weight_data = op.data[0].reshape(-1, op.data[0].shape[-1])
assert weight_data.shape[1] == ( assert weight_data.shape[1] == (
input_shape[1] * input_shape[2] * input_shape[3]) input_shape[1] * input_shape[2] * input_shape[3])
if self.device != 'neon': if self.device != 'cpu':
weight_data = weight_data.reshape(-1, input_shape[3], weight_data = weight_data.reshape(-1, input_shape[3],
input_shape[1], input_shape[2]) input_shape[1], input_shape[2])
weight_data = weight_data.transpose((0, 2, 3, 1)).reshape( weight_data = weight_data.transpose((0, 2, 3, 1)).reshape(
...@@ -783,7 +783,7 @@ class CaffeConverter(object): ...@@ -783,7 +783,7 @@ class CaffeConverter(object):
op_def.input.extend([bias_tensor_name]) op_def.input.extend([bias_tensor_name])
self.resolved_ops.add(op.name) self.resolved_ops.add(op.name)
input_format = 'NCHW' if self.device == 'neon' else 'NHWC' input_format = 'NCHW' if self.device == 'cpu' else 'NHWC'
output_shape = Shapes.fully_connected_shape(input_shape, output_shape = Shapes.fully_connected_shape(input_shape,
weight_data.shape, weight_data.shape,
input_format) input_format)
...@@ -823,14 +823,14 @@ class CaffeConverter(object): ...@@ -823,14 +823,14 @@ class CaffeConverter(object):
0]] 0]]
if param.HasField('global_pooling') and param.global_pooling: if param.HasField('global_pooling') and param.global_pooling:
kernels = [input_shape[2], input_shape[3]] \ kernels = [input_shape[2], input_shape[3]] \
if self.device == 'neon' else \ if self.device == 'cpu' else \
[input_shape[1], input_shape[2]] [input_shape[1], input_shape[2]]
kernel_arg = op_def.arg.add() kernel_arg = op_def.arg.add()
kernel_arg.name = 'kernels' kernel_arg.name = 'kernels'
kernel_arg.ints.extend(kernels) kernel_arg.ints.extend(kernels)
if self.device != 'neon': if self.device != 'cpu':
filter_shape = [ filter_shape = [
kernels[0], kernels[1], input_shape[3], input_shape[3] kernels[0], kernels[1], input_shape[3], input_shape[3]
] ]
...@@ -838,7 +838,7 @@ class CaffeConverter(object): ...@@ -838,7 +838,7 @@ class CaffeConverter(object):
filter_shape = [ filter_shape = [
input_shape[1], input_shape[1], kernels[0], kernels[1] input_shape[1], input_shape[1], kernels[0], kernels[1]
] ]
input_format = 'NCHW' if self.device == 'neon' else 'NHWC' input_format = 'NCHW' if self.device == 'cpu' else 'NHWC'
output_shape = Shapes.conv_pool_shape(input_shape, filter_shape, output_shape = Shapes.conv_pool_shape(input_shape, filter_shape,
paddings, strides, [1, 1], paddings, strides, [1, 1],
math.ceil, input_format) math.ceil, input_format)
...@@ -897,7 +897,7 @@ class CaffeConverter(object): ...@@ -897,7 +897,7 @@ class CaffeConverter(object):
op_def = self.CommonConvert(op, 'Concat') op_def = self.CommonConvert(op, 'Concat')
axis_arg = op_def.arg.add() axis_arg = op_def.arg.add()
axis_arg.name = 'axis' axis_arg.name = 'axis'
axis_arg.i = 3 if self.device != 'neon' else 1 axis_arg.i = 3 if self.device != 'cpu' else 1
try: try:
if op.layer.concat_param.HasFeild('axis'): if op.layer.concat_param.HasFeild('axis'):
axis_arg.i = op.concat_param.axis axis_arg.i = op.concat_param.axis
...@@ -947,7 +947,7 @@ class CaffeConverter(object): ...@@ -947,7 +947,7 @@ class CaffeConverter(object):
axis_arg = op_def.arg.add() axis_arg = op_def.arg.add()
axis_arg.name = 'axis' axis_arg.name = 'axis'
axis_arg.i = 3 if self.device != 'neon' else 1 axis_arg.i = 3 if self.device != 'cpu' else 1
input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
num_outputs = len(op.layer.top) num_outputs = len(op.layer.top)
...@@ -958,7 +958,7 @@ class CaffeConverter(object): ...@@ -958,7 +958,7 @@ class CaffeConverter(object):
raise Exception( raise Exception(
'Mace do not support slice with input shape ' + 'Mace do not support slice with input shape ' +
str(input_shape) + ' and number of output ' + str(num_outputs)) str(input_shape) + ' and number of output ' + str(num_outputs))
input_format = 'NCHW' if self.device == 'neon' else 'NHWC' input_format = 'NCHW' if self.device == 'cpu' else 'NHWC'
output_shape = Shapes.slice_shape(input_shape, num_outputs, output_shape = Shapes.slice_shape(input_shape, num_outputs,
input_format) input_format)
for i in range(len(op.layer.top)): for i in range(len(op.layer.top)):
...@@ -978,14 +978,14 @@ class CaffeConverter(object): ...@@ -978,14 +978,14 @@ class CaffeConverter(object):
self.resolved_ops.add(op.name) self.resolved_ops.add(op.name)
def convert_reshape(self, op): def convert_reshape(self, op):
if self.device == 'neon': if self.device == 'cpu':
op_def = self.CommonConvert(op, 'Reshape') op_def = self.CommonConvert(op, 'Reshape')
else: else:
op_def = self.CommonConvert(op, 'ReOrganize') op_def = self.CommonConvert(op, 'ReOrganize')
input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
output_shape = input_shape output_shape = input_shape
shape_param = np.asarray(op.layer.reshape_param.shape.dim) shape_param = np.asarray(op.layer.reshape_param.shape.dim)
if self.device != 'neon': if self.device != 'cpu':
shape_param = shape_param[[0, 3, 1, 2]] shape_param = shape_param[[0, 3, 1, 2]]
for i in range(len(shape_param)): for i in range(len(shape_param)):
if shape_param[i] != 0: if shape_param[i] != 0:
...@@ -1060,7 +1060,7 @@ class CaffeConverter(object): ...@@ -1060,7 +1060,7 @@ class CaffeConverter(object):
assert len(input_nodes) == len(input_shapes) assert len(input_nodes) == len(input_shapes)
for i in range(len(input_nodes)): for i in range(len(input_nodes)):
input_op = self.ops_map[input_nodes[i]] input_op = self.ops_map[input_nodes[i]]
input_shape = input_shapes[i] if self.device != 'neon' else \ input_shape = input_shapes[i] if self.device != 'cpu' else \
[input_shapes[i][0], input_shapes[i][3], [input_shapes[i][0], input_shapes[i][3],
input_shapes[i][1], input_shapes[i][2]] input_shapes[i][1], input_shapes[i][2]]
if input_op.layer is not None: if input_op.layer is not None:
...@@ -1068,7 +1068,7 @@ class CaffeConverter(object): ...@@ -1068,7 +1068,7 @@ class CaffeConverter(object):
else: else:
input_op.output_shape_map[input_op.name] = input_shape input_op.output_shape_map[input_op.name] = input_shape
def add_neon_input_transform(self, names): def add_cpu_input_transform(self, names):
for name in names: for name in names:
new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0" new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0"
op_def = self.net_def.op.add() op_def = self.net_def.op.add()
...@@ -1085,7 +1085,7 @@ class CaffeConverter(object): ...@@ -1085,7 +1085,7 @@ class CaffeConverter(object):
arg.name = 'T' arg.name = 'T'
arg.i = self.dt arg.i = self.dt
def add_neon_output_transform(self, names): def add_cpu_output_transform(self, names):
for name in names: for name in names:
output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0" output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0"
op_def = self.net_def.op.add() op_def = self.net_def.op.add()
...@@ -1105,8 +1105,8 @@ class CaffeConverter(object): ...@@ -1105,8 +1105,8 @@ class CaffeConverter(object):
if self.device == 'gpu': if self.device == 'gpu':
self.add_input_transform(input_nodes) self.add_input_transform(input_nodes)
if self.device == 'neon': if self.device == 'cpu':
self.add_neon_input_transform(input_nodes) self.add_cpu_input_transform(input_nodes)
for op in self.ops: for op in self.ops:
if op.name in self.resolved_ops: if op.name in self.resolved_ops:
...@@ -1152,10 +1152,7 @@ class CaffeConverter(object): ...@@ -1152,10 +1152,7 @@ class CaffeConverter(object):
self.add_output_transform(output_nodes) self.add_output_transform(output_nodes)
if self.device == 'cpu': if self.device == 'cpu':
self.replace_in_out_name(input_nodes, output_nodes) self.add_cpu_output_transform(output_nodes)
if self.device == 'neon':
self.add_neon_output_transform(output_nodes)
for op in self.ops: for op in self.ops:
if op.name not in self.resolved_ops: if op.name not in self.resolved_ops:
......
...@@ -157,7 +157,7 @@ class TFConverter(object): ...@@ -157,7 +157,7 @@ class TFConverter(object):
self.add_output_shape(self.ops[name].outputs, op_def) self.add_output_shape(self.ops[name].outputs, op_def)
def add_neon_input_transform(self, names): def add_cpu_input_transform(self, names):
for name in names: for name in names:
new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0" new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0"
op_def = self.net_def.op.add() op_def = self.net_def.op.add()
...@@ -187,7 +187,7 @@ class TFConverter(object): ...@@ -187,7 +187,7 @@ class TFConverter(object):
epsilon_arg.name = 'buffer_type' epsilon_arg.name = 'buffer_type'
epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL'] epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL']
def add_neon_output_transform(self, names): def add_cpu_output_transform(self, names):
for name in names: for name in names:
output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0" output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0"
op_def = self.net_def.op.add() op_def = self.net_def.op.add()
...@@ -281,7 +281,7 @@ class TFConverter(object): ...@@ -281,7 +281,7 @@ class TFConverter(object):
return (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ return (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \
(16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \
(width < OPENCL_IMAGE_MAX_SIZE) (width < OPENCL_IMAGE_MAX_SIZE)
elif self.device == 'neon': elif self.device == 'cpu':
return filter_shape[2] >= 8 and filter_shape[3] >= 8 return filter_shape[2] >= 8 and filter_shape[3] >= 8
return False return False
...@@ -375,7 +375,7 @@ class TFConverter(object): ...@@ -375,7 +375,7 @@ class TFConverter(object):
self.add_output_shape(final_op.outputs, iwt_op) self.add_output_shape(final_op.outputs, iwt_op)
self.net_def.op.extend([wt_op, matmul_op, iwt_op]) self.net_def.op.extend([wt_op, matmul_op, iwt_op])
def convert_conv_winograd_filter_neon(self, op, op_def): def convert_conv_winograd_filter_cpu(self, op, op_def):
weight_tensor = get_input_tensor(op, 1) weight_tensor = get_input_tensor(op, 1)
weight_tensor_value = weight_tensor.eval().astype(np.float32) weight_tensor_value = weight_tensor.eval().astype(np.float32)
input_shape = get_input_tensor(op, 0).shape.as_list() input_shape = get_input_tensor(op, 0).shape.as_list()
...@@ -421,7 +421,7 @@ class TFConverter(object): ...@@ -421,7 +421,7 @@ class TFConverter(object):
def convert_conv2d(self, op): def convert_conv2d(self, op):
use_winograd = False use_winograd = False
if self.device == 'neon': if self.device == 'cpu':
use_winograd = self.check_winograd_conv(op) use_winograd = self.check_winograd_conv(op)
op_def = mace_pb2.OperatorDef() op_def = mace_pb2.OperatorDef()
...@@ -434,7 +434,7 @@ class TFConverter(object): ...@@ -434,7 +434,7 @@ class TFConverter(object):
else: else:
op_def.type = op.type op_def.type = op.type
if self.device == 'neon' and not use_winograd: if self.device == 'cpu' and not use_winograd:
self.transpose_filter_tensor[get_input_tensor( self.transpose_filter_tensor[get_input_tensor(
op, 1).name] = (3, 2, 0, 1) op, 1).name] = (3, 2, 0, 1)
elif op.type == 'Conv2D': elif op.type == 'Conv2D':
...@@ -449,8 +449,8 @@ class TFConverter(object): ...@@ -449,8 +449,8 @@ class TFConverter(object):
output_name = self.add_buffer_to_image( output_name = self.add_buffer_to_image(
get_input_tensor(op, 1).name, buffer_type) get_input_tensor(op, 1).name, buffer_type)
op_def.input.extend([output_name]) op_def.input.extend([output_name])
elif self.device == 'neon' and use_winograd: elif self.device == 'cpu' and use_winograd:
self.convert_conv_winograd_filter_neon(op, op_def) self.convert_conv_winograd_filter_cpu(op, op_def)
else: else:
op_def.input.extend( op_def.input.extend(
[get_input_tensor(op, i).name for i in range(len(op.inputs))]) [get_input_tensor(op, i).name for i in range(len(op.inputs))])
...@@ -463,7 +463,7 @@ class TFConverter(object): ...@@ -463,7 +463,7 @@ class TFConverter(object):
strides_arg.ints.extend(op.get_attr('strides')[1:3]) strides_arg.ints.extend(op.get_attr('strides')[1:3])
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
if self.device == 'neon': if self.device == 'cpu':
data_format_arg.s = 'NCHW' data_format_arg.s = 'NCHW'
else: else:
data_format_arg.s = 'NHWC' data_format_arg.s = 'NHWC'
...@@ -502,7 +502,7 @@ class TFConverter(object): ...@@ -502,7 +502,7 @@ class TFConverter(object):
self.net_def.op.extend([op_def]) self.net_def.op.extend([op_def])
def check_conv_to_fc(self, op): def check_conv_to_fc(self, op):
if self.device != 'neon' or op.type != "Conv2D": if self.device != 'cpu' or op.type != "Conv2D":
return False return False
filter_shape = get_input_tensor(op, 1).shape.as_list() filter_shape = get_input_tensor(op, 1).shape.as_list()
input_shape = get_input_tensor(op, 0).shape.as_list() input_shape = get_input_tensor(op, 0).shape.as_list()
...@@ -569,7 +569,7 @@ class TFConverter(object): ...@@ -569,7 +569,7 @@ class TFConverter(object):
arg.i = self.dt arg.i = self.dt
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
if self.device == 'neon': if self.device == 'cpu':
data_format_arg.s = 'NCHW' data_format_arg.s = 'NCHW'
else: else:
data_format_arg.s = 'NHWC' data_format_arg.s = 'NHWC'
...@@ -675,7 +675,7 @@ class TFConverter(object): ...@@ -675,7 +675,7 @@ class TFConverter(object):
epsilon_arg.f = get_input_tensor(op, 1).eval().astype(np.float) epsilon_arg.f = get_input_tensor(op, 1).eval().astype(np.float)
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
if self.device == 'neon': if self.device == 'cpu':
data_format_arg.s = 'NCHW' data_format_arg.s = 'NCHW'
else: else:
data_format_arg.s = 'NHWC' data_format_arg.s = 'NHWC'
...@@ -709,7 +709,7 @@ class TFConverter(object): ...@@ -709,7 +709,7 @@ class TFConverter(object):
kernels_arg.ints.extend(op.get_attr('ksize')[1:3]) kernels_arg.ints.extend(op.get_attr('ksize')[1:3])
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
if self.device == 'neon': if self.device == 'cpu':
data_format_arg.s = 'NCHW' data_format_arg.s = 'NCHW'
else: else:
data_format_arg.s = 'NHWC' data_format_arg.s = 'NHWC'
...@@ -739,7 +739,7 @@ class TFConverter(object): ...@@ -739,7 +739,7 @@ class TFConverter(object):
kernels_arg.ints.extend(op.inputs[0].shape.as_list()[1:3]) kernels_arg.ints.extend(op.inputs[0].shape.as_list()[1:3])
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
if self.device == 'neon': if self.device == 'cpu':
data_format_arg.s = 'NCHW' data_format_arg.s = 'NCHW'
else: else:
data_format_arg.s = 'NHWC' data_format_arg.s = 'NHWC'
...@@ -802,7 +802,7 @@ class TFConverter(object): ...@@ -802,7 +802,7 @@ class TFConverter(object):
axis_arg = op_def.arg.add() axis_arg = op_def.arg.add()
axis_arg.name = 'axis' axis_arg.name = 'axis'
axis = get_input_tensor(op, len(op.inputs) - 1).eval().astype(np.int32) axis = get_input_tensor(op, len(op.inputs) - 1).eval().astype(np.int32)
if self.device == 'neon' and axis == 3: if self.device == 'cpu' and axis == 3:
axis = 1 axis = 1
axis_arg.i = axis axis_arg.i = axis
self.add_output_shape(op.outputs, op_def) self.add_output_shape(op.outputs, op_def)
...@@ -969,7 +969,7 @@ class TFConverter(object): ...@@ -969,7 +969,7 @@ class TFConverter(object):
strides_arg.ints.extend([1, 1]) strides_arg.ints.extend([1, 1])
data_format_arg = op_def.arg.add() data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format' data_format_arg.name = 'data_format'
if self.device == 'neon': if self.device == 'cpu':
data_format_arg.s = 'NCHW' data_format_arg.s = 'NCHW'
else: else:
data_format_arg.s = 'NHWC' data_format_arg.s = 'NHWC'
...@@ -1109,8 +1109,8 @@ class TFConverter(object): ...@@ -1109,8 +1109,8 @@ class TFConverter(object):
def convert(self, input_nodes, output_nodes): def convert(self, input_nodes, output_nodes):
if self.device == 'gpu': if self.device == 'gpu':
self.add_gpu_input_transform(input_nodes) self.add_gpu_input_transform(input_nodes)
if self.device == 'neon': if self.device == 'cpu':
self.add_neon_input_transform(input_nodes) self.add_cpu_input_transform(input_nodes)
for op in self.tf_ops: for op in self.tf_ops:
if self.resolved_ops[op.name] == 1: if self.resolved_ops[op.name] == 1:
...@@ -1197,11 +1197,8 @@ class TFConverter(object): ...@@ -1197,11 +1197,8 @@ class TFConverter(object):
if self.device == 'gpu': if self.device == 'gpu':
self.add_gpu_output_transform(output_nodes) self.add_gpu_output_transform(output_nodes)
if self.device == 'neon':
self.add_neon_output_transform(output_nodes)
if self.device == 'cpu': if self.device == 'cpu':
self.replace_in_out_name(input_nodes, output_nodes) self.add_cpu_output_transform(output_nodes)
for key in self.resolved_ops: for key in self.resolved_ops:
if self.resolved_ops[key] != 1: if self.resolved_ops[key] != 1:
...@@ -1252,7 +1249,7 @@ class Optimizer: ...@@ -1252,7 +1249,7 @@ class Optimizer:
scale_tensor = self.tensor_map[scale_buffer_name] scale_tensor = self.tensor_map[scale_buffer_name]
weight_shape = weight_tensor.dims weight_shape = weight_tensor.dims
idx = 0 idx = 0
if self.device == 'neon': # OIHW if self.device == 'cpu': # OIHW
for oc in range(weight_shape[0]): for oc in range(weight_shape[0]):
for ic in range(weight_shape[1]): for ic in range(weight_shape[1]):
for i in range(weight_shape[2]): for i in range(weight_shape[2]):
......
...@@ -69,9 +69,6 @@ def get_data_and_device_type(runtime): ...@@ -69,9 +69,6 @@ def get_data_and_device_type(runtime):
elif runtime == "cpu": elif runtime == "cpu":
data_type = "DT_FLOAT" data_type = "DT_FLOAT"
device_type = "CPU" device_type = "CPU"
elif runtime == "neon":
data_type = "DT_FLOAT"
device_type = "NEON"
return data_type, device_type return data_type, device_type
......
...@@ -56,7 +56,6 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value, ...@@ -56,7 +56,6 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value,
print output_name, 'MACE VS', platform.upper( print output_name, 'MACE VS', platform.upper(
), 'similarity: ', similarity ), 'similarity: ', similarity
if (mace_runtime == "cpu" and similarity > 0.999) or \ if (mace_runtime == "cpu" and similarity > 0.999) or \
(mace_runtime == "neon" and similarity > 0.999) or \
(mace_runtime == "gpu" and similarity > 0.995) or \ (mace_runtime == "gpu" and similarity > 0.995) or \
(mace_runtime == "dsp" and similarity > 0.930): (mace_runtime == "dsp" and similarity > 0.930):
print '===================Similarity Test Passed==================' print '===================Similarity Test Passed=================='
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册