diff --git a/mace/core/operator.cc b/mace/core/operator.cc index a260b2c48d9c712f36b61362a5b1d83449f3c8f5..908a934dbdfb6483e35d7d7417e04e2851a8b24d 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -82,6 +82,7 @@ extern void Register_BiasAdd(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry); +extern void Register_Deconv2D(OperatorRegistry *op_registry); extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_Dequantize(OperatorRegistry *op_registry); @@ -122,6 +123,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_ChannelShuffle(this); ops::Register_Concat(this); ops::Register_Conv2D(this); + ops::Register_Deconv2D(this); ops::Register_DepthToSpace(this); ops::Register_DepthwiseConv2d(this); ops::Register_Dequantize(this); diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index 1f58311a015f1a8a32fec3fd8633304778d58ae9..c0d55f16b7e6fbad4f7c0a3d0073aa4e22311822 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -14,6 +14,7 @@ #include "mace/kernels/conv_pool_2d_util.h" +#include #include namespace mace { @@ -147,6 +148,8 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC output_shape[3] = output_channels; } + + void CalcOutputSize(const index_t *input_shape, // NHWC const index_t *filter_shape, // HWOI const int *padding_size, @@ -161,14 +164,7 @@ void CalcOutputSize(const index_t *input_shape, // NHWC "If dilations > 1, strides should be 1"); MACE_CHECK_NOTNULL(output_shape); MACE_CHECK_NOTNULL(padding_size); - /* - * Convlution arithmetic: - * o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1 - * Pooling arithmetic: - * o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1 - * For details, see https://arxiv.org/pdf/1603.07285.pdf or - * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html - */ + output_shape[0] = input_shape[0]; if (round_type == FLOOR) { output_shape[1] = static_cast( @@ -454,5 +450,6 @@ void ConstructNHWCInputWithPadding(const Tensor *input_tensor, } } } + } // namespace kernels } // namespace mace diff --git a/mace/kernels/deconv_2d.h b/mace/kernels/deconv_2d.h new file mode 100644 index 0000000000000000000000000000000000000000..fef536853e19b4e975efaf2970293fefb815a565 --- /dev/null +++ b/mace/kernels/deconv_2d.h @@ -0,0 +1,350 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_DECONV_2D_H_ +#define MACE_KERNELS_DECONV_2D_H_ + +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#include +#endif +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" +#include "mace/kernels/activation.h" +#include "mace/kernels/conv_pool_2d_util.h" +#include "mace/utils/utils.h" + +namespace mace { +namespace kernels { + +namespace deconv { + +template +void Deconv2dNCHW(const T *input, + const T *filter, + const T *bias, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t out_height, + const index_t out_width, + const index_t out_channels, + const index_t filter_height, + const index_t filter_width, + const index_t stride_h, + const index_t stride_w, + const int padding_top, + const int padding_left, + float *output) { +#pragma omp parallel for collapse(4) + for (index_t b = 0; b < batch; ++b) { + for (index_t oc = 0; oc < out_channels; ++oc) { + for (index_t oh = 0; oh < out_height; ++oh) { + for (index_t ow = 0; ow < out_width; ++ow) { + index_t filter_start_y, filter_start_x; + index_t start_x = std::max(0, ow + stride_w -1 - padding_left); + index_t start_y = std::max(0, oh + stride_h -1 - padding_top); + start_x /= stride_w; + start_y /= stride_h; + filter_start_x = padding_left + stride_w * start_x - ow; + filter_start_y = padding_top + stride_h * start_y - oh; + filter_start_x = filter_width - 1 - filter_start_x; + filter_start_y = filter_height - 1 - filter_start_y; + T out_value = 0; + index_t out_pos = + ((b * out_channels + oc) * out_height + oh) * out_width + ow; + for (index_t ic = 0; ic < in_channels; ++ic) { + for (index_t f_y = filter_start_y, ih = start_y; + f_y >= 0 && ih < in_height; f_y -= stride_h, ++ih) { + for (index_t f_x = filter_start_x, iw = start_x; + f_x >= 0 && iw < in_width; f_x -= stride_w, ++iw) { + index_t weight_pos = + ((oc * in_channels + ic) * filter_height + f_y) + * filter_width + f_x; + index_t in_pos = + ((b * in_channels + ic) * in_height + ih) + * in_width + iw; + out_value += input[in_pos] * filter[weight_pos]; + } + } + } + if (bias != nullptr) + out_value += bias[oc]; + output[out_pos] = out_value; + } + } + } + } +} +} // namespace deconv + +struct Deconv2dFunctorBase { + Deconv2dFunctorBase(const int *strides, + const Padding &padding_type, + const std::vector &paddings, + const std::vector &output_shape, + const ActivationType activation, + const float relux_max_limit) + : strides_(strides), + padding_type_(padding_type), + paddings_(paddings), + output_shape_(output_shape), + activation_(activation), + relux_max_limit_(relux_max_limit) {} + + static void CalcDeconvOutputSize( + const index_t *input_shape, // NHWC + const index_t *filter_shape, // OIHW + const int *strides, + index_t *output_shape, + const int *padding_size, + const bool isNCHW = false, + const bool isOIHW = false) { + MACE_CHECK_NOTNULL(output_shape); + MACE_CHECK_NOTNULL(padding_size); + MACE_CHECK_NOTNULL(input_shape); + MACE_CHECK_NOTNULL(filter_shape); + MACE_CHECK_NOTNULL(strides); + + const index_t output_channel = isOIHW ? filter_shape[0] : filter_shape[2]; + + const index_t in_height = isNCHW ? input_shape[2] : input_shape[1]; + const index_t in_width = isNCHW ? input_shape[3] : input_shape[2]; + const index_t in_channels = isNCHW ? input_shape[1] : input_shape[3]; + + const index_t extended_input_height = + (in_height - 1) * strides[0] + 1 + padding_size[0]; + const index_t extended_input_width = + (in_width - 1) * strides[1] + 1 + padding_size[1]; + + const index_t filter_h = isOIHW ? filter_shape[2] : filter_shape[0]; + const index_t filter_w = isOIHW ? filter_shape[3] : filter_shape[1]; + + index_t out_height = extended_input_height - filter_h + 1; + index_t out_width = extended_input_width - filter_w + 1; + + output_shape[0] = input_shape[0]; + if (isNCHW) { + output_shape[1] = output_channel; + output_shape[2] = out_height; + output_shape[3] = out_width; + } else { + output_shape[1] = out_height; + output_shape[2] = out_width; + output_shape[3] = output_channel; + } + } + + static void CalcDeconvPaddingAndInputSize( + const index_t *input_shape, // NHWC + const index_t *filter_shape, // OIHW + const int *strides, + Padding padding, + const index_t *output_shape, + int *padding_size, + const bool isNCHW = false, + const bool isOIHW = false) { + MACE_CHECK_NOTNULL(output_shape); + MACE_CHECK_NOTNULL(padding_size); + MACE_CHECK_NOTNULL(input_shape); + MACE_CHECK_NOTNULL(filter_shape); + MACE_CHECK_NOTNULL(strides); + + const index_t in_height = isNCHW ? input_shape[2] : input_shape[1]; + const index_t in_width = isNCHW ? input_shape[3] : input_shape[2]; + const index_t in_channels = isNCHW ? input_shape[1] : input_shape[3]; + + const index_t out_height = isNCHW ? output_shape[2] : output_shape[1]; + const index_t out_width = isNCHW ? output_shape[3] : output_shape[2]; + const index_t out_channels = isNCHW ? output_shape[1] : output_shape[3]; + + const index_t extended_input_height = (in_height - 1) * strides[0] + 1; + const index_t extended_input_width = (in_width - 1) * strides[1] + 1; + + const index_t filter_h = isOIHW ? filter_shape[2] : filter_shape[0]; + const index_t filter_w = isOIHW ? filter_shape[3] : filter_shape[1]; + + index_t expected_input_height = 0, expected_input_width = 0; + + switch (padding) { + case VALID: + expected_input_height = + (out_height - filter_h) / strides[0] + 1; + expected_input_width = + (out_width - filter_w) / strides[1] + 1; + break; + case SAME: + expected_input_height = + (out_height - 1) / strides[0] + 1; + expected_input_width = + (out_width - 1) / strides[1] + 1; + break; + default: + MACE_CHECK(false, "Unsupported padding type: ", padding); + } + + MACE_CHECK(expected_input_height == in_height, + expected_input_height, "!=", in_height); + MACE_CHECK(expected_input_width == in_width, + expected_input_width, "!=", in_width); + + const int p_h = static_cast(out_height + + filter_h - 1 - extended_input_height); + const int p_w = static_cast(out_width + + filter_w - 1 - extended_input_width); + + padding_size[0] = std::max(0, p_h); + padding_size[1] = std::max(0, p_w); + } + + const int *strides_; // [stride_h, stride_w] + const Padding padding_type_; + std::vector paddings_; + const ActivationType activation_; + const float relux_max_limit_; + std::vector output_shape_; +}; + +template +struct Deconv2dFunctor : Deconv2dFunctorBase { + Deconv2dFunctor(const int *strides, + const Padding &padding_type, + const std::vector &paddings, + const std::vector &output_shape, + const ActivationType activation, + const float relux_max_limit, + const bool is_filter_transformed, + ScratchBuffer *scratch) + : Deconv2dFunctorBase(strides, + padding_type, + paddings, + output_shape, + activation, + relux_max_limit) {} + + void operator()(const Tensor *input, // NCHW + const Tensor *filter, // OIHW + const Tensor *bias, + Tensor *output, + StatsFuture *future) { + MACE_CHECK_NOTNULL(input); + MACE_CHECK_NOTNULL(filter); + MACE_CHECK_NOTNULL(output); + + std::vector output_shape(4); + if (output_shape_.size() == 4) { + output_shape[0] = output_shape_[0]; + output_shape[1] = output_shape_[3]; + output_shape[2] = output_shape_[1]; + output_shape[3] = output_shape_[2]; + paddings_.clear(); + paddings_ = std::vector(2, 0); + CalcDeconvPaddingAndInputSize( + input->shape().data(), + filter->shape().data(), + strides_, padding_type_, + output_shape.data(), + paddings_.data(), true, true); + output->Resize(output_shape); + } else { + output_shape_.clear(); + output_shape_ = std::vector(4, 0); + CalcDeconvOutputSize(input->shape().data(), + filter->shape().data(), + strides_, + output_shape_.data(), + paddings_.data(), true, true); + output->Resize(output_shape_); + } + index_t batch = output->dim(0); + index_t channels = output->dim(1); + index_t height = output->dim(2); + index_t width = output->dim(3); + + index_t input_batch = input->dim(0); + index_t input_channels = input->dim(1); + index_t input_height = input->dim(2); + index_t input_width = input->dim(3); + + index_t kernel_h = filter->dim(2); + index_t kernel_w = filter->dim(3); + MACE_CHECK(filter->dim(0) == channels, filter->dim(0), " != ", channels); + MACE_CHECK(filter->dim(1) == input_channels, filter->dim(1), " != ", + input_channels); + + index_t stride_h = strides_[0]; + index_t stride_w = strides_[1]; + + MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); + Tensor::MappingGuard input_mapper(input); + Tensor::MappingGuard filter_mapper(filter); + Tensor::MappingGuard bias_mapper(bias); + Tensor::MappingGuard output_mapper(output); + auto input_data = input->data(); + auto filter_data = filter->data(); + auto bias_data = bias == nullptr ? nullptr : bias->data(); + auto output_data = output->mutable_data(); + int padding_top = (paddings_[0] + 1) >> 1; + int padding_left = (paddings_[1] + 1) >> 1; + + deconv::Deconv2dNCHW(input_data, filter_data, bias_data, + batch, input_height, input_width, input_channels, + height, width, channels, + kernel_h, kernel_w, + stride_h, stride_w, padding_top, padding_left, + output_data); + + DoActivation(output_data, output_data, output->size(), activation_, + relux_max_limit_); + } +}; + +template +struct Deconv2dFunctor : Deconv2dFunctorBase { + Deconv2dFunctor(const int *strides, + const Padding &padding_type, + const std::vector &paddings, + const std::vector &output_shape, + const ActivationType activation, + const float relux_max_limit, + const bool is_filter_transformed, + ScratchBuffer *scratch) + : Deconv2dFunctorBase(strides, + padding_type, + paddings, + output_shape, + activation, + relux_max_limit) {} + + void operator()(const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output, + StatsFuture *future); + + cl::Kernel kernel_; + uint32_t kwg_size_; + std::unique_ptr kernel_error_; + std::vector input_shape_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_DECONV_2D_H_ diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index 134a06d22ebd2ec88f61abe748d4c5dcce60cc1d..cd45614b88aed97331b253cca04d64b0a899d739 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -40,7 +40,8 @@ enum EltwiseType { NEG = 6, ABS = 7, SQR_DIFF = 8, - NONE = 9, + POW = 9, + NONE = 10, }; inline void TensorScalar(const EltwiseType type, @@ -103,19 +104,25 @@ inline void TensorScalar(const EltwiseType type, output[i] = std::pow(input0[i] - value, 2.f); } break; + case POW: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output[i] = std::pow(input0[i], value); + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type; } } -inline void TensorVector(const EltwiseType type, - const float *input0, - const float *input1, - const index_t batch, - const index_t channel, - const index_t hw, - const bool swapped, - float *output) { +inline void TensorBatchVector(const EltwiseType type, + const float *input0, + const float *input1, + const index_t batch, + const index_t channel, + const index_t hw, + const bool swapped, + float *output) { switch (type) { case SUM: #pragma omp parallel for collapse(3) @@ -227,6 +234,153 @@ inline void TensorVector(const EltwiseType type, } } break; + case POW: +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = b * channel + c; + output[idx0] = std::pow(input0[idx0], input1[idx1]); + } + } + } + break; + default: + LOG(FATAL) << "Eltwise op not support type " << type; + } +} +inline void TensorVector(const EltwiseType type, + const float *input0, + const float *input1, + const index_t batch, + const index_t channel, + const index_t hw, + const bool swapped, + float *output) { + switch (type) { + case SUM: +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = input0[idx0] + input1[idx1]; + } + } + } + break; + case SUB: + if (swapped) { +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = input1[idx1] - input0[idx0]; + } + } + } + } else { +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = input0[idx0] - input1[idx1]; + } + } + } + } + break; + case PROD: +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = input0[idx0] * input1[idx1]; + } + } + } + break; + case DIV: + if (swapped) { +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = input1[idx1] / input0[idx0]; + } + } + } + } else { +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = input0[idx0] / input1[idx1]; + } + } + } + } + break; + case MIN: +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = std::min(input0[idx0], input1[idx1]); + } + } + } + break; + case MAX: +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = std::max(input0[idx0], input1[idx1]); + } + } + } + break; + case SQR_DIFF: +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = std::pow(input0[idx0] - input1[idx1], 2.f); + } + } + } + break; + case POW: +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batch; ++b) { + for (index_t c = 0; c < channel; ++c) { + for (index_t i = 0; i < hw; ++i) { + const index_t idx0 = (b * channel + c) * hw + i; + const index_t idx1 = c; + output[idx0] = std::pow(input0[idx0], input1[idx1]); + } + } + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type; } @@ -279,6 +433,12 @@ inline void TensorEltwise(const EltwiseType type, output[i] = std::pow(input0[i] - input1[i], 2.f); } break; + case POW: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output[i] = std::pow(input0[i], input1[i]); + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type; } @@ -312,18 +472,25 @@ struct EltwiseFunctor: EltwiseFunctorBase { StatsFuture *future) { bool swapped = false; if (input1 != nullptr) { - MACE_CHECK(input0->dim_size() == input1->dim_size()) + MACE_CHECK(input0->dim_size() == input1->dim_size() + || input0->dim_size() == 1 + || input1->dim_size() == 1) << "Inputs of Eltwise op must be same shape"; if (input0->size() != input1->size()) { if (input0->size() < input1->size()) { std::swap(input0, input1); swapped = true; } - MACE_CHECK(input0->dim(0) == input1->dim(0) && - input0->dim(1) == input1->dim(1) && - input1->dim(2) == 1 && - input1->dim(3) == 1) - << "Element-Wise op only support channel dimension broadcast"; + if (input1->dim_size() == 1) { + MACE_CHECK(input0->dim(1) == input1->dim(0)) + << "Element-Wise op only support channel dimension broadcast"; + } else { + MACE_CHECK((input0->dim(0) == input1->dim(0) || input1->dim(0) == 1) + && input0->dim(1) == input1->dim(1) + && input1->dim(2) == 1 + && input1->dim(3) == 1) + << "Element-Wise op only support channel dimension broadcast"; + } } } output->ResizeLike(input0); @@ -344,8 +511,12 @@ struct EltwiseFunctor: EltwiseFunctorBase { const index_t batch = input0->dim(0); const index_t channel = input0->dim(1); const index_t hw = input0->dim(2) * input0->dim(3); - TensorVector(type_, input0_ptr, input1_ptr, - batch, channel, hw, swapped, output_ptr); + if (input1->dim(0) == 1 || input1->dim_size() == 1) + TensorVector(type_, input0_ptr, input1_ptr, + batch, channel, hw, swapped, output_ptr); + else + TensorBatchVector(type_, input0_ptr, input1_ptr, + batch, channel, hw, swapped, output_ptr); } else { if (!coeff_.empty() && type_ == SUM) { #pragma omp parallel for diff --git a/mace/kernels/opencl/cl/deconv_2d.cl b/mace/kernels/opencl/cl/deconv_2d.cl new file mode 100644 index 0000000000000000000000000000000000000000..f64dc7bb900fe3c6ceb8dd1a74dcd62d0ffb3c58 --- /dev/null +++ b/mace/kernels/opencl/cl/deconv_2d.cl @@ -0,0 +1,164 @@ +#include + +__kernel void deconv_2d(KERNEL_ERROR_PARAMS + GLOBAL_WORK_GROUP_SIZE_DIM3 + __read_only image2d_t input, + __read_only image2d_t weights, +#ifdef BIAS + __read_only image2d_t bias, +#endif + __write_only image2d_t output, + __private const float relux_max_limit, + __private const int in_height, + __private const int in_width, + __private const int in_channels, + __private const int out_height, + __private const int out_width, + __private const int out_channel, + __private const int stride, + __private const float stride_r, + __private const int align_h, + __private const int align_w, + __private const int padding_h, + __private const int padding_w, + __private const int kernel_h, + __private const int kernel_w, + __private const int kernel_size, + __private const int in_channel_blocks, + __private const int out_channel_blocks) +{ + const int c = get_global_id(0); + const int w_id = get_global_id(1); + const int hb = get_global_id(2); + +#ifndef NON_UNIFORM_WORK_GROUP + if (c >= global_size_dim0 || w_id >= global_size_dim1 + || hb >= global_size_dim2) { + return; + } +#endif + +#ifdef BIAS + DATA_TYPE4 out0 = + READ_IMAGET(bias, SAMPLER, (int2)(c, 0)); + DATA_TYPE4 out1 = out0; + DATA_TYPE4 out2 = out0; + DATA_TYPE4 out3 = out0; + DATA_TYPE4 out4 = out0; +#else + DATA_TYPE4 out0 = 0; + DATA_TYPE4 out1 = 0; + DATA_TYPE4 out2 = 0; + DATA_TYPE4 out3 = 0; + DATA_TYPE4 out4 = 0; +#endif + + const int n_stride = mad(w_id, stride_r, 0); + const int mod_stride = w_id - mul24(n_stride, stride); + const int w = mad24(mul24(n_stride, 5), stride, mod_stride); + const int b = hb / out_height; + const int h = hb - mul24(b, out_height); + if (w < out_width) { + int start_x = floor((float) (w + align_w) * stride_r); + int start_y = (h + align_h) * stride_r; + start_y = max(0, start_y); + + int f_start_x = mad24(start_x, stride, padding_w) - w; + int f_start_y = mad24(start_y, stride, padding_h) - h; + f_start_x = kernel_w - 1 - f_start_x; + f_start_y = kernel_h - 1 - f_start_y; + + int2 in_pos; + int f_pos_x0, f_pos_x1, f_pos_x2, f_pos_x3, f_pos_y; + DATA_TYPE4 in0, in1, in2, in3, in4; + DATA_TYPE4 weight0, weight1, weight2, weight3; + int idx_w0, idx_w1, idx_w2, idx_w3, idx_w4; + int index_x, index_y; + for (int ic = 0; ic < in_channel_blocks; ++ic) { + f_pos_x0 = mul24(ic, 4); + f_pos_x1 = f_pos_x0 + 1; + f_pos_x2 = f_pos_x0 + 2; + f_pos_x3 = f_pos_x0 + 3; + for (int f_y = f_start_y, idx_h = start_y ; f_y >= 0; f_y -= stride, ++idx_h) { + index_y = mad24(b, in_height, idx_h); + in_pos.y = select(index_y, -1, idx_h < 0 || idx_h >= in_height); + for (int f_x = f_start_x, idx_w = start_x; f_x >= 0; f_x -= stride, ++idx_w) { + f_pos_y = mad24(f_y, kernel_w, f_x); + f_pos_y = mad24(c, kernel_size, f_pos_y); + weight0 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x0, f_pos_y)); + weight1 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x1, f_pos_y)); + weight2 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x2, f_pos_y)); + weight3 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x3, f_pos_y)); + + idx_w0 = idx_w; + idx_w1 = idx_w + 1; + idx_w2 = idx_w + 2; + idx_w3 = idx_w + 3; + idx_w4 = idx_w + 4; + +#define READ_INPUT(i) \ + index_x = mad24(ic, in_width, idx_w##i); \ + in_pos.x = \ + select(index_x, -1, idx_w##i < 0 || idx_w##i >= in_width); \ + in##i = READ_IMAGET(input, SAMPLER, in_pos); + + READ_INPUT(0); + READ_INPUT(1); + READ_INPUT(2); + READ_INPUT(3); + READ_INPUT(4); +#undef READ_INPUT + +#define CALC_OUTPUT(i) \ + out##i = mad(in##i.x, weight0, out##i); \ + out##i = mad(in##i.y, weight1, out##i); \ + out##i = mad(in##i.z, weight2, out##i); \ + out##i = mad(in##i.w, weight3, out##i); + + CALC_OUTPUT(0); + CALC_OUTPUT(1); + CALC_OUTPUT(2); + CALC_OUTPUT(3); + CALC_OUTPUT(4); +#undef CALC_OUTPUT + } + } + } + +#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) + out0 = do_activation(out0, relux_max_limit); + out1 = do_activation(out1, relux_max_limit); + out2 = do_activation(out2, relux_max_limit); + out3 = do_activation(out3, relux_max_limit); + out4 = do_activation(out4, relux_max_limit); +#endif + + int2 out_pos; + out_pos.y = hb; + + int ow = w; + if (ow >= out_width) return; + out_pos.x = mad24(c, out_width, ow); + WRITE_IMAGET(output, out_pos, out0); + + ow += stride; + if (ow >= out_width) return; + out_pos.x += stride; + WRITE_IMAGET(output, out_pos, out1); + + ow += stride; + if (ow >= out_width) return; + out_pos.x += stride; + WRITE_IMAGET(output, out_pos, out2); + + ow += stride; + if (ow >= out_width) return; + out_pos.x += stride; + WRITE_IMAGET(output, out_pos, out3); + + ow += stride; + if (ow >= out_width) return; + out_pos.x += stride; + WRITE_IMAGET(output, out_pos, out4); + } +} \ No newline at end of file diff --git a/mace/kernels/opencl/cl/eltwise.cl b/mace/kernels/opencl/cl/eltwise.cl index 3a0ea33cfeffe3f68cd307d58c9bf70ea1e8f43b..e3cd7ecfaf4d155939da8bfb70d51d98758b1944 100644 --- a/mace/kernels/opencl/cl/eltwise.cl +++ b/mace/kernels/opencl/cl/eltwise.cl @@ -33,6 +33,8 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS #elif INPUT_TYPE == 2 const int batch_idx = hb / height; DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(chan_idx, batch_idx)); +#elif INPUT_TYPE == 3 + DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(chan_idx, 0)); #else DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(pos, hb)); #endif @@ -70,10 +72,17 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS #elif ELTWISE_TYPE == 8 DATA_TYPE4 diff = in0 - in1; out = diff * diff; +#elif ELTWISE_TYPE == 9 + #ifdef SWAPPED + out = pow(in0, in1); + #else + out = pow(in1, in0); + #endif #endif #if INPUT_TYPE == 1 - #if ELTWISE_TYPE == 0 || ELTWISE_TYPE == 1 || ELTWISE_TYPE == 4 || ELTWISE_TYPE == 5 || ELTWISE_TYPE == 8 + #if ELTWISE_TYPE == 0 || ELTWISE_TYPE == 1 || ELTWISE_TYPE == 4 || \ + ELTWISE_TYPE == 5 || ELTWISE_TYPE == 8 || ELTWISE_TYPE == 9 const int remain_channel = channel - 4 * chan_idx; if (remain_channel < 4) { switch (remain_channel) { diff --git a/mace/kernels/opencl/deconv_2d_opencl.cc b/mace/kernels/opencl/deconv_2d_opencl.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f02bc47d4d8f425942a005c36b44c7a709ebd37 --- /dev/null +++ b/mace/kernels/opencl/deconv_2d_opencl.cc @@ -0,0 +1,200 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/kernels/deconv_2d.h" +#include "mace/kernels/opencl/helper.h" + +namespace mace { +namespace kernels { + +namespace { + +void Deconv2dOpencl(cl::Kernel *kernel, + const Tensor *input, + const Tensor *filter, + const Tensor *bias, + const int stride, + const int *paddings, + const ActivationType activation, + const float relux_max_limit, + const DataType dt, + std::vector *prev_input_shape, + Tensor *output, + StatsFuture *future, + uint32_t *kwg_size, + std::unique_ptr *kernel_error) { + const index_t batch = output->dim(0); + const index_t height = output->dim(1); + const index_t width = output->dim(2); + const index_t channels = output->dim(3); + const index_t input_channels = input->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); + const index_t input_channel_blocks = RoundUpDiv4(input_channels); + MACE_CHECK(stride > 0, "stride should > 0."); +#define WIDTH_BLK 5 + const index_t n_strides = (width + stride - 1) / stride; + const index_t width_blocks = ((n_strides + WIDTH_BLK -1)/ WIDTH_BLK) * stride; + const float stride_r = 1.f / static_cast(stride); + const int padding_h = (paddings[0]+1) >> 1; + const int padding_w = (paddings[0]+1) >> 1; + + const int align_h = stride - 1 - padding_h; + const int align_w = stride - 1 - padding_w; + const int kernel_size = filter->dim(0) * filter->dim(1); + + auto runtime = OpenCLRuntime::Global(); + + if (kernel->get() == nullptr) { + std::set built_options; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("deconv_2d"); + built_options.emplace("-Ddeconv_2d=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + if (runtime->IsOutOfRangeCheckEnabled()) { + built_options.emplace("-DOUT_OF_RANGE_CHECK"); + *kernel_error = std::move(std::unique_ptr( + new Buffer(GetDeviceAllocator(DeviceType::GPU), 1))); + (*kernel_error)->Map(nullptr); + *((*kernel_error)->mutable_data()) = 0; + (*kernel_error)->UnMap(); + } + if (runtime->IsNonUniformWorkgroupsSupported()) { + built_options.emplace("-DNON_UNIFORM_WORK_GROUP"); + } + built_options.emplace(bias != nullptr ? "-DBIAS" : ""); + switch (activation) { + case NOOP:break; + case RELU:built_options.emplace("-DUSE_RELU"); + break; + case RELUX:built_options.emplace("-DUSE_RELUX"); + break; + case TANH:built_options.emplace("-DUSE_TANH"); + break; + case SIGMOID:built_options.emplace("-DUSE_SIGMOID"); + break; + default:LOG(FATAL) << "Unknown activation type: " << activation; + } + + *kernel = runtime->BuildKernel("deconv_2d", kernel_name, built_options); + + *kwg_size = + static_cast(runtime->GetKernelMaxWorkGroupSize(*kernel)); + } + + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width_blocks), + static_cast(height * batch)}; + + if (!IsVecEqual(*prev_input_shape, input->shape())) { + uint32_t idx = 0; + if (runtime->IsOutOfRangeCheckEnabled()) { + kernel->setArg(idx++, + *(static_cast((*kernel_error)->buffer()))); + } + if (!runtime->IsNonUniformWorkgroupsSupported()) { + kernel->setArg(idx++, gws[0]); + kernel->setArg(idx++, gws[1]); + kernel->setArg(idx++, gws[2]); + } + kernel->setArg(idx++, *(input->opencl_image())); + kernel->setArg(idx++, *(filter->opencl_image())); + if (bias != nullptr) { + kernel->setArg(idx++, *(bias->opencl_image())); + } + kernel->setArg(idx++, *(output->opencl_image())); + kernel->setArg(idx++, relux_max_limit); + kernel->setArg(idx++, static_cast(input->dim(1))); + kernel->setArg(idx++, static_cast(input->dim(2))); + kernel->setArg(idx++, static_cast(input->dim(3))); + kernel->setArg(idx++, static_cast(height)); + kernel->setArg(idx++, static_cast(width)); + kernel->setArg(idx++, static_cast(channels)); + kernel->setArg(idx++, static_cast(stride)); + kernel->setArg(idx++, stride_r); + kernel->setArg(idx++, static_cast(align_h)); + kernel->setArg(idx++, static_cast(align_w)); + kernel->setArg(idx++, static_cast(padding_h)); + kernel->setArg(idx++, static_cast(padding_w)); + kernel->setArg(idx++, static_cast(filter->dim(0))); + kernel->setArg(idx++, static_cast(filter->dim(1))); + kernel->setArg(idx++, static_cast(kernel_size)); + kernel->setArg(idx++, static_cast(input_channel_blocks)); + kernel->setArg(idx++, static_cast(channel_blocks)); + + *prev_input_shape = input->shape(); + } + + const std::vector lws = {8, *kwg_size / 64, 8, 0}; + std::string tuning_key = + Concat("deconv2d_opencl_kernel_", activation, output->dim(0), + output->dim(1), output->dim(2), output->dim(3)); + TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future); + + if (runtime->IsOutOfRangeCheckEnabled()) { + (*kernel_error)->Map(nullptr); + char *kerror_code = (*kernel_error)->mutable_data(); + MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code; + (*kernel_error)->UnMap(); + } +} + +} // namespace + +template +void Deconv2dFunctor::operator()(const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output, + StatsFuture *future) { + MACE_CHECK_NOTNULL(input); + MACE_CHECK_NOTNULL(filter); + MACE_CHECK_NOTNULL(output); + + if (output_shape_.size() == 4) { + paddings_.clear(); + paddings_ = std::vector(2, 0); + CalcDeconvPaddingAndInputSize( + input->shape().data(), + filter->shape().data(), + strides_, padding_type_, + output_shape_.data(), + paddings_.data()); + } else { + output_shape_.clear(); + output_shape_ = std::vector(4, 0); + CalcDeconvOutputSize(input->shape().data(), + filter->shape().data(), + strides_, + output_shape_.data(), + paddings_.data()); + } + + std::vector output_image_shape; + CalImage2DShape(output_shape_, BufferType::IN_OUT_CHANNEL, + &output_image_shape); + output->ResizeImage(output_shape_, output_image_shape); + + Deconv2dOpencl(&kernel_, input, filter, bias, + strides_[0], paddings_.data(), + activation_, relux_max_limit_, + DataTypeToEnum::value, &input_shape_, + output, future, &kwg_size_, &kernel_error_); +} + +template struct Deconv2dFunctor; +template struct Deconv2dFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/opencl/eltwise.cc b/mace/kernels/opencl/eltwise.cc index e3f4b8f8f7db189fc1faf8d52140cd259768af9f..3557032c7fdad9c6f4f1755cd4a05dad83a4809f 100644 --- a/mace/kernels/opencl/eltwise.cc +++ b/mace/kernels/opencl/eltwise.cc @@ -27,21 +27,40 @@ void EltwiseFunctor::operator()(const Tensor *input0, StatsFuture *future) { bool swapped = false; if (input1 != nullptr) { - MACE_CHECK(input0->dim_size() == input1->dim_size()) + MACE_CHECK(input0->dim_size() == input1->dim_size() + || input0->dim_size() == 1 + || input1->dim_size() == 1) << "Inputs of Eltwise op must be same shape"; if (input0->size() != input1->size()) { if (input0->size() < input1->size()) { std::swap(input0, input1); swapped = true; } - MACE_CHECK(input0->dim(0) == input1->dim(0) && - input1->dim(1) == 1 && - input1->dim(2) == 1 && - input0->dim(3) == input1->dim(3)) - << "Element-Wise op only support channel dimension broadcast"; + if (input1->dim_size() == 1) { + MACE_CHECK(input0->dim(3) == input1->dim(0)) + << "Element-Wise op only support channel dimension broadcast"; + } else { + MACE_CHECK((input0->dim(0) == input1->dim(0) || input1->dim(0) == 1) && + input0->dim(3) == input1->dim(3) && + input1->dim(1) == 1 && + input1->dim(2) == 1) + << "Element-Wise op only support channel dimension broadcast"; + } } } - output->ResizeLike(input0); + + std::vector output_shape(4); + output_shape[0] = input0->dim(0); + output_shape[1] = input0->dim(1); + output_shape[2] = input0->dim(2); + output_shape[3] = input0->dim(3); + + std::vector output_image_shape; + CalImage2DShape(output_shape, + BufferType::IN_OUT_CHANNEL, + &output_image_shape); + output->ResizeImage(output_shape, output_image_shape); + const index_t batch = output->dim(0); const index_t height = output->dim(1); const index_t width = output->dim(2); @@ -66,7 +85,10 @@ void EltwiseFunctor::operator()(const Tensor *input0, if (input1 == nullptr) { built_options.emplace("-DINPUT_TYPE=1"); } else if (input0->size() != input1->size()) { - built_options.emplace("-DINPUT_TYPE=2"); + if (input1->dim(0) == 1 || input1->dim_size() == 1) + built_options.emplace("-DINPUT_TYPE=3"); + else + built_options.emplace("-DINPUT_TYPE=2"); if (swapped) built_options.emplace("-DSWAPPED"); } if (!coeff_.empty()) built_options.emplace("-DCOEFF_SUM"); diff --git a/mace/kernels/opencl/out_of_range_check_test.cc b/mace/kernels/opencl/out_of_range_check_test.cc index a67cae0d5966a8a3becf30631d657d9ca016e9b6..8dc0a372fb8852e8f0d43e5e28db537ba7a875e5 100644 --- a/mace/kernels/opencl/out_of_range_check_test.cc +++ b/mace/kernels/opencl/out_of_range_check_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "gtest/gtest.h" diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 2f51df772ce579cc4c95619a5bca9501837d1505..36a2ec37b087702b4e112a5c778a7b8445654b5b 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -105,7 +105,8 @@ void TestNHWCSimple3x3SAME() { {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); net.AddInputFromArray( "Filter", {3, 3, 1, 2}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); net.AddInputFromArray("Bias", {1}, {0.1f}); @@ -191,7 +192,8 @@ void TestNHWCSimple3x3WithoutBias() { {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); net.AddInputFromArray( "Filter", {3, 3, 1, 2}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); if (D == DeviceType::CPU) { @@ -351,10 +353,12 @@ void TestFusedNHWCSimple3x3VALID() { // Add input data net.AddInputFromArray( "Input", {1, 3, 3, 2}, - {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}); + {-1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1}); net.AddInputFromArray( "Filter", {3, 3, 1, 2}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); net.AddInputFromArray("Bias", {1}, {-0.1f}); @@ -423,10 +427,13 @@ void TestFusedNHWCSimple3x3WithoutBias() { // Add input data net.AddInputFromArray( "Input", {1, 3, 3, 2}, - {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}); + {-1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1}); net.AddInputFromArray( "Filter", {3, 3, 1, 2}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); if (D == DeviceType::CPU) { diff --git a/mace/ops/deconv_2d.cc b/mace/ops/deconv_2d.cc new file mode 100644 index 0000000000000000000000000000000000000000..b49776937fd4f94d223b3d42edb85ecd65b0337d --- /dev/null +++ b/mace/ops/deconv_2d.cc @@ -0,0 +1,41 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/deconv_2d.h" + +namespace mace { +namespace ops { + +void Register_Deconv2D(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + Deconv2dOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + Deconv2dOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + Deconv2dOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/deconv_2d.h b/mace/ops/deconv_2d.h new file mode 100644 index 0000000000000000000000000000000000000000..d20d76d8b93d2a5ec7ba35e5d9fd52ef22928b66 --- /dev/null +++ b/mace/ops/deconv_2d.h @@ -0,0 +1,64 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_DECONV_2D_H_ +#define MACE_OPS_DECONV_2D_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/deconv_2d.h" +#include "mace/ops/conv_pool_2d_base.h" + +namespace mace { +namespace ops { + +template +class Deconv2dOp : public ConvPool2dOpBase { + public: + Deconv2dOp(const OperatorDef &op_def, Workspace *ws) + : ConvPool2dOpBase(op_def, ws), + functor_(this->strides_.data(), + this->padding_type_, + this->paddings_, + OperatorBase::GetRepeatedArgument("output_shape"), + kernels::ActivationType::NOOP, + 0.0f, + static_cast(OperatorBase::GetSingleArgument( + "is_filter_transformed", false)), + ws->GetScratchBuffer(D)) {} + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *filter = this->Input(FILTER); + const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr; + Tensor *output = this->Output(OUTPUT); + + functor_(input, filter, bias, output, future); + + return true; + } + + private: + kernels::Deconv2dFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, FILTER, BIAS); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_DECONV_2D_H_ diff --git a/mace/ops/deconv_2d_benchmark.cc b/mace/ops/deconv_2d_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..4401abdf035448241723075262d777bd9e9aa6d3 --- /dev/null +++ b/mace/ops/deconv_2d_benchmark.cc @@ -0,0 +1,143 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/deconv_2d.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void Deconv2d(int iters, + int batch, + int channels, + int height, + int width, + int kernel_h, + int kernel_w, + int stride, + int out_h, + int out_w, + Padding padding, + int output_channels) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + if (D == DeviceType::CPU) { + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Filter", + {output_channels, channels, kernel_h, + kernel_w}); + } else { + net.AddRandomInput("Input", {batch, height, width, channels}); + net.AddRandomInput("Filter", + {kernel_h, kernel_w, output_channels, + channels}); + } + if (D == DeviceType::GPU) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Filter", "FilterImage", + kernels::BufferType::CONV2D_FILTER); + OpDefBuilder("Deconv2D", "Deconv2dTest") + .Input("InputImage") + .Input("FilterImage") + .Output("Output") + .AddIntsArg("strides", {stride, stride}) + .AddIntArg("padding", padding) + .AddIntsArg("output_shape", {batch, out_h, out_w, output_channels}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("Deconv2D", "Deconv2dTest") + .Input("Input") + .Input("Filter") + .Output("Output") + .AddIntsArg("strides", {stride, stride}) + .AddIntArg("padding", padding) + .AddIntsArg("output_shape", {batch, out_h, out_w, output_channels}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + } + + net.Setup(D); + + // Warm-up + for (int i = 0; i < 2; ++i) { + net.Run(); + net.Sync(); + } + + mace::testing::StartTiming(); + while (iters--) { + net.Run(); + net.Sync(); + } +} + +// In common network, there are usually more than 1 layers, this is used to +// approximate the amortized latency. The OpenCL runtime for Mali/Adreno is +// in-order. + +#define BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, OH, OW, P, OC, TYPE, \ + DEVICE) \ + static void \ + BM_DECONV_2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##STRIDE##_##OH##_##OW\ + ##_##P##_##OC##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + int64_t oh = OH; \ + int64_t ow = OW; \ + const int64_t macc = \ + static_cast(iters) * N * OC * oh * ow * (KH * KW * C + 1); \ + mace::testing::MaccProcessed(macc); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + Deconv2d(iters, N, C, H, W, KH, KW, STRIDE, OH, OW, \ + mace::Padding::P, OC); \ + } \ + BENCHMARK( \ + BM_DECONV_2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##STRIDE##_##OH##_##OW##\ + _##P##_##OC##_##TYPE##_##DEVICE) + +#define BM_DECONV_2D(N, C, H, W, KH, KW, S, OH, OW, P, OC) \ + BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, CPU); \ + BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, GPU); \ + BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, half, GPU); + +BM_DECONV_2D(1, 512, 15, 15, 1, 1, 1, 15, 15, VALID, 1024); +BM_DECONV_2D(1, 32, 60, 60, 1, 1, 1, 60, 60, VALID, 128); + +BM_DECONV_2D(1, 128, 60, 60, 3, 3, 1, 62, 62, VALID, 128); +BM_DECONV_2D(1, 32, 60, 60, 3, 3, 1, 60, 60, SAME, 32); +BM_DECONV_2D(1, 3, 512, 512, 7, 7, 2, 1023, 1023, SAME, 64); +BM_DECONV_2D(1, 128, 16, 16, 5, 5, 1, 20, 20, VALID, 32); +BM_DECONV_2D(1, 128, 64, 64, 5, 5, 1, 68, 68, VALID, 32); + +BM_DECONV_2D(1, 3, 480, 480, 1, 1, 1, 480, 480, VALID, 3); + +BM_DECONV_2D(1, 64, 32, 32, 1, 1, 1, 32, 32, VALID, 128); +BM_DECONV_2D(1, 64, 33, 32, 3, 3, 2, 65, 63, SAME, 128); +BM_DECONV_2D(1, 3, 224, 224, 3, 3, 2, 447, 447, SAME, 32); +BM_DECONV_2D(1, 3, 224, 224, 3, 3, 2, 449, 449, VALID, 32); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/deconv_2d_test.cc b/mace/ops/deconv_2d_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..98aa1bd1656f8092d15a6c3a8947805c389d1bc2 --- /dev/null +++ b/mace/ops/deconv_2d_test.cc @@ -0,0 +1,595 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mace/ops/deconv_2d.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class Deconv2dOpTest : public OpsTestBase {}; + +template +void RunTestSimple(const std::vector &input_shape, + const std::vector &input_data, + const int stride, + Padding padding, + const std::vector &padding_size, + const std::vector &output_shape, + const std::vector &filter_shape, + const std::vector &filter_data, + const std::vector &expected_shape, + const std::vector &expected_data) { + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", input_shape, input_data); + net.AddInputFromArray("Filter", filter_shape, filter_data); + + if (D == DeviceType::GPU) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Filter", "FilterImage", + kernels::BufferType::CONV2D_FILTER); + OpDefBuilder("Deconv2D", "Deconv2dTest") + .Input("InputImage") + .Input("FilterImage") + .Output("OutputImage") + .AddIntsArg("strides", {stride, stride}) + .AddIntArg("padding", padding) + .AddIntsArg("padding_values", padding_size) + .AddIntsArg("output_shape", output_shape) + .Finalize(net.NewOperatorDef()); + + net.RunOp(D); + + // Transfer output + ImageToBuffer(&net, "OutputImage", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } else { + net.TransformDataFormat("Input", + NHWC, + "InputNCHW", + NCHW); + net.TransformDataFormat("Filter", + HWOI, + "FilterOIHW", + OIHW); + OpDefBuilder("Deconv2D", "Deconv2dTest") + .Input("InputNCHW") + .Input("FilterOIHW") + .Output("OutputNCHW") + .AddIntsArg("strides", {stride, stride}) + .AddIntArg("padding", padding) + .AddIntsArg("padding_values", padding_size) + .AddIntsArg("output_shape", output_shape) + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + net.TransformDataFormat("OutputNCHW", + NCHW, + "Output", + NHWC); + } + + auto expected = CreateTensor(expected_shape, expected_data); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.0001); +} + +template +void TestNHWCSimple3x3SAME_S1() { + RunTestSimple({1, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1}, + 1, + Padding::SAME, + {0, 0}, + {1, 3, 3, 3}, + {3, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 3, 3, 3}, + {4, 4, 4, 6, 6, 6, 4, 4, 4, + 6, 6, 6, 9, 9, 9, 6, 6, 6, + 4, 4, 4, 6, 6, 6, 4, 4, 4}); + RunTestSimple({1, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1}, + 1, + Padding::VALID, + {2, 2}, + {0}, + {3, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 3, 3, 3}, + {4, 4, 4, 6, 6, 6, 4, 4, 4, + 6, 6, 6, 9, 9, 9, 6, 6, 6, + 4, 4, 4, 6, 6, 6, 4, 4, 4}); + RunTestSimple({1, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, + 1, + Padding::SAME, + {0, 0}, + {1, 3, 3, 3}, + {3, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27}, + {1, 3, 3, 3}, + {54, 66, 78, 126, 147, 168, 130, 146, 162, + 198, 225, 252, 405, 450, 495, 366, 399, 432, + 354, 378, 402, 630, 669, 708, 502, 530, 558}); + RunTestSimple({1, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, + 1, + Padding::SAME, + {2, 2}, + {0}, + {3, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27}, + {1, 3, 3, 3}, + {54, 66, 78, 126, 147, 168, 130, 146, 162, + 198, 225, 252, 405, 450, 495, 366, 399, 432, + 354, 378, 402, 630, 669, 708, 502, 530, 558}); +} + +template +void TestNHWCSimple3x3SAME_S2() { + RunTestSimple({1, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1}, + 2, + Padding::SAME, + {0, 0}, + {1, 6, 6, 3}, + {3, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 6, 6, 3}, + {1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 4, 4, 4, 2, 2, 2, 4, 4, 4, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 4, 4, 4, 2, 2, 2, 4, 4, 4, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1}); + RunTestSimple({1, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1}, + 2, + Padding::SAME, + {2, 2}, + {0}, + {3, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 5, 5, 3}, + {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 4, 4, 4, 2, 2, 2, 4, 4, 4, 2, 2, 2, + 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 4, 4, 4, 2, 2, 2, 4, 4, 4, 2, 2, 2, + 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1}); + RunTestSimple({1, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, + 2, + Padding::SAME, + {0, 0}, + {1, 6, 6, 3}, + {3, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27}, + {1, 6, 6, 3}, + {1, 2, 3, 4, 5, 6, 9, 12, 15, + 8, 10, 12, 17, 22, 27, 12, 15, 18, + 10, 11, 12, 13, 14, 15, 36, 39, 42, + 26, 28, 30, 62, 67, 72, 39, 42, 45, + 23, 28, 33, 38, 43, 48, 96, 108, 120, + 64, 71, 78, 148, 164, 180, 90, 99, 108, + 40, 44, 48, 52, 56, 60, 114, 123, 132, + 65, 70, 75, 140, 151, 162, 78, 84, 90, + 83, 94, 105, 116, 127, 138, 252, 276, 300, + 142, 155, 168, 304, 332, 360, 168, 183, 198, + 70, 77, 84, 91, 98, 105, 192, 207, 222, + 104, 112, 120, 218, 235, 252, 117, 126, 135}); + RunTestSimple({1, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, + 2, + Padding::SAME, + {2, 2}, + {0}, + {3, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27}, + {1, 5, 5, 3}, + {13, 14, 15, 36, 39, 42, + 26, 28, 30, 62, 67, 72, 39, 42, 45, + 38, 43, 48, 96, 108, 120, + 64, 71, 78, 148, 164, 180, 90, 99, 108, + 52, 56, 60, 114, 123, 132, + 65, 70, 75, 140, 151, 162, 78, 84, 90, + 116, 127, 138, 252, 276, 300, + 142, 155, 168, 304, 332, 360, 168, 183, 198, + 91, 98, 105, 192, 207, 222, + 104, 112, 120, 218, 235, 252, 117, 126, 135}); +} + +template +void TestNHWCSimple3x3SAME_S2_1() { + RunTestSimple({1, 3, 3, 1}, + {12, 18, 12, 18, 27, 18, 12, 18, 12}, + 2, + Padding::SAME, + {0, 0}, + {1, 5, 5, 3}, + {3, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 5, 5, 3}, + {12, 12, 12, 30, 30, 30, 18, 18, 18, + 30, 30, 30, 12, 12, 12, + 30, 30, 30, 75, 75, 75, 45, 45, 45, + 75, 75, 75, 30, 30, 30, + 18, 18, 18, 45, 45, 45, 27, 27, 27, + 45, 45, 45, 18, 18, 18, + 30, 30, 30, 75, 75, 75, 45, 45, 45, + 75, 75, 75, 30, 30, 30, + 12, 12, 12, 30, 30, 30, 18, 18, 18, + 30, 30, 30, 12, 12, 12}); +} + +template +void TestNHWCSimple3x3VALID_S2() { + RunTestSimple({1, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1}, + 2, + Padding::VALID, + {0, 0}, + {1, 7, 7, 3}, + {3, 3, 3, 1}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 7, 7, 3}, + {1, 1, 1, 1, 1, 1, 2, 2, 2, + 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 2, 2, 2, + 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 4, 4, 4, 2, 2, 2, + 4, 4, 4, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 4, 4, 4, 2, 2, 2, + 4, 4, 4, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 1, 1, 1, 1, 1, 1}); +} + +template +void TestNHWCSimple3x3VALID_S1() { + RunTestSimple({1, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, + 1, + Padding::VALID, + {0, 0}, + {1, 5, 5, 3}, + {3, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27}, + {1, 5, 5, 3}, + {1, 2, 3, + 6, 9, 12, + 18, 24, 30, + 26, 31, 36, + 21, 24, 27, + 14, 19, 24, + 54, 66, 78, + 126, 147, 168, + 130, 146, 162, + 90, 99, 108, + 66, 78, 90, + 198, 225, 252, + 405, 450, 495, + 366, 399, 432, + 234, 252, 270, + 146, 157, 168, + 354, 378, 402, + 630, 669, 708, + 502, 530, 558, + 294, 309, 324, + 133, 140, 147, + 306, 321, 336, + 522, 546, 570, + 398, 415, 432, + 225, 234, 243}); + RunTestSimple({1, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}, + 1, + Padding::VALID, + {4, 4}, + {0}, + {3, 3, 3, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27}, + {1, 5, 5, 3}, + {1, 2, 3, + 6, 9, 12, + 18, 24, 30, + 26, 31, 36, + 21, 24, 27, + 14, 19, 24, + 54, 66, 78, + 126, 147, 168, + 130, 146, 162, + 90, 99, 108, + 66, 78, 90, + 198, 225, 252, + 405, 450, 495, + 366, 399, 432, + 234, 252, 270, + 146, 157, 168, + 354, 378, 402, + 630, 669, 708, + 502, 530, 558, + 294, 309, 324, + 133, 140, 147, + 306, 321, 336, + 522, 546, 570, + 398, 415, 432, + 225, 234, 243}); +} + +template +void TestNHWCSimple2x2SAME() { + RunTestSimple({1, 2, 2, 1}, + {1, 1, 1, 1}, + 1, + Padding::SAME, + {0, 0}, + {1, 2, 2, 1}, + {3, 3, 1, 1}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {1, 2, 2, 1}, + {4.f, 4.f, 4.f, 4.f}); +} + +template +void TestNHWCSimple2x2VALID() { + RunTestSimple({1, 2, 2, 1}, + {1, 1, 1, 1}, + 2, + Padding::VALID, + {0, 0}, + {1, 5, 5, 1}, + {3, 3, 1, 1}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, + {1, 5, 5, 1}, + {1.f, 1.f, 2.f, 1.f, 1.f, + 1.f, 1.f, 2.f, 1.f, 1.f, + 2.f, 2.f, 4.f, 2.f, 2.f, + 1.f, 1.f, 2.f, 1.f, 1.f, + 1.f, 1.f, 2.f, 1.f, 1.f}); +} + +TEST_F(Deconv2dOpTest, CPUSimple3X3PaddingSame_S1) { + TestNHWCSimple3x3SAME_S1(); +} + +TEST_F(Deconv2dOpTest, CPUSimple3X3PaddingSame_S2) { +TestNHWCSimple3x3SAME_S2(); +} + +TEST_F(Deconv2dOpTest, CPUSimple3X3PaddingSame_S2_1) { +TestNHWCSimple3x3SAME_S2_1(); +} + +TEST_F(Deconv2dOpTest, CPUSimple2X2PaddingSame) { + TestNHWCSimple2x2SAME(); +} + +TEST_F(Deconv2dOpTest, CPUSimple2X2PaddingValid) { + TestNHWCSimple2x2VALID(); +} + +TEST_F(Deconv2dOpTest, CPUSimple3X3PaddingValid_S1) { + TestNHWCSimple3x3VALID_S1(); +} + +TEST_F(Deconv2dOpTest, CPUSimple3X3PaddingValid_S2) { + TestNHWCSimple3x3VALID_S2(); +} + +TEST_F(Deconv2dOpTest, OPENCLSimple2X2PaddingSame) { + TestNHWCSimple2x2SAME(); +} + +TEST_F(Deconv2dOpTest, OPENCLSimple3X3PaddingSame_S1) { + TestNHWCSimple3x3SAME_S1(); +} + +TEST_F(Deconv2dOpTest, OPENCLSimple3X3PaddingSame_S2) { +TestNHWCSimple3x3SAME_S2(); +} + +TEST_F(Deconv2dOpTest, OPENCLSimple3X3PaddingSame_S2_1) { +TestNHWCSimple3x3SAME_S2_1(); +} + +TEST_F(Deconv2dOpTest, OPENCLSimple2X2PaddingValid) { + TestNHWCSimple2x2VALID(); +} + +TEST_F(Deconv2dOpTest, OPENCLSimple3X3PaddingValid_S1) { + TestNHWCSimple3x3VALID_S1(); +} + +TEST_F(Deconv2dOpTest, OPENCLSimple3X3PaddingValid_S2) { + TestNHWCSimple3x3VALID_S2(); +} + +namespace { +template +void TestComplexDeconvNxNS12(const std::vector &shape, + const int stride) { + testing::internal::LogToStderr(); + auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, + Padding type, int padding) { + // generate random input + static unsigned int seed = time(NULL); + int batch = 3 + (rand_r(&seed) % 10); + int height = shape[0]; + int width = shape[1]; + int input_channels = shape[2] + (rand_r(&seed) % 10); + int output_channels = shape[3] + (rand_r(&seed) % 10); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, input_channels}); + net.AddRandomInput( + "Filter", {kernel_h, kernel_w, output_channels, input_channels}); + net.AddRandomInput("Bias", {output_channels}); + net.TransformDataFormat("Input", + NHWC, + "InputNCHW", + NCHW); + net.TransformDataFormat("Filter", + HWOI, + "FilterOIHW", + OIHW); + int out_h = 0; + int out_w = 0; + + std::vectorpaddings; + std::vector output_shape; + + if (padding < 0) { + if (type == Padding::SAME) { + out_h = (height - 1) * stride_h + 1; + out_w = (width - 1) * stride_w + 1; + } else { + out_h = (height - 1) * stride_h + kernel_h; + out_w = (width - 1) * stride_w + kernel_w; + } + output_shape.push_back(batch); + output_shape.push_back(out_h); + output_shape.push_back(out_w); + output_shape.push_back(output_channels); + } else { +// out_h = (height - 1) * stride + 1 + padding - kernel_h + 1; +// out_w = (width -1) * stride + 1 + padding - kernel_w + 1; + paddings.push_back(padding); + paddings.push_back(padding); + } + // Construct graph + OpDefBuilder("Deconv2D", "Deconv2dTest") + .Input("InputNCHW") + .Input("FilterOIHW") + .Input("Bias") + .Output("OutputNCHW") + .AddIntsArg("strides", {stride_h, stride_w}) + .AddIntArg("padding", type) + .AddIntsArg("padding_values", paddings) + .AddIntsArg("output_shape", output_shape) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + + // run on cpu + net.RunOp(); + + net.TransformDataFormat("OutputNCHW", + NCHW, + "Output", + NHWC); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // run on gpu + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Filter", "FilterImage", + kernels::BufferType::CONV2D_FILTER); + BufferToImage(&net, "Bias", "BiasImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("Deconv2D", "Deconv2dTest") + .Input("InputImage") + .Input("FilterImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntsArg("strides", {stride_h, stride_w}) + .AddIntArg("padding", type) + .AddIntsArg("padding_values", paddings) + .AddIntsArg("output_shape", output_shape) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + // Run on device + net.RunOp(D); + + ImageToBuffer(&net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + ExpectTensorNear(expected, + *net.GetOutput("OPENCLOutput"), 1e-4, 1e-4); + }; + + for (int kernel_size : {1, 3, 5, 7}) { + func(kernel_size, kernel_size, stride, stride, VALID, -1); + func(kernel_size, kernel_size, stride, stride, SAME, -1); + func(kernel_size, kernel_size, stride, stride, VALID, 1); + func(kernel_size, kernel_size, stride, stride, VALID, 2); + func(kernel_size, kernel_size, stride, stride, VALID, 3); + func(kernel_size, kernel_size, stride, stride, VALID, 4); + } +} +} // namespace + +TEST_F(Deconv2dOpTest, OPENCLAlignedDeconvNxNS12) { + TestComplexDeconvNxNS12({32, 16, 16, 32}, 1); + TestComplexDeconvNxNS12({32, 16, 16, 32}, 2); + TestComplexDeconvNxNS12({33, 17, 16, 32}, 1); + TestComplexDeconvNxNS12({33, 17, 16, 32}, 2); +} + +TEST_F(Deconv2dOpTest, OPENCLAlignedDeconvNxNS34) { + TestComplexDeconvNxNS12({32, 16, 16, 32}, 3); + TestComplexDeconvNxNS12({32, 16, 16, 32}, 4); +} + +TEST_F(Deconv2dOpTest, OPENCLUnalignedDeconvNxNS12) { +TestComplexDeconvNxNS12({17, 113, 5, 7}, 1); +TestComplexDeconvNxNS12({17, 113, 5, 7}, 2); +} + +TEST_F(Deconv2dOpTest, OPENCLUnalignedDeconvNxNS34) { + TestComplexDeconvNxNS12({17, 113, 5, 7}, 3); + TestComplexDeconvNxNS12({17, 113, 5, 7}, 4); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/eltwise_test.cc b/mace/ops/eltwise_test.cc index a156d95f54953aa2166676bf2dfa75bff9947a29..f2d887125ac27ae6baba0d7510709d3d3f79fac0 100644 --- a/mace/ops/eltwise_test.cc +++ b/mace/ops/eltwise_test.cc @@ -553,6 +553,8 @@ TEST_F(EltwiseOpTest, RandomTensorVecFloat) { {1, 32, 32, 16}, {1, 1, 1, 16}); RandomTensorEltwise(kernels::EltwiseType::SUB, {5, 32, 32, 16}, {5, 1, 1, 16}); + RandomTensorEltwise(kernels::EltwiseType::SUB, + {5, 32, 32, 16}, {1, 1, 1, 16}); RandomTensorEltwise(kernels::EltwiseType::SUB, {5, 1, 1, 16}, {5, 32, 32, 16}); RandomTensorEltwise(kernels::EltwiseType::PROD, @@ -574,12 +576,16 @@ TEST_F(EltwiseOpTest, RandomTensorVecHalf) { {1, 32, 32, 16}, {1, 1, 1, 16}); RandomTensorEltwise(kernels::EltwiseType::SUB, {3, 32, 32, 16}, {3, 1, 1, 16}); + RandomTensorEltwise(kernels::EltwiseType::SUB, + {3, 32, 32, 16}, {1, 1, 1, 16}); RandomTensorEltwise(kernels::EltwiseType::SUB, {3, 1, 1, 16}, {3, 32, 32, 16}); RandomTensorEltwise(kernels::EltwiseType::PROD, {1, 1, 1, 17}, {1, 31, 37, 17}); RandomTensorEltwise(kernels::EltwiseType::DIV, {5, 31, 37, 17}, {5, 1, 1, 17}); + RandomTensorEltwise(kernels::EltwiseType::DIV, + {5, 31, 37, 17}, {1, 1, 1, 17}); RandomTensorEltwise(kernels::EltwiseType::DIV, {5, 1, 1, 17}, {5, 31, 37, 17}); RandomTensorEltwise(kernels::EltwiseType::MIN, diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 6230cf0e4180d4da94d5d66fdf59df73f6965a4f..8647b246af2b560384f1be2b7bc385e8bf4f0982 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -38,6 +38,8 @@ math_type_mode = { 'MAX': 5, 'NEG': 6, 'ABS': 7, + 'SQR_DIFF': 8, + 'POW': 9, } buffer_type_map = { @@ -528,6 +530,103 @@ class TFConverter(object): self.add_output_shape(final_op.outputs, op_def) self.net_def.op.extend([op_def]) + def convert_deconv2d(self, op): + op_def = mace_pb2.OperatorDef() + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + op_def.name = op.name + op_def.type = 'Deconv2D' + + out_shape_value = None + if len(op.inputs) == 2: + out_shape_value = op.get_attr('output_shape') + if self.device == 'cpu': + self.transpose_filter_tensor[get_input_tensor( + op, 1).name] = (3, 2, 0, 1) + else: + self.transpose_filter_tensor[get_input_tensor( + op, 1).name] = (0, 1, 3, 2) + if self.device == 'gpu': + op_def.input.extend([op.inputs[0].name]) + buffer_type = "CONV2D_FILTER" + output_name = self.add_buffer_to_image( + get_input_tensor(op, 1).name, buffer_type) + op_def.input.extend([output_name]) + else: + op_def.input.extend( + [get_input_tensor(op, i).name + for i in range(len(op.inputs))]) + elif len(op.inputs) == 3: + out_shape_value = \ + get_input_tensor(op, 0).eval().astype(np.int32).flat + self.unused_tensor.add(op.inputs[0].name) + if self.device == 'cpu': + self.transpose_filter_tensor[get_input_tensor( + op, 1).name] = (2, 3, 0, 1) + else: + self.transpose_filter_tensor[get_input_tensor( + op, 1).name] = (0, 1, 2, 3) + if self.device == 'gpu': + op_def.input.extend([op.inputs[2].name]) + buffer_type = "CONV2D_FILTER" + output_name = self.add_buffer_to_image( + get_input_tensor(op, 1).name, buffer_type) + op_def.input.extend([output_name]) + else: + op_def.input.extend([op.inputs[2].name]) + op_def.input.extend([op.inputs[1].name]) + else: + raise Exception('Too many inputs. Op: %s, type: %s' % (op.name, + op.type)) + if out_shape_value is not None: + out_shape_arg = op_def.arg.add() + out_shape_arg.name = 'output_shape' + out_shape_arg.ints.extend(out_shape_value) + padding_arg = op_def.arg.add() + padding_arg.name = 'padding' + padding_arg.i = padding_mode[op.get_attr('padding')] + strides_arg = op_def.arg.add() + strides_arg.name = 'strides' + strides_arg.ints.extend(op.get_attr('strides')[1:3]) + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + if self.device == 'cpu': + data_format_arg.s = 'NCHW' + else: + data_format_arg.s = 'NHWC' + final_op = op + self.resolved_ops[op.name] = 1 + + if len(self.tf_graph.get(op.name, [])) == 1 and \ + self.tf_graph[op.name][0].type == 'BiasAdd': + bias_add_op = self.tf_graph[op.name][0] + if self.device == 'gpu': + output_name = self.add_buffer_to_image( + get_input_tensor(bias_add_op, 1).name, "ARGUMENT") + op_def.input.extend([output_name]) + else: + op_def.input.extend([get_input_tensor(bias_add_op, 1).name]) + final_op = bias_add_op + self.resolved_ops[bias_add_op.name] = 1 + + if len(self.tf_graph.get(final_op.name, [])) == 1 and \ + self.tf_graph[final_op.name][0].type in activation_name_map: + activation_op = self.tf_graph[final_op.name][0] + fused_act_arg = op_def.arg.add() + fused_act_arg.name = 'activation' + fused_act_arg.s = activation_name_map[activation_op.type] + if activation_op.type == 'Relu6': + max_limit_arg = op_def.arg.add() + max_limit_arg.name = 'max_limit' + max_limit_arg.f = 6 + final_op = activation_op + self.resolved_ops[activation_op.name] = 1 + + op_def.output.extend([output.name for output in final_op.outputs]) + self.add_output_shape(final_op.outputs, op_def) + self.net_def.op.extend([op_def]) + def check_conv_to_fc(self, op): if self.device != 'cpu' or op.type != "Conv2D": return False @@ -857,6 +956,7 @@ class TFConverter(object): if len(op.inputs) == 2: input_tensor0 = get_input_tensor(op, 0) input_tensor1 = get_input_tensor(op, 1) + x_value = None if np.asarray(input_tensor1.shape).size == 0: x_value = input_tensor1.eval() @@ -867,7 +967,22 @@ class TFConverter(object): op_def.input.extend([op.inputs[1].name]) self.unused_tensor.add(input_tensor0.name) else: - op_def.input.extend([input.name for input in op.inputs]) + if np.asarray(input_tensor0.shape).size == 1 \ + and input_tensor0.op.type == 'Const': + if self.device == 'gpu': + output_name = self.add_buffer_to_image( + input_tensor0.name, "ARGUMENT") + op_def.input.extend([output_name]) + else: + op_def.input.extend([input_tensor0.name]) + if np.asarray(input_tensor1.shape).size == 1 \ + and input_tensor1.op.type == 'Const': + if self.device == 'gpu': + output_name = self.add_buffer_to_image( + input_tensor1.name, "ARGUMENT") + op_def.input.extend([output_name]) + else: + op_def.input.extend([input_tensor1.name]) if x_value is not None: x_arg = op_def.arg.add() x_arg.name = 'x' @@ -1150,6 +1265,8 @@ class TFConverter(object): self.convert_winograd_conv_gpu(op) else: self.convert_conv2d(op) + elif op.type == 'Conv2DBackpropInput': + self.convert_deconv2d(op) elif op.type == 'FusedBatchNorm': self.convert_fused_batchnorm(op) elif op.type == 'Mul' and op.name.find('batchnorm/mul') != -1: @@ -1159,7 +1276,10 @@ class TFConverter(object): elif op.type == 'Relu6': self.convert_relu6(op) elif op.type == 'Add': - self.convert_add(op) + if len(op.inputs) > 2: + self.convert_add(op) + else: + self.convert_eltwise(op, 'ADD') elif op.type == 'ConcatV2': self.convert_concat(op) elif op.type == 'ResizeBilinear': @@ -1176,6 +1296,12 @@ class TFConverter(object): self.convert_depth_to_space(op, False) elif op.type in ['Neg', 'neg', 'Negative', 'negative']: self.convert_eltwise(op, 'NEG') + elif op.type in ['RealDiv', 'Div']: + self.convert_eltwise(op, 'DIV') + elif op.type in ['SquaredDifference']: + self.convert_eltwise(op, 'SQR_DIFF') + elif op.type in ['Pow']: + self.convert_eltwise(op, 'POW') elif op.type == 'Mul': self.convert_eltwise(op, 'MUL') elif op.type == 'Sub':