提交 7c9ca067 编写于 作者: 李寅

Refactor cpu

上级 7c1711d8
......@@ -91,7 +91,6 @@ extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_FusedConv2D(OperatorRegistry *op_registry);
extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
......@@ -132,7 +131,6 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_FusedConv2D(this);
ops::Register_GlobalAvgPooling(this);
ops::Register_ImageToBuffer(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
......
......@@ -318,6 +318,7 @@ class Tensor {
public:
explicit MappingGuard(const Tensor *tensor) : tensor_(tensor) {
if (tensor_ != nullptr) {
MACE_CHECK_NOTNULL(tensor_->buffer_);
tensor_->buffer_->Map(&mapped_image_pitch_);
}
}
......
......@@ -121,23 +121,30 @@ void PReLUActivation(const T *input_ptr,
}
template <DeviceType D, typename T>
class ActivationFunctor {
class ActivationFunctor;
template <>
class ActivationFunctor<DeviceType::CPU, float> {
public:
ActivationFunctor(ActivationType type, T relux_max_limit)
ActivationFunctor(ActivationType type, float relux_max_limit)
: activation_(type), relux_max_limit_(relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
if (activation_ == PRELU) {
MACE_CHECK_NOTNULL(alpha);
const T *alpha_ptr = alpha->data<T>();
const index_t outer_size = output->dim(0) * output->dim(1)
* output->dim(2);
PReLUActivation(input_ptr, outer_size, input->dim(3), 1, alpha_ptr,
const float *alpha_ptr = alpha->data<float>();
const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(input_ptr,
outer_size,
input->dim(1),
inner_size,
alpha_ptr,
output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_,
......@@ -145,22 +152,6 @@ class ActivationFunctor {
}
}
private:
ActivationType activation_;
T relux_max_limit_;
};
template <>
class ActivationFunctor<DeviceType::NEON, float> {
public:
ActivationFunctor(ActivationType type, float relux_max_limit)
: activation_(type), relux_max_limit_(relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future);
private:
ActivationType activation_;
float relux_max_limit_;
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
void ActivationFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
if (activation_ == PRELU) {
MACE_CHECK_NOTNULL(alpha);
const float *alpha_ptr = alpha->data<float>();
const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(input_ptr, outer_size, input->dim(1), inner_size, alpha_ptr,
output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/batch_norm.h"
namespace mace {
namespace kernels {
void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// ( \offset - \frac { \scale * mean } {
// \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
const float *offset_ptr = offset->data<float>();
float *output_ptr = output->mutable_data<float>();
std::vector<float> new_scale;
std::vector<float> new_offset;
if (!folded_constant_) {
new_scale.resize(channels);
new_offset.resize(channels);
const float *mean_ptr = mean->data<float>();
const float *var_ptr = var->data<float>();
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon);
new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
}
}
const float *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const float *offset_data = folded_constant_ ? offset_ptr : new_offset.data();
index_t channel_size = height * width;
index_t batch_size = channels * channel_size;
// NEON is slower, so stick to the trivial implementaion
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
index_t offset = b * batch_size + c * channel_size;
for (index_t hw = 0; hw < height * width; ++hw) {
output_ptr[offset + hw] =
scale_data[c] * input_ptr[offset + hw] + offset_data[c];
}
}
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/arm/conv_winograd.h"
namespace mace {
namespace kernels {
namespace {
void Conv2dNCHW(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
float *output) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h;
index_t iw = w * stride_w + kw * dilation_w;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((m * in_channels) + c) * filter_height + kh) * filter_width
+ kw;
output[out_offset] += input[in_offset] * filter[filter_offset];
}
}
}
}
}
}
}
}
} // namespace
extern void Conv2dNeonK1x1S1(const float *input,
const float *filter,
const index_t batch,
const index_t height,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> filter_shape(4);
if (is_filter_transformed_) {
// TOC -> OIHW
filter_shape[0] = filter->dim(1);
filter_shape[1] = filter->dim(2);
filter_shape[2] = filter_shape[3] = 3;
} else {
filter_shape = filter->shape();
}
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter_shape.data(),
dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::FLOOR,
output_shape.data());
}
output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter_shape[2];
index_t filter_w = filter_shape[3];
MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
index_t padded_input_height = input_height + paddings[0];
index_t padded_input_width = input_width + paddings[1];
index_t extra_input_height = padded_input_height;
index_t extra_input_width = padded_input_width;
index_t extra_output_height = height;
index_t extra_output_width = width;
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
std::function<void(const float *input, float *output)> conv_func;
bool use_winograd = is_filter_transformed_ || (filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1
&& input_channels >= 8 && channels >= 8);
bool use_neon_3x3_s1 = filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_3x3_s2 = filter_h == 3 && filter_w == 3
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1;
bool use_neon_1x1_s1 = filter_h == 1 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
std::vector<index_t> transformed_input_shape;
std::vector<index_t> transformed_output_shape;
std::vector<index_t> transformed_filter_shape;
// When size of input feature map is bigger than 16x16,
// set winograd out tile size to 6 to get higher performance.
index_t winograd_out_tile_size = 2;
if (input_height > 16 && input_width > 16) {
winograd_out_tile_size = 6;
}
if (use_winograd) {
extra_output_height = RoundUp<index_t>(height, winograd_out_tile_size);
extra_input_height = std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, winograd_out_tile_size);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
index_t tile_height_count = extra_output_height / winograd_out_tile_size;
index_t tile_width_count = extra_output_width / winograd_out_tile_size;
index_t tile_count = tile_height_count * tile_width_count;
index_t in_tile_area =
(winograd_out_tile_size + 2) * (winograd_out_tile_size + 2);
transformed_input_shape.insert(transformed_input_shape.end(),
{in_tile_area, batch, input_channels,
tile_count});
transformed_output_shape.insert(transformed_output_shape.end(),
{in_tile_area, batch, channels,
tile_count});
transformed_filter_shape.insert(transformed_filter_shape.end(),
{in_tile_area, channels, input_channels});
} else if (use_neon_3x3_s1) {
extra_output_height = RoundUp<index_t>(height, 2);
extra_input_height = std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_3x3_s2) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 2 + 3);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 2 + 3);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
}
// decide scratch size before allocate it
index_t total_scratch_size = 0;
index_t transformed_input_size = 0;
index_t transformed_output_size = 0;
index_t padded_input_size = 0;
index_t padded_output_size = 0;
if (use_winograd) {
transformed_input_size =
std::accumulate(transformed_input_shape.begin(),
transformed_input_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
transformed_output_size =
std::accumulate(transformed_output_shape.begin(),
transformed_output_shape.end(),
1,
std::multiplies<index_t>()) * sizeof(float);
total_scratch_size += transformed_input_size + transformed_output_size;
}
if (extra_input_height != input_height || extra_input_width != input_width) {
padded_input_size =
batch * input_channels * (input_height + pad_top + pad_bottom)
* (input_width + pad_left + pad_right) * sizeof(float);
total_scratch_size += padded_input_size;
}
if (extra_output_height != height || extra_output_width != width) {
padded_output_size =
batch * channels * extra_output_height * extra_output_width
* sizeof(float);
total_scratch_size += padded_output_size;
}
// Init scratch buffer
scratch_->Rewind();
scratch_->GrowSize(total_scratch_size);
Tensor transformed_input(scratch_->Scratch(transformed_input_size), DT_FLOAT);
Tensor
transformed_output(scratch_->Scratch(transformed_output_size), DT_FLOAT);
Tensor padded_input(scratch_->Scratch(padded_input_size), DT_FLOAT);
Tensor padded_output(scratch_->Scratch(padded_output_size), DT_FLOAT);
// decide which convolution function to call
if (use_winograd) {
transformed_input.Resize(transformed_input_shape);
transformed_output.Resize(transformed_output_shape);
const float *transformed_filter_ptr;
if (transformed_filter_.dim_size() == 0) {
transformed_filter_.Resize(transformed_filter_shape);
if (is_filter_transformed_) {
transformed_filter_ptr = filter_data;
} else {
switch (winograd_out_tile_size) {
case 2:
TransformFilter4x4(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter_.mutable_data<float>());
break;
case 6:
TransformFilter8x8(filter_data,
filter_shape[1],
filter_shape[0],
transformed_filter_.mutable_data<float>());
break;
default:MACE_NOT_IMPLEMENTED;
}
transformed_filter_ptr = transformed_filter_.data<float>();
}
} else {
transformed_filter_ptr = transformed_filter_.data<float>();
}
conv_func = [&](const float *pad_input, float *pad_output) {
WinoGradConv3x3s1(pad_input,
transformed_filter_ptr,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
winograd_out_tile_size,
transformed_input.mutable_data<float>(),
transformed_output.mutable_data<float>(),
pad_output);
};
} else if (use_neon_3x3_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (use_neon_3x3_s2) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S2(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (use_neon_1x1_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK1x1S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
pad_output);
};
} else {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNCHW(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_output);
};
}
// pad input and output
const Tensor *pad_input_ptr = input;
if (extra_input_height != input_height || extra_input_width != input_width) {
padded_input.Clear();
ConstructNCHWInputWithSpecificPadding(input,
pad_top,
pad_bottom,
pad_left,
pad_right,
&padded_input);
pad_input_ptr = &padded_input;
}
Tensor *pad_output_ptr = output;
if (extra_output_height != height || extra_output_width != width) {
padded_output.Resize({batch, channels, extra_output_height,
extra_output_width});
padded_output.Clear();
pad_output_ptr = &padded_output;
}
const float *pad_input_data = pad_input_ptr->data<float>();
float *pad_output_data = pad_output_ptr->mutable_data<float>();
conv_func(pad_input_data, pad_output_data);
// unpack output
if (extra_output_height != height || extra_output_width != width) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t h = 0; h < height; ++h) {
memcpy(
output_data + b * channels * height * width + c * height * width
+ h * width,
pad_output_data
+ b * channels * extra_output_height * extra_output_width
+ c * extra_output_height * extra_output_width
+ h * extra_output_width,
sizeof(float) * width);
}
}
}
}
if (bias_data != nullptr) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t i = 0; i < height * width; ++i) {
output_data[(b * channels + c) * height * width + i] += bias_data[c];
}
}
}
}
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
......@@ -12,49 +12,46 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_GLOBAL_AVG_POOLING_H_
#define MACE_KERNELS_GLOBAL_AVG_POOLING_H_
#ifndef MACE_KERNELS_ARM_CONV_2D_NEON_H_
#define MACE_KERNELS_ARM_CONV_2D_NEON_H_
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/types.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct GlobalAvgPoolingFunctor {
void operator()(const T *input,
const index_t *input_shape,
T *output,
StatsFuture *future) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
index_t image_size = height * width;
index_t input_offset = 0;
index_t total_channels = batch * channels;
for (int c = 0; c < total_channels; ++c) {
T sum = 0;
for (int i = 0; i < image_size; ++i) {
sum += input[input_offset + i];
}
output[c] = sum / image_size;
input_offset += image_size;
}
}
};
template <>
void GlobalAvgPoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input,
const index_t *input_shape,
float *output,
StatsFuture *future);
extern void Conv2dNeonK1x1S1(const float *input,
const float *filter,
const index_t batch,
const index_t height,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_GLOBAL_AVG_POOLING_H_
#endif // MACE_KERNELS_ARM_CONV_2D_NEON_H_
......@@ -16,7 +16,7 @@
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/kernels/arm/conv_2d_neon.h"
#include "mace/kernels/gemm.h"
namespace mace {
......
......@@ -16,7 +16,7 @@
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/kernels/arm/conv_2d_neon.h"
namespace mace {
namespace kernels {
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/depthwise_conv2d.h"
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
namespace {
void DepthwiseConv2dNCHW(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t multiplier = out_channels / in_channels;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
index_t c = m / multiplier;
index_t o = m % multiplier;
float sum = 0;
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h - pad_top;
index_t iw = w * stride_w + kw * dilation_w - pad_left;
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((o * in_channels) + c) * filter_height + kh) * filter_width
+ kw;
sum += input[in_offset] * filter[filter_offset];
}
}
}
output[out_offset] = sum;
}
}
}
}
}
} // namespace
extern void DepthwiseConv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
float *output);
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
float *output);
void DepthwiseConv2dFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape
{filter->dim(0) * filter->dim(1), filter->dim(1), filter->dim(2),
filter->dim(3)};
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter_shape.data(),
dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::FLOOR,
output_shape.data());
}
output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter_shape[2];
index_t filter_w = filter_shape[3];
MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
int valid_h_start = pad_top == 0 ? 0 : (pad_top - 1) / stride_h + 1;
int valid_h_stop = pad_bottom == 0
? height
: height - ((pad_bottom - 1) / stride_h + 1);
int valid_w_start = pad_left == 0 ? 0 : (pad_left - 1) / stride_w + 1;
int valid_w_stop = pad_right == 0
? width
: width - ((pad_right - 1) / stride_w + 1);
std::function<void(const float *input, float *output)> conv_func;
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S1(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
pad_top,
pad_left,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S2(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
pad_top,
pad_left,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNCHW(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
};
}
conv_func(input_data, output_data);
if (bias_data != nullptr) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t i = 0; i < height * width; ++i) {
output_data[(b * channels + c) * height * width + i] += bias_data[c];
}
}
}
}
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_ARM_DEPTHWISE_CONV2D_NEON_H_
#define MACE_KERNELS_ARM_DEPTHWISE_CONV2D_NEON_H_
#include "mace/core/types.h"
namespace mace {
namespace kernels {
void DepthwiseConv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output);
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_ARM_DEPTHWISE_CONV2D_NEON_H_
......@@ -16,7 +16,7 @@
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/kernels/arm/depthwise_conv2d_neon.h"
namespace mace {
namespace kernels {
......@@ -60,10 +60,10 @@ void DepthwiseConv2dNeonK3x3S1(const float *input,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output) {
const index_t multiplier = out_channels / in_channels;
const index_t in_image_size = in_height * in_width;
......@@ -277,10 +277,10 @@ void DepthwiseConv2dNeonK3x3S2(const float *input,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
const index_t valid_h_start,
const index_t valid_h_stop,
const index_t valid_w_start,
const index_t valid_w_stop,
float *output) {
const index_t multiplier = out_channels / in_channels;
const index_t in_image_size = in_height * in_width;
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/fully_connected.h"
#include "mace/kernels/gemm.h"
namespace mace {
namespace kernels {
void FullyConnectedFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
output->Resize(output_shape);
const index_t N = output->dim(0);
const index_t input_size = weight->dim(1);
const index_t output_size = weight->dim(0);
const float *input_ptr = input->data<float>();
const float *weight_ptr = weight->data<float>();
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
for (int i = 0; i < N; ++i) {
for (int j = 0; j < output_size; ++j) {
output_ptr[j + i * output_size] += bias_ptr[j];
}
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/local_response_norm.h"
namespace mace {
namespace kernels {
void LocalResponseNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
int depth_radius,
float bias,
float alpha,
float beta,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
index_t image_size = height * width;
index_t batch_size = channels * image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const int begin_input_c = std::max(static_cast<index_t>(0),
c - depth_radius);
const int end_input_c = std::min(channels, c + depth_radius + 1);
index_t pos = b * batch_size;
for (index_t hw = 0; hw < height * width; ++hw, ++pos) {
float accum = 0.f;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_ptr[pos + input_c * image_size];
accum += input_val * input_val;
}
const float multiplier = std::pow(bias + alpha * accum, -beta);
output_ptr[pos + c * image_size] =
input_ptr[pos + c * image_size] * multiplier;
}
}
}
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/pooling.h"
namespace mace {
namespace kernels {
namespace {
void MaxPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = std::numeric_limits<float>::lowest();
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res = std::max(res, input[input_offset]);
}
}
}
output[out_offset] = res;
}
}
}
}
}
void AvgPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = 0;
int block_size = 0;
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res += input[input_offset];
++block_size;
}
}
}
output[out_offset] = res / block_size;
}
}
}
}
}
} // namespace
void PoolingFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {
input_tensor->dim(1), input_tensor->dim(1), kernels_[0], kernels_[1]};
std::vector<int> paddings(2);
if (paddings_.empty()) {
kernels::CalcNCHWPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_,
strides_, padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input_tensor->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::CEIL,
output_shape.data());
}
output_tensor->Resize(output_shape);
const float *input = input_tensor->data<float>();
float *output = output_tensor->mutable_data<float>();
const index_t *input_shape = input_tensor->shape().data();
index_t batch = output_shape[0];
index_t channels = output_shape[1];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t input_channels = input_shape[1];
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
index_t in_image_size = input_height * input_width;
int filter_h = kernels_[0];
int filter_w = kernels_[1];
int stride_h = strides_[0];
int stride_w = strides_[1];
int dilation_h = dilations_[0];
int dilation_w = dilations_[1];
int pad_top = paddings[0] / 2;
int pad_left = paddings[1] / 2;
if (pooling_type_ == PoolingType::MAX) {
MaxPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else if (pooling_type_ == PoolingType::AVG) {
AvgPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else {
MACE_NOT_IMPLEMENTED;
}
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/softmax.h"
namespace mace {
namespace kernels {
void SoftmaxFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t class_count = input->dim(1);
const index_t class_size = input->dim(2) * input->dim(3);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) {
std::vector<float>
max_val(class_size, std::numeric_limits<float>::lowest());
std::vector<float> sum_val(class_size, 0.f);
// calculate max for each class
for (index_t c = 0; c < class_count; ++c) {
const float *input_ptr = input_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
max_val[k] = std::max(max_val[k], input_ptr[k]);
}
}
// calculate data - max for each class
#pragma omp parallel for
for (index_t c = 0; c < class_count; ++c) {
const float *input_ptr = input_data + (b * class_count + c) * class_size;
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] = ::exp(input_ptr[k] - max_val[k]);
}
}
// calculate sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
sum_val[k] += output_ptr[k];
}
}
// calculate (data - max) / sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] /= sum_val[k];
}
}
}
}
} // namespace kernels
} // namespace mace
......@@ -43,8 +43,11 @@ struct BatchNormFunctorBase {
const float relux_max_limit_;
};
template <DeviceType D, typename T>
struct BatchNormFunctor : BatchNormFunctorBase {
template<DeviceType D, typename T>
struct BatchNormFunctor;
template<>
struct BatchNormFunctor<DeviceType::CPU, float> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
const float relux_max_limit)
......@@ -67,29 +70,29 @@ struct BatchNormFunctor : BatchNormFunctorBase {
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard scale_mapper(scale);
Tensor::MappingGuard offset_mapper(offset);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
const T *scale_ptr = scale->data<T>();
const T *offset_ptr = offset->data<T>();
T *output_ptr = output->mutable_data<T>();
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
const float *offset_ptr = offset->data<float>();
float *output_ptr = output->mutable_data<float>();
std::vector<T> new_scale;
std::vector<T> new_offset;
std::vector<float> new_scale;
std::vector<float> new_offset;
if (!folded_constant_) {
new_scale.resize(channels);
new_offset.resize(channels);
Tensor::MappingGuard mean_mapper(mean);
Tensor::MappingGuard var_mapper(var);
const T *mean_ptr = mean->data<T>();
const T *var_ptr = var->data<T>();
const float *mean_ptr = mean->data<float>();
const float *var_ptr = var->data<float>();
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon);
......@@ -97,44 +100,21 @@ struct BatchNormFunctor : BatchNormFunctorBase {
}
}
const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data();
const float *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const float
*offset_data = folded_constant_ ? offset_ptr : new_offset.data();
const int elements = batch * height * width;
constexpr int c_tile_size = 4;
const int c_tiles = channels / c_tile_size;
const index_t remains_start = c_tiles * c_tile_size;
index_t channel_size = height * width;
index_t batch_size = channels * channel_size;
if (c_tiles > 0) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (int cb = 0; cb < c_tiles; ++cb) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert(c_tile_size == 4, "channels tile size must be 4");
int c = cb * c_tile_size;
int pos = i * channels + c;
float32x4_t scales = vld1q_f32(scale_data + c);
float32x4_t offsets = vld1q_f32(offset_data + c);
float32x4_t in = vld1q_f32(input_ptr + pos);
float32x4_t out = vfmaq_f32(offsets, scales, in);
vst1q_f32(output_ptr + pos, out);
#else
for (int ci = 0; ci < c_tile_size; ++ci) {
int c = cb * c_tile_size + ci;
index_t pos = i * channels + c;
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
}
#endif
}
}
}
if (remains_start < channels) {
// NEON is slower, so stick to the trivial implementaion
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (index_t c = remains_start; c < channels; ++c) {
index_t pos = i * channels + c;
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
index_t offset = b * batch_size + c * channel_size;
for (index_t hw = 0; hw < height * width; ++hw) {
output_ptr[offset + hw] =
scale_data[c] * input_ptr[offset + hw] + offset_data[c];
}
}
}
......@@ -143,24 +123,7 @@ struct BatchNormFunctor : BatchNormFunctorBase {
}
};
template <>
struct BatchNormFunctor<DeviceType::NEON, float> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
const float relux_max_limit)
: BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future);
};
template <typename T>
template<typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
......
......@@ -26,49 +26,41 @@
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct BiasAddFunctor {
template<DeviceType D, typename T>
struct BiasAddFunctor;
template<>
struct BiasAddFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard bias_mapper(bias);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
const T *bias_ptr = bias->data<T>();
T *output_ptr = output->mutable_data<T>();
const float *input_ptr = input->data<float>();
const float *bias_ptr = bias->data<float>();
float *output_ptr = output->mutable_data<float>();
#pragma omp parallel for collapse(4)
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < channels; ++c) {
index_t pos = (((n * height) + h) * width + w) * channels + c;
for (index_t hw = 0; hw < height * width; ++hw) {
index_t pos = (n * channels + c) * height * width + hw;
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
}
}
}
}
}
};
/*
template <>
void BiasAddFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
*/
template <typename T>
template<typename T>
struct BiasAddFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *input,
const Tensor *bias,
......
......@@ -24,7 +24,7 @@
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
template<DeviceType D, typename T>
struct ChannelShuffleFunctor {
explicit ChannelShuffleFunctor(const int groups) : groups_(groups) {}
......@@ -39,20 +39,25 @@ struct ChannelShuffleFunctor {
T *output_ptr = output->mutable_data<T>();
index_t batch = input->dim(0);
index_t height = input->dim(1);
index_t width = input->dim(2);
index_t channels = input->dim(3);
index_t bhw_fuse = batch * height * width;
int channels_per_group = channels / groups_;
#pragma omp parallel for
for (int bhw = 0; bhw < bhw_fuse; ++bhw) {
for (int c = 0; c < channels; ++c) {
index_t channel_base = bhw * channels;
output_ptr[channel_base + c] =
input_ptr[channel_base + c % groups_ * channels_per_group
+ c / groups_];
index_t channels = input->dim(1);
index_t height = input->dim(2);
index_t width = input->dim(3);
index_t image_size = height * width;
index_t batch_size = channels * image_size;
index_t channels_per_group = channels / groups_;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const T *input_base = input_ptr + b * batch_size;
T *output_base = output_ptr + b * batch_size;
index_t g = c % groups_;
index_t idx = c / groups_;
for (index_t hw = 0; hw < height * width; ++hw) {
output_base[c * image_size + hw] = input_base[
(c % groups_ * channels_per_group + c / groups_) * image_size + hw];
}
}
}
}
......@@ -60,7 +65,7 @@ struct ChannelShuffleFunctor {
const int groups_;
};
template <typename T>
template<typename T>
struct ChannelShuffleFunctor<DeviceType::OPENCL, T> {
explicit ChannelShuffleFunctor(const int groups) : groups_(groups) {}
......
此差异已折叠。
......@@ -368,7 +368,6 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
const int pad_left,
const int pad_right,
Tensor *output_tensor) {
Tensor::MappingGuard input_mapper(input_tensor);
const float *input = input_tensor->data<float>();
const index_t *input_shape = input_tensor->shape().data();
......
......@@ -25,15 +25,15 @@
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
template<DeviceType D, typename T>
struct DepthToSpaceOpFunctor {
explicit DepthToSpaceOpFunctor(const int block_size, bool d2s)
: block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const int batch_size = input->dim(0);
const int input_height = input->dim(1);
const int input_width = input->dim(2);
const int input_depth = input->dim(3);
const int input_depth = input->dim(1);
const int input_height = input->dim(2);
const int input_width = input->dim(3);
index_t output_depth, output_width, output_height;
......@@ -46,8 +46,8 @@ struct DepthToSpaceOpFunctor {
output_width = input_width / block_size_;
output_height = input_height / block_size_;
}
std::vector<index_t> output_shape = {batch_size, output_height,
output_width, output_depth};
std::vector<index_t> output_shape = {batch_size, output_depth,
output_height, output_width};
output->Resize(output_shape);
......@@ -59,23 +59,22 @@ struct DepthToSpaceOpFunctor {
if (d2s_) {
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int d = 0; d < output_depth; ++d) {
for (int h = 0; h < output_height; ++h) {
const int in_h = h / block_size_;
const int offset_h = (h % block_size_);
for (int w = 0; w < output_width; ++w) {
const int in_w = w / block_size_;
const int offset_w = w % block_size_;
const int offset_d =
const index_t in_w = w / block_size_;
const index_t offset_w = w % block_size_;
const index_t offset_d =
(offset_h * block_size_ + offset_w) * output_depth;
for (int d = 0; d < output_depth; ++d) {
const int in_d = d + offset_d;
const int o_index =
((b * output_height + h) * output_width + w) * output_depth +
d;
const int i_index =
((b * input_height + in_h) * input_width + in_w) *
input_depth +
in_d;
const index_t in_d = d + offset_d;
const index_t o_index =
((b * output_depth + d) * output_height + h) * output_width + w;
const index_t i_index =
((b * input_depth + in_d) * input_height + in_h) * input_width
+ in_w;
output_ptr[o_index] = input_ptr[i_index];
}
}
......@@ -84,6 +83,7 @@ struct DepthToSpaceOpFunctor {
} else {
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int d = 0; d < input_depth; ++d) {
for (int h = 0; h < input_height; ++h) {
const int out_h = h / block_size_;
const int offset_h = (h % block_size_);
......@@ -92,14 +92,14 @@ struct DepthToSpaceOpFunctor {
const int offset_w = (w % block_size_);
const int offset_d =
(offset_h * block_size_ + offset_w) * input_depth;
for (int d = 0; d < input_depth; ++d) {
const int out_d = d + offset_d;
const int o_index =
((b * output_height + out_h) * output_width + out_w) *
output_depth +
out_d;
const int i_index =
((b * input_height + h) * input_width + w) * input_depth + d;
const index_t o_index =
((b * output_depth + out_d) * output_height + out_h)
* output_width + out_w;
const index_t i_index =
((b * input_depth + d) * input_height + h) * input_width
+ w;
output_ptr[o_index] = input_ptr[i_index];
}
}
......@@ -112,7 +112,7 @@ struct DepthToSpaceOpFunctor {
bool d2s_;
};
template <typename T>
template<typename T>
struct DepthToSpaceOpFunctor<DeviceType::OPENCL, T> {
DepthToSpaceOpFunctor(const int block_size, bool d2s)
: block_size_(block_size), d2s_(d2s) {}
......
此差异已折叠。
......@@ -23,6 +23,7 @@
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/kernels/gemm.h"
namespace mace {
namespace kernels {
......@@ -41,7 +42,10 @@ struct FullyConnectedBase {
};
template <DeviceType D, typename T>
struct FullyConnectedFunctor : FullyConnectedBase {
struct FullyConnectedFunctor;
template <>
struct FullyConnectedFunctor<DeviceType::CPU, float>: FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type,
const ActivationType activation,
const float relux_max_limit)
......@@ -52,33 +56,25 @@ struct FullyConnectedFunctor : FullyConnectedBase {
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), 1, 1, weight->dim(0)};
std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
output->Resize(output_shape);
const index_t N = output->dim(0);
const index_t input_size = weight->dim(1);
const index_t output_size = weight->dim(0);
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_weight(weight);
Tensor::MappingGuard guard_bias(bias);
Tensor::MappingGuard guard_output(output);
const T *input_ptr = input->data<T>();
const T *weight_ptr = weight->data<T>();
const T *bias_ptr = bias == nullptr ? nullptr : bias->data<T>();
T *output_ptr = output->mutable_data<T>();
const float *input_ptr = input->data<float>();
const float *weight_ptr = weight->data<float>();
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
#pragma omp parallel for collapse(2)
Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
for (int i = 0; i < N; ++i) {
for (int out_idx = 0; out_idx < output_size; ++out_idx) {
T sum = 0;
if (bias_ptr != nullptr) sum = bias_ptr[out_idx];
index_t input_offset = i * input_size;
index_t weight_offset = out_idx * input_size;
for (int in_idx = 0; in_idx < input_size; ++in_idx) {
sum += input_ptr[input_offset] * weight_ptr[weight_offset];
input_offset++;
weight_offset++;
}
output_ptr[i * output_size + out_idx] = sum;
for (int j = 0; j < output_size; ++j) {
output_ptr[j + i * output_size] += bias_ptr[j];
}
}
......@@ -87,20 +83,6 @@ struct FullyConnectedFunctor : FullyConnectedBase {
}
};
template <>
struct FullyConnectedFunctor<DeviceType::NEON, float> : FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type,
const ActivationType activation,
const float relux_max_limit)
: FullyConnectedBase(weight_type, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
};
template <typename T>
struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type,
......
......@@ -747,7 +747,7 @@ void Gemv(const float *m_ptr,
for (index_t h = 0; h < remain_h; ++h) {
float32x4_t vsum0 = vdupq_n_f32(0.f);
const float *m_ptr0 = m_ptr + (h + remain_start_height) * width;
const float *v_ptr0 = v_ptr;
const float *v_ptr0 = v_ptr + b * width;
for (index_t w = 0; w < width_d4; ++w) {
float32x4_t vm = vld1q_f32(m_ptr0);
float32x4_t vv = vld1q_f32(v_ptr0);
......@@ -761,7 +761,7 @@ void Gemv(const float *m_ptr,
m_ptr0++;
v_ptr0++;
}
out_ptr[remain_start_height + h] = sum;
out_ptr[remain_start_height + h + b * height] = sum;
}
}
#else
......
......@@ -17,8 +17,11 @@
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct LocalResponseNormFunctor {
template<DeviceType D, typename T>
struct LocalResponseNormFunctor;
template<>
struct LocalResponseNormFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input,
int depth_radius,
float bias,
......@@ -27,46 +30,37 @@ struct LocalResponseNormFunctor {
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
const int elements = batch * height * width;
index_t image_size = height * width;
index_t batch_size = channels * image_size;
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const int begin_input_c = std::max(static_cast<index_t>(0),
c - depth_radius);
const int end_input_c = std::min(channels, c + depth_radius + 1);
index_t pos = i * channels;
index_t pos = b * batch_size;
for (index_t hw = 0; hw < height * width; ++hw, ++pos) {
float accum = 0.f;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_ptr[pos + input_c];
const float input_val = input_ptr[pos + input_c * image_size];
accum += input_val * input_val;
}
const float multiplier = std::pow(bias + alpha * accum, -beta);
output_ptr[pos + c] = input_ptr[pos + c] * multiplier;
output_ptr[pos + c * image_size] =
input_ptr[pos + c * image_size] * multiplier;
}
}
}
}
};
template <>
struct LocalResponseNormFunctor<DeviceType::NEON, float> {
void operator()(const Tensor *input,
int depth_radius,
float bias,
float alpha,
float beta,
Tensor *output,
StatsFuture *future);
};
} // namespace kernels
......
......@@ -57,7 +57,10 @@ struct PoolingFunctorBase {
};
template <DeviceType D, typename T>
struct PoolingFunctor : PoolingFunctorBase {
struct PoolingFunctor;
template <>
struct PoolingFunctor<DeviceType::CPU, float>: PoolingFunctorBase {
PoolingFunctor(const PoolingType pooling_type,
const int *kernels,
const int *strides,
......@@ -68,43 +71,141 @@ struct PoolingFunctor : PoolingFunctorBase {
pooling_type, kernels, strides, padding_type, paddings, dilations) {
}
void MaxPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = std::numeric_limits<float>::lowest();
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res = std::max(res, input[input_offset]);
}
}
}
output[out_offset] = res;
}
}
}
}
}
void AvgPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = 0;
int block_size = 0;
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res += input[input_offset];
++block_size;
}
}
}
output[out_offset] = res / block_size;
}
}
}
}
}
void operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {
kernels_[0], kernels_[1], input_tensor->dim(3), input_tensor->dim(3)};
input_tensor->dim(1), input_tensor->dim(1), kernels_[0], kernels_[1]};
std::vector<int> paddings(2);
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
kernels::CalcNCHWPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_,
strides_, padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::CEIL,
CalcNCHWOutputSize(input_tensor->shape().data(),
filter_shape.data(),
paddings_.data(),
dilations_,
strides_,
RoundType::CEIL,
output_shape.data());
}
output_tensor->Resize(output_shape);
Tensor::MappingGuard in_guard(input_tensor);
Tensor::MappingGuard out_guard(output_tensor);
const T *input = input_tensor->data<T>();
T *output = output_tensor->mutable_data<T>();
Tensor::MappingGuard input_guard(input_tensor);
Tensor::MappingGuard output_guard(output_tensor);
const float *input = input_tensor->data<float>();
float *output = output_tensor->mutable_data<float>();
const index_t *input_shape = input_tensor->shape().data();
index_t batch = output_shape[0];
index_t height = output_shape[1];
index_t width = output_shape[2];
index_t channels = output_shape[3];
index_t channels = output_shape[1];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t input_height = input_shape[1];
index_t input_width = input_shape[2];
index_t input_channels = input_shape[3];
index_t in_image_size = input_height * input_width;
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
int kernel_h = kernels_[0];
int kernel_w = kernels_[1];
int filter_h = kernels_[0];
int filter_w = kernels_[1];
int stride_h = strides_[0];
int stride_w = strides_[1];
......@@ -112,84 +213,47 @@ struct PoolingFunctor : PoolingFunctorBase {
int dilation_h = dilations_[0];
int dilation_w = dilations_[1];
// The left-upper most offset of the padded input
int padded_h_start = 0 - paddings[0] / 2;
int padded_w_start = 0 - paddings[1] / 2;
if (pooling_type_ == MAX) {
#pragma omp parallel for collapse(4)
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
index_t out_offset =
(((b * height) + h) * width + w) * channels + c;
index_t in_offset = b * in_image_size * input_channels + c;
T res = std::numeric_limits<T>::lowest();
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width) {
index_t input_offset =
in_offset + (inh * input_width + inw) * input_channels;
res = std::max(res, input[input_offset]);
}
}
}
output[out_offset] = res;
}
}
}
}
} else if (pooling_type_ == AVG) {
#pragma omp parallel for collapse(4)
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
index_t out_offset =
(((b * height) + h) * width + w) * channels + c;
index_t in_offset = b * in_image_size * input_channels + c;
T sum = static_cast<T>(0);
int block_size = 0;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width) {
index_t input_offset =
in_offset + (inh * input_width + inw) * input_channels;
sum += input[input_offset];
block_size += 1;
}
}
}
output[out_offset] = sum / block_size;
}
}
}
}
}
}
};
int pad_top = paddings[0] / 2;
int pad_left = paddings[1] / 2;
template <>
struct PoolingFunctor<DeviceType::NEON, float> : PoolingFunctorBase {
PoolingFunctor(const PoolingType pooling_type,
const int *kernels,
const int *strides,
const Padding padding_type,
const std::vector<int> &paddings,
const int *dilations)
: PoolingFunctorBase(
pooling_type, kernels, strides, padding_type, paddings, dilations) {
if (pooling_type_ == PoolingType::MAX) {
MaxPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else if (pooling_type_ == PoolingType::AVG) {
AvgPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else {
MACE_NOT_IMPLEMENTED;
}
}
void operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future);
};
template <typename T>
......
......@@ -68,8 +68,7 @@ inline float ComputeLerp(const float top_left,
return top + (bottom - top) * y_lerp;
}
template <typename T>
void ResizeImage(const T *images,
inline void ResizeImage(const float *images,
const index_t batch_size,
const index_t in_height,
const index_t in_width,
......@@ -78,38 +77,31 @@ void ResizeImage(const T *images,
const index_t channels,
const std::vector<CachedInterpolation> &xs_vec,
const std::vector<CachedInterpolation> &ys,
T *output) {
const index_t in_batch_num_values = channels * in_height * in_width;
const index_t out_batch_num_values = channels * out_height * out_width;
float *output) {
const CachedInterpolation *xs = xs_vec.data();
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch_size; ++b) {
for (index_t c = 0; c < channels; ++c) {
const float
*channel_input_ptr = images + (b * channels + c) * in_height * in_width;
float *channel_output_ptr =
output + (b * channels + c) * out_height * out_width;
for (index_t y = 0; y < out_height; ++y) {
const T *batch_input_ptr = images + in_batch_num_values * b;
T *batch_output_ptr = output + out_batch_num_values * b;
const T *y_lower_input_ptr =
batch_input_ptr + ys[y].lower * in_width * channels;
const T *y_upper_input_ptr =
batch_input_ptr + ys[y].upper * in_width * channels;
T *y_output_ptr = batch_output_ptr + y * out_width * channels;
const float *y_lower_input_ptr =
channel_input_ptr + ys[y].lower * in_width;
const float *y_upper_input_ptr =
channel_input_ptr + ys[y].upper * in_width;
const float ys_lerp = ys[y].lerp;
for (index_t x = 0; x < out_width; ++x) {
const float xs_lerp = xs[x].lerp;
const T *top_left_ptr = y_lower_input_ptr + xs[x].lower * channels;
const T *top_right_ptr = y_lower_input_ptr + xs[x].upper * channels;
const T *bottom_left_ptr = y_upper_input_ptr + xs[x].lower * channels;
const T *bottom_right_ptr = y_upper_input_ptr + xs[x].upper * channels;
T *output_ptr = y_output_ptr + x * channels;
for (index_t c = 0; c < channels; ++c) {
const T top_left = top_left_ptr[c];
const T top_right = top_right_ptr[c];
const T bottom_left = bottom_left_ptr[c];
const T bottom_right = bottom_right_ptr[c];
output_ptr[c] = ComputeLerp(top_left, top_right, bottom_left,
const float top_left = y_lower_input_ptr[xs[x].lower];
const float top_right = y_lower_input_ptr[xs[x].upper];
const float bottom_left = y_upper_input_ptr[xs[x].lower];
const float bottom_right = y_upper_input_ptr[xs[x].upper];
channel_output_ptr[y * out_width + x] =
ComputeLerp(top_left, top_right, bottom_left,
bottom_right, xs_lerp, ys_lerp);
}
}
......@@ -132,30 +124,35 @@ struct ResizeBilinearFunctorBase {
index_t out_width_;
};
template <DeviceType D, typename T>
struct ResizeBilinearFunctor : ResizeBilinearFunctorBase {
template<DeviceType D, typename T>
struct ResizeBilinearFunctor;
template<>
struct ResizeBilinearFunctor<DeviceType::CPU, float>
: ResizeBilinearFunctorBase {
ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners)
: ResizeBilinearFunctorBase(size, align_corners) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channels = input->dim(1);
const index_t in_height = input->dim(2);
const index_t in_width = input->dim(3);
index_t out_height = out_height_;
index_t out_width = out_width_;
MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> out_shape{batch, out_height, out_width, channels};
std::vector<index_t> out_shape{batch, channels, out_height, out_width};
output->Resize(out_shape);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
if (out_height == in_height && out_width == in_width) {
std::copy(input_data, input_data + channels * in_height * in_width,
std::copy(input_data,
input_data + batch * channels * in_height * in_width,
output_data);
return;
}
......@@ -177,7 +174,7 @@ struct ResizeBilinearFunctor : ResizeBilinearFunctorBase {
}
};
template <typename T>
template<typename T>
struct ResizeBilinearFunctor<DeviceType::OPENCL, T>
: ResizeBilinearFunctorBase {
ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners)
......
......@@ -19,6 +19,7 @@
#include <functional>
#include <memory>
#include <vector>
#include <limits>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
......@@ -29,50 +30,66 @@
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct SoftmaxFunctor {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future) {
Tensor::MappingGuard logits_guard(logits);
template<DeviceType D, typename T>
struct SoftmaxFunctor;
template<>
struct SoftmaxFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t class_count = input->dim(1);
const index_t class_size = input->dim(2) * input->dim(3);
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *logits_ptr = logits->data<T>();
T *output_ptr = output->mutable_data<T>();
auto &logits_shape = logits->shape();
const index_t batch_size =
std::accumulate(logits_shape.begin(), logits_shape.end() - 1, 1,
std::multiplies<index_t>());
const index_t num_classes = logits_shape.back();
#pragma omp parallel
{
// Allocate per thread buffer
std::vector<T> exp_data(num_classes);
#pragma omp for
for (index_t i = 0; i < batch_size; ++i) {
const index_t pos = i * num_classes;
T max_value = logits_ptr[pos];
for (index_t c = 1; c < num_classes; ++c) {
max_value = std::max(max_value, logits_ptr[pos + c]);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) {
std::vector<float>
max_val(class_size, std::numeric_limits<float>::lowest());
std::vector<float> sum_val(class_size, 0.f);
// calculate max for each class
for (index_t c = 0; c < class_count; ++c) {
const float
*input_ptr = input_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
max_val[k] = std::max(max_val[k], input_ptr[k]);
}
// TODO(liuqi): check overflow?
T sum = 0;
for (index_t c = 0; c < num_classes; ++c) {
exp_data[c] = ::exp((logits_ptr[pos + c] - max_value));
sum += exp_data[c];
}
for (index_t c = 0; c < num_classes; ++c) {
output_ptr[pos + c] = exp_data[c] / sum;
// calculate data - max for each class
#pragma omp parallel for
for (index_t c = 0; c < class_count; ++c) {
const float
*input_ptr = input_data + (b * class_count + c) * class_size;
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] = ::exp(input_ptr[k] - max_val[k]);
}
}
// calculate sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
sum_val[k] += output_ptr[k];
}
}
};
template <>
struct SoftmaxFunctor<DeviceType::NEON, float> {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future);
// calculate (data - max) / sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] /= sum_val[k];
}
}
}
}
};
template <typename T>
template<typename T>
struct SoftmaxFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future);
......
......@@ -35,11 +35,6 @@ void Register_Activation(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
ActivationOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
ActivationOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -31,9 +31,21 @@ void ReluBenchmark(
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("Activation", "ReluBM")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
......@@ -43,11 +55,7 @@ void ReluBenchmark(
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Activation", "ReluBM")
.Input("Input")
.Output("Output")
.AddStringArg("activation", "RELU")
.Finalize(net.NewOperatorDef());
MACE_NOT_IMPLEMENTED;
}
// Warm-up
......@@ -93,7 +101,11 @@ void ReluxBenchmark(
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
}
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
......@@ -157,10 +169,23 @@ void PreluBenchmark(
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, float>("Alpha", {channels});
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("Activation", "PreluBM")
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "PRELU")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Alpha", "AlphaImage",
......@@ -173,12 +198,7 @@ void PreluBenchmark(
.AddStringArg("activation", "PRELU")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Activation", "PreluBM")
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "PRELU")
.Finalize(net.NewOperatorDef());
MACE_NOT_IMPLEMENTED;
}
// Warm-up
......@@ -224,7 +244,11 @@ void TanhBenchmark(
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
}
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
......@@ -286,7 +310,11 @@ void SigmoidBenchmark(
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
}
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
......
......@@ -269,16 +269,11 @@ void TestSimplePrelu() {
net.RunOp(D);
}
if (D == DeviceType::NEON) {
if (D == DeviceType::CPU) {
auto expected = CreateTensor<float>(
{2, 2, 2, 2},
{-14, 7, -12, 6, -15, -15, -12, -12, -6, 3, -4, 2, -3, -3, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} else {
auto expected = CreateTensor<float>(
{2, 2, 2, 2},
{-14, 7, -12, 6, -10, -15, -8, -12, -6, 3, -4, 2, -2, -3, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
}
} // namespace
......@@ -287,10 +282,6 @@ TEST_F(ActivationOpTest, CPUSimplePrelu) {
TestSimplePrelu<DeviceType::CPU>();
}
TEST_F(ActivationOpTest, NEONSimplePrelu) {
TestSimplePrelu<DeviceType::NEON>();
}
TEST_F(ActivationOpTest, OPENCLSimplePrelu) {
TestSimplePrelu<DeviceType::OPENCL>();
}
......
......@@ -35,12 +35,6 @@ void Register_AddN(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
AddNOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
AddNOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -35,11 +35,6 @@ void Register_BatchNorm(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
BatchNormOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
BatchNormOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -30,13 +30,29 @@ void BatchNorm(
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, T>("Scale", {channels});
net.AddRandomInput<D, T>("Offset", {channels});
net.AddRandomInput<D, T>("Mean", {channels});
net.AddRandomInput<D, T>("Var", {channels}, true);
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("BatchNorm", "BatchNormBM")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Scale", "ScaleImage",
......@@ -57,15 +73,7 @@ void BatchNorm(
.Output("Output")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("BatchNorm", "BatchNormBM")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
MACE_NOT_IMPLEMENTED;
}
// tuning
......
......@@ -34,7 +34,22 @@ void Simple() {
net.AddInputFromArray<D, float>("Mean", {1}, {10});
net.AddInputFromArray<D, float>("Var", {1}, {11.67f});
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Scale", "ScaleImage",
......@@ -61,18 +76,6 @@ void Simple() {
// Transfer output
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
// Check
......@@ -97,17 +100,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
......@@ -117,9 +110,30 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
// Construct graph
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
......@@ -170,15 +184,6 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) {
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-1)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
......@@ -188,9 +193,29 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-1)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
......@@ -242,15 +267,6 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
......@@ -260,9 +276,29 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
......@@ -313,15 +349,6 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-1)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
......@@ -331,9 +358,29 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-1)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
......@@ -375,63 +422,6 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-1, 1e-2);
}
TEST_F(BatchNormOpTest, NEONTest) {
srand(time(NULL));
unsigned int seed;
// generate random input
index_t batch = 1 + rand_r(&seed) % 10;
index_t channels = 3 + rand_r(&seed) % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels});
// run cpu
net.RunOp();
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNeon")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputNeon")
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
// Run on neon
net.RunOp(DeviceType::NEON);
net.Sync();
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
1e-5, 1e-4);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -29,10 +29,22 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) {
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, T>("Bias", {channels}, true);
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("BiasAdd", "BiasAddBM")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Bias", "BiasImage",
......@@ -43,11 +55,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) {
.Output("Output")
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("BiasAdd", "BiasAddBM")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
MACE_NOT_IMPLEMENTED;
}
// Warm-up
......
......@@ -31,7 +31,23 @@ void BiasAddSimple() {
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
net.AddInputFromArray<D, float>("Bias", {1}, {0.5f});
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Bias", "BiasImage",
......@@ -49,13 +65,7 @@ void BiasAddSimple() {
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
MACE_NOT_IMPLEMENTED;
}
// Check
......@@ -81,22 +91,33 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
index_t height = 64 + rand_r(&seed) % 50;
index_t width = 64 + rand_r(&seed) % 50;
// Construct graph
OpsTestNet net;
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
// Construct graph
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
......@@ -130,22 +151,32 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
index_t height = 103 + rand_r(&seed) % 100;
index_t width = 113 + rand_r(&seed) % 100;
// Construct graph
OpsTestNet net;
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
// Construct graph
OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
......
......@@ -34,7 +34,14 @@ class ChannelShuffleOp : public Operator<D, T> {
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
int channels = input->dim(3);
int channels;
if (D == OPENCL) {
channels = input->dim(3);
} else if (D == CPU) {
channels = input->dim(1);
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(channels % group_ == 0,
"input channels must be an integral multiple of group. ",
input->dim(3));
......
......@@ -29,9 +29,20 @@ void ChannelShuffle(
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, height, channels, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("Softmax", "SoftmaxBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
......@@ -41,10 +52,7 @@ void ChannelShuffle(
.AddIntArg("group", group)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Softmax", "SoftmaxBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
MACE_NOT_IMPLEMENTED;
}
// Warm-up
......
......@@ -22,21 +22,31 @@ namespace test {
class ChannelShuffleOpTest : public OpsTestBase {};
TEST_F(ChannelShuffleOpTest, C8G4_CPU) {
// Construct graph
OpsTestNet net;
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
.Input("Input")
.Output("Output")
.AddIntArg("group", 4)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 1, 2, 8},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
// Construct graph
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("group", 4)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
// Check
auto expected = CreateTensor<float>(
......
......@@ -33,11 +33,6 @@ void Register_Concat(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
ConcatOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
ConcatOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -35,12 +35,6 @@ void Register_Conv2D(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
Conv2dOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
Conv2dOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -41,21 +41,34 @@ void Conv2d(int iters,
OpsTestNet net;
// Add input data
if (D == DeviceType::NEON) {
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h,
kernel_w});
net.AddRandomInput<D, float>("Bias", {output_channels});
} else {
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
net.AddRandomInput<D, float>("Filter",
{kernel_h, kernel_w, output_channels,
channels});
net.AddRandomInput<D, float>("Bias", {output_channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {dilation, dilation})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage",
......@@ -73,16 +86,7 @@ void Conv2d(int iters,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {dilation, dilation})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
MACE_NOT_IMPLEMENTED;
}
net.Setup(D);
......@@ -134,7 +138,6 @@ void Conv2d(int iters,
#define BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, NEON); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, OPENCL); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, OPENCL);
......
此差异已折叠。
......@@ -37,10 +37,17 @@ class DepthToSpaceOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
int input_depth = input->dim(3);
int input_depth;
if (D == CPU) {
input_depth = input->dim(1);
} else if (D == OPENCL) {
input_depth = input->dim(3);
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(input_depth % (block_size_ * block_size_) == 0,
"input depth should be dividable by block_size * block_size",
input->dim(3));
input_depth);
MACE_CHECK((input_depth % 4) == 0,
"input channel should be dividable by 4");
functor_(input, output, future);
......
......@@ -29,9 +29,20 @@ void DepthToSpace(
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("DepthToSpace", "DepthToSpaceBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
......@@ -41,10 +52,7 @@ void DepthToSpace(
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("DepthToSpace", "DepthToSpaceBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
MACE_NOT_IMPLEMENTED;
}
// Warm-up
......
......@@ -36,11 +36,21 @@ void RunDepthToSpace(const bool d2s,
const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest";
// Construct graph
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder(ops_name, ops_test_name)
.Input("Input")
.Output("Output")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
} else {
BufferToImage<D, float>(&net, "Input", "InputImage",
......@@ -50,9 +60,10 @@ void RunDepthToSpace(const bool d2s,
.Output("OutputImage")
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
}
// Run
net.RunOp(D);
}
if (D == DeviceType::OPENCL) {
ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "Output",
......@@ -176,22 +187,31 @@ void RandomTest(const bool d2s, const int block_size,
const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest";
// Add input data
net.AddRandomInput<D, float>("Input1", shape);
net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder(ops_name, ops_test_name)
.Input("Input1")
.Input("InputNCHW")
.AddIntArg("block_size", block_size)
.Output("Output")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
BufferToImage<D, T>(&net, "Input1", "InputImg1",
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW,
"Output",
NHWC);
BufferToImage<D, T>(&net, "Input", "InputImg",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder(ops_name, ops_test_name)
.Input("InputImg1")
.Input("InputImg")
.AddIntArg("block_size", block_size)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("OutputImg")
......
......@@ -35,12 +35,6 @@ void Register_DepthwiseConv2d(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
DepthwiseConv2dOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
DepthwiseConv2dOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -40,21 +40,34 @@ void DepthwiseConv2d(int iters,
OpsTestNet net;
// Add input data
if (D == DeviceType::NEON) {
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input",
{batch, input_channels, height, width});
net.AddRandomInput<D, float>(
"Filter", {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
} else {
} else if (D == DeviceType::OPENCL) {
net.AddRandomInput<D, float>("Input",
{batch, height, width, input_channels});
net.AddRandomInput<D, float>(
"Filter", {kernel_h, kernel_w, input_channels, multiplier});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::OPENCL) {
if (D == DeviceType::CPU) {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage",
......@@ -72,16 +85,7 @@ void DepthwiseConv2d(int iters,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
MACE_NOT_IMPLEMENTED;
}
net.Setup(D);
......@@ -131,8 +135,7 @@ void DepthwiseConv2d(int iters,
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, OPENCL); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, OPENCL); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, NEON);
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, OPENCL);
BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 32, 56, 56, 3, 3, 2, VALID, 1);
......
此差异已折叠。
......@@ -35,11 +35,6 @@ void Register_Eltwise(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
EltwiseOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
EltwiseOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -35,11 +35,6 @@ void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
FoldedBatchNormOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::NEON, float>);
}
} // namespace ops
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -56,7 +56,6 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value,
print output_name, 'MACE VS', platform.upper(
), 'similarity: ', similarity
if (mace_runtime == "cpu" and similarity > 0.999) or \
(mace_runtime == "neon" and similarity > 0.999) or \
(mace_runtime == "gpu" and similarity > 0.995) or \
(mace_runtime == "dsp" and similarity > 0.930):
print '===================Similarity Test Passed=================='
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册