提交 6ae2fe6f 编写于 作者: 刘琦

Merge branch 'add-deconv' into 'master'

Add deconv

See merge request !401
......@@ -82,6 +82,7 @@ extern void Register_BiasAdd(OperatorRegistry *op_registry);
extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_Deconv2D(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Dequantize(OperatorRegistry *op_registry);
......@@ -122,6 +123,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_ChannelShuffle(this);
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_Deconv2D(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
......
......@@ -14,6 +14,7 @@
#include "mace/kernels/conv_pool_2d_util.h"
#include <algorithm>
#include <vector>
namespace mace {
......@@ -147,6 +148,8 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
output_shape[3] = output_channels;
}
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const int *padding_size,
......@@ -161,14 +164,7 @@ void CalcOutputSize(const index_t *input_shape, // NHWC
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
/*
* Convlution arithmetic:
* o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* Pooling arithmetic:
* o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
output_shape[0] = input_shape[0];
if (round_type == FLOOR) {
output_shape[1] = static_cast<index_t>(
......@@ -454,5 +450,6 @@ void ConstructNHWCInputWithPadding(const Tensor *input_tensor,
}
}
}
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_DECONV_2D_H_
#define MACE_KERNELS_DECONV_2D_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
namespace deconv {
template<typename T>
void Deconv2dNCHW(const T *input,
const T *filter,
const T *bias,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const index_t filter_height,
const index_t filter_width,
const index_t stride_h,
const index_t stride_w,
const int padding_top,
const int padding_left,
float *output) {
#pragma omp parallel for collapse(4)
for (index_t b = 0; b < batch; ++b) {
for (index_t oc = 0; oc < out_channels; ++oc) {
for (index_t oh = 0; oh < out_height; ++oh) {
for (index_t ow = 0; ow < out_width; ++ow) {
index_t filter_start_y, filter_start_x;
index_t start_x = std::max<int>(0, ow + stride_w -1 - padding_left);
index_t start_y = std::max<int>(0, oh + stride_h -1 - padding_top);
start_x /= stride_w;
start_y /= stride_h;
filter_start_x = padding_left + stride_w * start_x - ow;
filter_start_y = padding_top + stride_h * start_y - oh;
filter_start_x = filter_width - 1 - filter_start_x;
filter_start_y = filter_height - 1 - filter_start_y;
T out_value = 0;
index_t out_pos =
((b * out_channels + oc) * out_height + oh) * out_width + ow;
for (index_t ic = 0; ic < in_channels; ++ic) {
for (index_t f_y = filter_start_y, ih = start_y;
f_y >= 0 && ih < in_height; f_y -= stride_h, ++ih) {
for (index_t f_x = filter_start_x, iw = start_x;
f_x >= 0 && iw < in_width; f_x -= stride_w, ++iw) {
index_t weight_pos =
((oc * in_channels + ic) * filter_height + f_y)
* filter_width + f_x;
index_t in_pos =
((b * in_channels + ic) * in_height + ih)
* in_width + iw;
out_value += input[in_pos] * filter[weight_pos];
}
}
}
if (bias != nullptr)
out_value += bias[oc];
output[out_pos] = out_value;
}
}
}
}
}
} // namespace deconv
struct Deconv2dFunctorBase {
Deconv2dFunctorBase(const int *strides,
const Padding &padding_type,
const std::vector<int> &paddings,
const std::vector<index_t> &output_shape,
const ActivationType activation,
const float relux_max_limit)
: strides_(strides),
padding_type_(padding_type),
paddings_(paddings),
output_shape_(output_shape),
activation_(activation),
relux_max_limit_(relux_max_limit) {}
static void CalcDeconvOutputSize(
const index_t *input_shape, // NHWC
const index_t *filter_shape, // OIHW
const int *strides,
index_t *output_shape,
const int *padding_size,
const bool isNCHW = false,
const bool isOIHW = false) {
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
MACE_CHECK_NOTNULL(input_shape);
MACE_CHECK_NOTNULL(filter_shape);
MACE_CHECK_NOTNULL(strides);
const index_t output_channel = isOIHW ? filter_shape[0] : filter_shape[2];
const index_t in_height = isNCHW ? input_shape[2] : input_shape[1];
const index_t in_width = isNCHW ? input_shape[3] : input_shape[2];
const index_t in_channels = isNCHW ? input_shape[1] : input_shape[3];
const index_t extended_input_height =
(in_height - 1) * strides[0] + 1 + padding_size[0];
const index_t extended_input_width =
(in_width - 1) * strides[1] + 1 + padding_size[1];
const index_t filter_h = isOIHW ? filter_shape[2] : filter_shape[0];
const index_t filter_w = isOIHW ? filter_shape[3] : filter_shape[1];
index_t out_height = extended_input_height - filter_h + 1;
index_t out_width = extended_input_width - filter_w + 1;
output_shape[0] = input_shape[0];
if (isNCHW) {
output_shape[1] = output_channel;
output_shape[2] = out_height;
output_shape[3] = out_width;
} else {
output_shape[1] = out_height;
output_shape[2] = out_width;
output_shape[3] = output_channel;
}
}
static void CalcDeconvPaddingAndInputSize(
const index_t *input_shape, // NHWC
const index_t *filter_shape, // OIHW
const int *strides,
Padding padding,
const index_t *output_shape,
int *padding_size,
const bool isNCHW = false,
const bool isOIHW = false) {
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
MACE_CHECK_NOTNULL(input_shape);
MACE_CHECK_NOTNULL(filter_shape);
MACE_CHECK_NOTNULL(strides);
const index_t in_height = isNCHW ? input_shape[2] : input_shape[1];
const index_t in_width = isNCHW ? input_shape[3] : input_shape[2];
const index_t in_channels = isNCHW ? input_shape[1] : input_shape[3];
const index_t out_height = isNCHW ? output_shape[2] : output_shape[1];
const index_t out_width = isNCHW ? output_shape[3] : output_shape[2];
const index_t out_channels = isNCHW ? output_shape[1] : output_shape[3];
const index_t extended_input_height = (in_height - 1) * strides[0] + 1;
const index_t extended_input_width = (in_width - 1) * strides[1] + 1;
const index_t filter_h = isOIHW ? filter_shape[2] : filter_shape[0];
const index_t filter_w = isOIHW ? filter_shape[3] : filter_shape[1];
index_t expected_input_height = 0, expected_input_width = 0;
switch (padding) {
case VALID:
expected_input_height =
(out_height - filter_h) / strides[0] + 1;
expected_input_width =
(out_width - filter_w) / strides[1] + 1;
break;
case SAME:
expected_input_height =
(out_height - 1) / strides[0] + 1;
expected_input_width =
(out_width - 1) / strides[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
}
MACE_CHECK(expected_input_height == in_height,
expected_input_height, "!=", in_height);
MACE_CHECK(expected_input_width == in_width,
expected_input_width, "!=", in_width);
const int p_h = static_cast<int>(out_height +
filter_h - 1 - extended_input_height);
const int p_w = static_cast<int>(out_width +
filter_w - 1 - extended_input_width);
padding_size[0] = std::max<int>(0, p_h);
padding_size[1] = std::max<int>(0, p_w);
}
const int *strides_; // [stride_h, stride_w]
const Padding padding_type_;
std::vector<int> paddings_;
const ActivationType activation_;
const float relux_max_limit_;
std::vector<index_t> output_shape_;
};
template <DeviceType D, typename T>
struct Deconv2dFunctor : Deconv2dFunctorBase {
Deconv2dFunctor(const int *strides,
const Padding &padding_type,
const std::vector<int> &paddings,
const std::vector<index_t> &output_shape,
const ActivationType activation,
const float relux_max_limit,
const bool is_filter_transformed,
ScratchBuffer *scratch)
: Deconv2dFunctorBase(strides,
padding_type,
paddings,
output_shape,
activation,
relux_max_limit) {}
void operator()(const Tensor *input, // NCHW
const Tensor *filter, // OIHW
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape(4);
if (output_shape_.size() == 4) {
output_shape[0] = output_shape_[0];
output_shape[1] = output_shape_[3];
output_shape[2] = output_shape_[1];
output_shape[3] = output_shape_[2];
paddings_.clear();
paddings_ = std::vector<int>(2, 0);
CalcDeconvPaddingAndInputSize(
input->shape().data(),
filter->shape().data(),
strides_, padding_type_,
output_shape.data(),
paddings_.data(), true, true);
output->Resize(output_shape);
} else {
output_shape_.clear();
output_shape_ = std::vector<index_t>(4, 0);
CalcDeconvOutputSize(input->shape().data(),
filter->shape().data(),
strides_,
output_shape_.data(),
paddings_.data(), true, true);
output->Resize(output_shape_);
}
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3);
MACE_CHECK(filter->dim(0) == channels, filter->dim(0), " != ", channels);
MACE_CHECK(filter->dim(1) == input_channels, filter->dim(1), " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard filter_mapper(filter);
Tensor::MappingGuard bias_mapper(bias);
Tensor::MappingGuard output_mapper(output);
auto input_data = input->data<T>();
auto filter_data = filter->data<T>();
auto bias_data = bias == nullptr ? nullptr : bias->data<T>();
auto output_data = output->mutable_data<T>();
int padding_top = (paddings_[0] + 1) >> 1;
int padding_left = (paddings_[1] + 1) >> 1;
deconv::Deconv2dNCHW(input_data, filter_data, bias_data,
batch, input_height, input_width, input_channels,
height, width, channels,
kernel_h, kernel_w,
stride_h, stride_w, padding_top, padding_left,
output_data);
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
}
};
template <typename T>
struct Deconv2dFunctor<DeviceType::GPU, T> : Deconv2dFunctorBase {
Deconv2dFunctor(const int *strides,
const Padding &padding_type,
const std::vector<int> &paddings,
const std::vector<index_t> &output_shape,
const ActivationType activation,
const float relux_max_limit,
const bool is_filter_transformed,
ScratchBuffer *scratch)
: Deconv2dFunctorBase(strides,
padding_type,
paddings,
output_shape,
activation,
relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
std::vector<index_t> input_shape_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_DECONV_2D_H_
......@@ -40,7 +40,8 @@ enum EltwiseType {
NEG = 6,
ABS = 7,
SQR_DIFF = 8,
NONE = 9,
POW = 9,
NONE = 10,
};
inline void TensorScalar(const EltwiseType type,
......@@ -103,19 +104,25 @@ inline void TensorScalar(const EltwiseType type,
output[i] = std::pow(input0[i] - value, 2.f);
}
break;
case POW:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output[i] = std::pow(input0[i], value);
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type;
}
}
inline void TensorVector(const EltwiseType type,
const float *input0,
const float *input1,
const index_t batch,
const index_t channel,
const index_t hw,
const bool swapped,
float *output) {
inline void TensorBatchVector(const EltwiseType type,
const float *input0,
const float *input1,
const index_t batch,
const index_t channel,
const index_t hw,
const bool swapped,
float *output) {
switch (type) {
case SUM:
#pragma omp parallel for collapse(3)
......@@ -227,6 +234,153 @@ inline void TensorVector(const EltwiseType type,
}
}
break;
case POW:
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = b * channel + c;
output[idx0] = std::pow(input0[idx0], input1[idx1]);
}
}
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type;
}
}
inline void TensorVector(const EltwiseType type,
const float *input0,
const float *input1,
const index_t batch,
const index_t channel,
const index_t hw,
const bool swapped,
float *output) {
switch (type) {
case SUM:
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = input0[idx0] + input1[idx1];
}
}
}
break;
case SUB:
if (swapped) {
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = input1[idx1] - input0[idx0];
}
}
}
} else {
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = input0[idx0] - input1[idx1];
}
}
}
}
break;
case PROD:
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = input0[idx0] * input1[idx1];
}
}
}
break;
case DIV:
if (swapped) {
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = input1[idx1] / input0[idx0];
}
}
}
} else {
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = input0[idx0] / input1[idx1];
}
}
}
}
break;
case MIN:
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = std::min<float>(input0[idx0], input1[idx1]);
}
}
}
break;
case MAX:
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = std::max<float>(input0[idx0], input1[idx1]);
}
}
}
break;
case SQR_DIFF:
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = std::pow(input0[idx0] - input1[idx1], 2.f);
}
}
}
break;
case POW:
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channel; ++c) {
for (index_t i = 0; i < hw; ++i) {
const index_t idx0 = (b * channel + c) * hw + i;
const index_t idx1 = c;
output[idx0] = std::pow(input0[idx0], input1[idx1]);
}
}
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type;
}
......@@ -279,6 +433,12 @@ inline void TensorEltwise(const EltwiseType type,
output[i] = std::pow(input0[i] - input1[i], 2.f);
}
break;
case POW:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output[i] = std::pow(input0[i], input1[i]);
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type;
}
......@@ -312,18 +472,25 @@ struct EltwiseFunctor<DeviceType::CPU, float>: EltwiseFunctorBase {
StatsFuture *future) {
bool swapped = false;
if (input1 != nullptr) {
MACE_CHECK(input0->dim_size() == input1->dim_size())
MACE_CHECK(input0->dim_size() == input1->dim_size()
|| input0->dim_size() == 1
|| input1->dim_size() == 1)
<< "Inputs of Eltwise op must be same shape";
if (input0->size() != input1->size()) {
if (input0->size() < input1->size()) {
std::swap(input0, input1);
swapped = true;
}
MACE_CHECK(input0->dim(0) == input1->dim(0) &&
input0->dim(1) == input1->dim(1) &&
input1->dim(2) == 1 &&
input1->dim(3) == 1)
<< "Element-Wise op only support channel dimension broadcast";
if (input1->dim_size() == 1) {
MACE_CHECK(input0->dim(1) == input1->dim(0))
<< "Element-Wise op only support channel dimension broadcast";
} else {
MACE_CHECK((input0->dim(0) == input1->dim(0) || input1->dim(0) == 1)
&& input0->dim(1) == input1->dim(1)
&& input1->dim(2) == 1
&& input1->dim(3) == 1)
<< "Element-Wise op only support channel dimension broadcast";
}
}
}
output->ResizeLike(input0);
......@@ -344,8 +511,12 @@ struct EltwiseFunctor<DeviceType::CPU, float>: EltwiseFunctorBase {
const index_t batch = input0->dim(0);
const index_t channel = input0->dim(1);
const index_t hw = input0->dim(2) * input0->dim(3);
TensorVector(type_, input0_ptr, input1_ptr,
batch, channel, hw, swapped, output_ptr);
if (input1->dim(0) == 1 || input1->dim_size() == 1)
TensorVector(type_, input0_ptr, input1_ptr,
batch, channel, hw, swapped, output_ptr);
else
TensorBatchVector(type_, input0_ptr, input1_ptr,
batch, channel, hw, swapped, output_ptr);
} else {
if (!coeff_.empty() && type_ == SUM) {
#pragma omp parallel for
......
#include <common.h>
__kernel void deconv_2d(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t weights,
#ifdef BIAS
__read_only image2d_t bias,
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const int in_height,
__private const int in_width,
__private const int in_channels,
__private const int out_height,
__private const int out_width,
__private const int out_channel,
__private const int stride,
__private const float stride_r,
__private const int align_h,
__private const int align_w,
__private const int padding_h,
__private const int padding_w,
__private const int kernel_h,
__private const int kernel_w,
__private const int kernel_size,
__private const int in_channel_blocks,
__private const int out_channel_blocks)
{
const int c = get_global_id(0);
const int w_id = get_global_id(1);
const int hb = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (c >= global_size_dim0 || w_id >= global_size_dim1
|| hb >= global_size_dim2) {
return;
}
#endif
#ifdef BIAS
DATA_TYPE4 out0 =
READ_IMAGET(bias, SAMPLER, (int2)(c, 0));
DATA_TYPE4 out1 = out0;
DATA_TYPE4 out2 = out0;
DATA_TYPE4 out3 = out0;
DATA_TYPE4 out4 = out0;
#else
DATA_TYPE4 out0 = 0;
DATA_TYPE4 out1 = 0;
DATA_TYPE4 out2 = 0;
DATA_TYPE4 out3 = 0;
DATA_TYPE4 out4 = 0;
#endif
const int n_stride = mad(w_id, stride_r, 0);
const int mod_stride = w_id - mul24(n_stride, stride);
const int w = mad24(mul24(n_stride, 5), stride, mod_stride);
const int b = hb / out_height;
const int h = hb - mul24(b, out_height);
if (w < out_width) {
int start_x = floor((float) (w + align_w) * stride_r);
int start_y = (h + align_h) * stride_r;
start_y = max(0, start_y);
int f_start_x = mad24(start_x, stride, padding_w) - w;
int f_start_y = mad24(start_y, stride, padding_h) - h;
f_start_x = kernel_w - 1 - f_start_x;
f_start_y = kernel_h - 1 - f_start_y;
int2 in_pos;
int f_pos_x0, f_pos_x1, f_pos_x2, f_pos_x3, f_pos_y;
DATA_TYPE4 in0, in1, in2, in3, in4;
DATA_TYPE4 weight0, weight1, weight2, weight3;
int idx_w0, idx_w1, idx_w2, idx_w3, idx_w4;
int index_x, index_y;
for (int ic = 0; ic < in_channel_blocks; ++ic) {
f_pos_x0 = mul24(ic, 4);
f_pos_x1 = f_pos_x0 + 1;
f_pos_x2 = f_pos_x0 + 2;
f_pos_x3 = f_pos_x0 + 3;
for (int f_y = f_start_y, idx_h = start_y ; f_y >= 0; f_y -= stride, ++idx_h) {
index_y = mad24(b, in_height, idx_h);
in_pos.y = select(index_y, -1, idx_h < 0 || idx_h >= in_height);
for (int f_x = f_start_x, idx_w = start_x; f_x >= 0; f_x -= stride, ++idx_w) {
f_pos_y = mad24(f_y, kernel_w, f_x);
f_pos_y = mad24(c, kernel_size, f_pos_y);
weight0 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x0, f_pos_y));
weight1 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x1, f_pos_y));
weight2 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x2, f_pos_y));
weight3 = READ_IMAGET(weights, SAMPLER, (int2)(f_pos_x3, f_pos_y));
idx_w0 = idx_w;
idx_w1 = idx_w + 1;
idx_w2 = idx_w + 2;
idx_w3 = idx_w + 3;
idx_w4 = idx_w + 4;
#define READ_INPUT(i) \
index_x = mad24(ic, in_width, idx_w##i); \
in_pos.x = \
select(index_x, -1, idx_w##i < 0 || idx_w##i >= in_width); \
in##i = READ_IMAGET(input, SAMPLER, in_pos);
READ_INPUT(0);
READ_INPUT(1);
READ_INPUT(2);
READ_INPUT(3);
READ_INPUT(4);
#undef READ_INPUT
#define CALC_OUTPUT(i) \
out##i = mad(in##i.x, weight0, out##i); \
out##i = mad(in##i.y, weight1, out##i); \
out##i = mad(in##i.z, weight2, out##i); \
out##i = mad(in##i.w, weight3, out##i);
CALC_OUTPUT(0);
CALC_OUTPUT(1);
CALC_OUTPUT(2);
CALC_OUTPUT(3);
CALC_OUTPUT(4);
#undef CALC_OUTPUT
}
}
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit);
out4 = do_activation(out4, relux_max_limit);
#endif
int2 out_pos;
out_pos.y = hb;
int ow = w;
if (ow >= out_width) return;
out_pos.x = mad24(c, out_width, ow);
WRITE_IMAGET(output, out_pos, out0);
ow += stride;
if (ow >= out_width) return;
out_pos.x += stride;
WRITE_IMAGET(output, out_pos, out1);
ow += stride;
if (ow >= out_width) return;
out_pos.x += stride;
WRITE_IMAGET(output, out_pos, out2);
ow += stride;
if (ow >= out_width) return;
out_pos.x += stride;
WRITE_IMAGET(output, out_pos, out3);
ow += stride;
if (ow >= out_width) return;
out_pos.x += stride;
WRITE_IMAGET(output, out_pos, out4);
}
}
\ No newline at end of file
......@@ -33,6 +33,8 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS
#elif INPUT_TYPE == 2
const int batch_idx = hb / height;
DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(chan_idx, batch_idx));
#elif INPUT_TYPE == 3
DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(chan_idx, 0));
#else
DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(pos, hb));
#endif
......@@ -70,10 +72,17 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS
#elif ELTWISE_TYPE == 8
DATA_TYPE4 diff = in0 - in1;
out = diff * diff;
#elif ELTWISE_TYPE == 9
#ifdef SWAPPED
out = pow(in0, in1);
#else
out = pow(in1, in0);
#endif
#endif
#if INPUT_TYPE == 1
#if ELTWISE_TYPE == 0 || ELTWISE_TYPE == 1 || ELTWISE_TYPE == 4 || ELTWISE_TYPE == 5 || ELTWISE_TYPE == 8
#if ELTWISE_TYPE == 0 || ELTWISE_TYPE == 1 || ELTWISE_TYPE == 4 || \
ELTWISE_TYPE == 5 || ELTWISE_TYPE == 8 || ELTWISE_TYPE == 9
const int remain_channel = channel - 4 * chan_idx;
if (remain_channel < 4) {
switch (remain_channel) {
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/deconv_2d.h"
#include "mace/kernels/opencl/helper.h"
namespace mace {
namespace kernels {
namespace {
void Deconv2dOpencl(cl::Kernel *kernel,
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
const int stride,
const int *paddings,
const ActivationType activation,
const float relux_max_limit,
const DataType dt,
std::vector<index_t> *prev_input_shape,
Tensor *output,
StatsFuture *future,
uint32_t *kwg_size,
std::unique_ptr<BufferBase> *kernel_error) {
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
const index_t channels = output->dim(3);
const index_t input_channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
MACE_CHECK(stride > 0, "stride should > 0.");
#define WIDTH_BLK 5
const index_t n_strides = (width + stride - 1) / stride;
const index_t width_blocks = ((n_strides + WIDTH_BLK -1)/ WIDTH_BLK) * stride;
const float stride_r = 1.f / static_cast<float>(stride);
const int padding_h = (paddings[0]+1) >> 1;
const int padding_w = (paddings[0]+1) >> 1;
const int align_h = stride - 1 - padding_h;
const int align_w = stride - 1 - padding_w;
const int kernel_size = filter->dim(0) * filter->dim(1);
auto runtime = OpenCLRuntime::Global();
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("deconv_2d");
built_options.emplace("-Ddeconv_2d=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (runtime->IsOutOfRangeCheckEnabled()) {
built_options.emplace("-DOUT_OF_RANGE_CHECK");
*kernel_error = std::move(std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(DeviceType::GPU), 1)));
(*kernel_error)->Map(nullptr);
*((*kernel_error)->mutable_data<char>()) = 0;
(*kernel_error)->UnMap();
}
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
switch (activation) {
case NOOP:break;
case RELU:built_options.emplace("-DUSE_RELU");
break;
case RELUX:built_options.emplace("-DUSE_RELUX");
break;
case TANH:built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:built_options.emplace("-DUSE_SIGMOID");
break;
default:LOG(FATAL) << "Unknown activation type: " << activation;
}
*kernel = runtime->BuildKernel("deconv_2d", kernel_name, built_options);
*kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(*kernel));
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)};
if (!IsVecEqual(*prev_input_shape, input->shape())) {
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel->setArg(idx++,
*(static_cast<cl::Buffer *>((*kernel_error)->buffer())));
}
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel->setArg(idx++, gws[0]);
kernel->setArg(idx++, gws[1]);
kernel->setArg(idx++, gws[2]);
}
kernel->setArg(idx++, *(input->opencl_image()));
kernel->setArg(idx++, *(filter->opencl_image()));
if (bias != nullptr) {
kernel->setArg(idx++, *(bias->opencl_image()));
}
kernel->setArg(idx++, *(output->opencl_image()));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, static_cast<int32_t>(input->dim(1)));
kernel->setArg(idx++, static_cast<int32_t>(input->dim(2)));
kernel->setArg(idx++, static_cast<int32_t>(input->dim(3)));
kernel->setArg(idx++, static_cast<int32_t>(height));
kernel->setArg(idx++, static_cast<int32_t>(width));
kernel->setArg(idx++, static_cast<int32_t>(channels));
kernel->setArg(idx++, static_cast<int32_t>(stride));
kernel->setArg(idx++, stride_r);
kernel->setArg(idx++, static_cast<int32_t>(align_h));
kernel->setArg(idx++, static_cast<int32_t>(align_w));
kernel->setArg(idx++, static_cast<int32_t>(padding_h));
kernel->setArg(idx++, static_cast<int32_t>(padding_w));
kernel->setArg(idx++, static_cast<int32_t>(filter->dim(0)));
kernel->setArg(idx++, static_cast<int32_t>(filter->dim(1)));
kernel->setArg(idx++, static_cast<int32_t>(kernel_size));
kernel->setArg(idx++, static_cast<int32_t>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int32_t>(channel_blocks));
*prev_input_shape = input->shape();
}
const std::vector<uint32_t> lws = {8, *kwg_size / 64, 8, 0};
std::string tuning_key =
Concat("deconv2d_opencl_kernel_", activation, output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
(*kernel_error)->Map(nullptr);
char *kerror_code = (*kernel_error)->mutable_data<char>();
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
(*kernel_error)->UnMap();
}
}
} // namespace
template <typename T>
void Deconv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
if (output_shape_.size() == 4) {
paddings_.clear();
paddings_ = std::vector<int>(2, 0);
CalcDeconvPaddingAndInputSize(
input->shape().data(),
filter->shape().data(),
strides_, padding_type_,
output_shape_.data(),
paddings_.data());
} else {
output_shape_.clear();
output_shape_ = std::vector<index_t>(4, 0);
CalcDeconvOutputSize(input->shape().data(),
filter->shape().data(),
strides_,
output_shape_.data(),
paddings_.data());
}
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape_, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape_, output_image_shape);
Deconv2dOpencl(&kernel_, input, filter, bias,
strides_[0], paddings_.data(),
activation_, relux_max_limit_,
DataTypeToEnum<T>::value, &input_shape_,
output, future, &kwg_size_, &kernel_error_);
}
template struct Deconv2dFunctor<DeviceType::GPU, float>;
template struct Deconv2dFunctor<DeviceType::GPU, half>;
} // namespace kernels
} // namespace mace
......@@ -27,21 +27,40 @@ void EltwiseFunctor<DeviceType::GPU, T>::operator()(const Tensor *input0,
StatsFuture *future) {
bool swapped = false;
if (input1 != nullptr) {
MACE_CHECK(input0->dim_size() == input1->dim_size())
MACE_CHECK(input0->dim_size() == input1->dim_size()
|| input0->dim_size() == 1
|| input1->dim_size() == 1)
<< "Inputs of Eltwise op must be same shape";
if (input0->size() != input1->size()) {
if (input0->size() < input1->size()) {
std::swap(input0, input1);
swapped = true;
}
MACE_CHECK(input0->dim(0) == input1->dim(0) &&
input1->dim(1) == 1 &&
input1->dim(2) == 1 &&
input0->dim(3) == input1->dim(3))
<< "Element-Wise op only support channel dimension broadcast";
if (input1->dim_size() == 1) {
MACE_CHECK(input0->dim(3) == input1->dim(0))
<< "Element-Wise op only support channel dimension broadcast";
} else {
MACE_CHECK((input0->dim(0) == input1->dim(0) || input1->dim(0) == 1) &&
input0->dim(3) == input1->dim(3) &&
input1->dim(1) == 1 &&
input1->dim(2) == 1)
<< "Element-Wise op only support channel dimension broadcast";
}
}
}
output->ResizeLike(input0);
std::vector<index_t > output_shape(4);
output_shape[0] = input0->dim(0);
output_shape[1] = input0->dim(1);
output_shape[2] = input0->dim(2);
output_shape[3] = input0->dim(3);
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape,
BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
......@@ -66,7 +85,10 @@ void EltwiseFunctor<DeviceType::GPU, T>::operator()(const Tensor *input0,
if (input1 == nullptr) {
built_options.emplace("-DINPUT_TYPE=1");
} else if (input0->size() != input1->size()) {
built_options.emplace("-DINPUT_TYPE=2");
if (input1->dim(0) == 1 || input1->dim_size() == 1)
built_options.emplace("-DINPUT_TYPE=3");
else
built_options.emplace("-DINPUT_TYPE=2");
if (swapped) built_options.emplace("-DSWAPPED");
}
if (!coeff_.empty()) built_options.emplace("-DCOEFF_SUM");
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "gtest/gtest.h"
......
......@@ -105,7 +105,8 @@ void TestNHWCSimple3x3SAME() {
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
......@@ -191,7 +192,8 @@ void TestNHWCSimple3x3WithoutBias() {
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
if (D == DeviceType::CPU) {
......@@ -351,10 +353,12 @@ void TestFusedNHWCSimple3x3VALID() {
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 3, 3, 2},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1});
{-1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1});
net.AddInputFromArray<D, float>(
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, float>("Bias", {1}, {-0.1f});
......@@ -423,10 +427,13 @@ void TestFusedNHWCSimple3x3WithoutBias() {
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 3, 3, 2},
{-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1});
{-1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1});
net.AddInputFromArray<D, float>(
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
if (D == DeviceType::CPU) {
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/deconv_2d.h"
namespace mace {
namespace ops {
void Register_Deconv2D(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
Deconv2dOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
Deconv2dOp<DeviceType::GPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
Deconv2dOp<DeviceType::GPU, half>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_DECONV_2D_H_
#define MACE_OPS_DECONV_2D_H_
#include <memory>
#include "mace/core/operator.h"
#include "mace/kernels/deconv_2d.h"
#include "mace/ops/conv_pool_2d_base.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class Deconv2dOp : public ConvPool2dOpBase<D, T> {
public:
Deconv2dOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws),
functor_(this->strides_.data(),
this->padding_type_,
this->paddings_,
OperatorBase::GetRepeatedArgument<index_t>("output_shape"),
kernels::ActivationType::NOOP,
0.0f,
static_cast<bool>(OperatorBase::GetSingleArgument<int>(
"is_filter_transformed", false)),
ws->GetScratchBuffer(D)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *output = this->Output(OUTPUT);
functor_(input, filter, bias, output, future);
return true;
}
private:
kernels::Deconv2dFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, FILTER, BIAS);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_DECONV_2D_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/deconv_2d.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void Deconv2d(int iters,
int batch,
int channels,
int height,
int width,
int kernel_h,
int kernel_w,
int stride,
int out_h,
int out_w,
Padding padding,
int output_channels) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h,
kernel_w});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
net.AddRandomInput<D, float>("Filter",
{kernel_h, kernel_w, output_channels,
channels});
}
if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Filter", "FilterImage",
kernels::BufferType::CONV2D_FILTER);
OpDefBuilder("Deconv2D", "Deconv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("output_shape", {batch, out_h, out_w, output_channels})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("Deconv2D", "Deconv2dTest")
.Input("Input")
.Input("Filter")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("output_shape", {batch, out_h, out_w, output_channels})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
}
net.Setup(D);
// Warm-up
for (int i = 0; i < 2; ++i) {
net.Run();
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.Run();
net.Sync();
}
}
// In common network, there are usually more than 1 layers, this is used to
// approximate the amortized latency. The OpenCL runtime for Mali/Adreno is
// in-order.
#define BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, OH, OW, P, OC, TYPE, \
DEVICE) \
static void \
BM_DECONV_2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##STRIDE##_##OH##_##OW\
##_##P##_##OC##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
int64_t oh = OH; \
int64_t ow = OW; \
const int64_t macc = \
static_cast<int64_t>(iters) * N * OC * oh * ow * (KH * KW * C + 1); \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Deconv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, OH, OW, \
mace::Padding::P, OC); \
} \
BENCHMARK( \
BM_DECONV_2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##STRIDE##_##OH##_##OW##\
_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_DECONV_2D(N, C, H, W, KH, KW, S, OH, OW, P, OC) \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, CPU); \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, GPU); \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, half, GPU);
BM_DECONV_2D(1, 512, 15, 15, 1, 1, 1, 15, 15, VALID, 1024);
BM_DECONV_2D(1, 32, 60, 60, 1, 1, 1, 60, 60, VALID, 128);
BM_DECONV_2D(1, 128, 60, 60, 3, 3, 1, 62, 62, VALID, 128);
BM_DECONV_2D(1, 32, 60, 60, 3, 3, 1, 60, 60, SAME, 32);
BM_DECONV_2D(1, 3, 512, 512, 7, 7, 2, 1023, 1023, SAME, 64);
BM_DECONV_2D(1, 128, 16, 16, 5, 5, 1, 20, 20, VALID, 32);
BM_DECONV_2D(1, 128, 64, 64, 5, 5, 1, 68, 68, VALID, 32);
BM_DECONV_2D(1, 3, 480, 480, 1, 1, 1, 480, 480, VALID, 3);
BM_DECONV_2D(1, 64, 32, 32, 1, 1, 1, 32, 32, VALID, 128);
BM_DECONV_2D(1, 64, 33, 32, 3, 3, 2, 65, 63, SAME, 128);
BM_DECONV_2D(1, 3, 224, 224, 3, 3, 2, 447, 447, SAME, 32);
BM_DECONV_2D(1, 3, 224, 224, 3, 3, 2, 449, 449, VALID, 32);
} // namespace test
} // namespace ops
} // namespace mace
此差异已折叠。
......@@ -553,6 +553,8 @@ TEST_F(EltwiseOpTest, RandomTensorVecFloat) {
{1, 32, 32, 16}, {1, 1, 1, 16});
RandomTensorEltwise<float>(kernels::EltwiseType::SUB,
{5, 32, 32, 16}, {5, 1, 1, 16});
RandomTensorEltwise<float>(kernels::EltwiseType::SUB,
{5, 32, 32, 16}, {1, 1, 1, 16});
RandomTensorEltwise<float>(kernels::EltwiseType::SUB,
{5, 1, 1, 16}, {5, 32, 32, 16});
RandomTensorEltwise<float>(kernels::EltwiseType::PROD,
......@@ -574,12 +576,16 @@ TEST_F(EltwiseOpTest, RandomTensorVecHalf) {
{1, 32, 32, 16}, {1, 1, 1, 16});
RandomTensorEltwise<half>(kernels::EltwiseType::SUB,
{3, 32, 32, 16}, {3, 1, 1, 16});
RandomTensorEltwise<half>(kernels::EltwiseType::SUB,
{3, 32, 32, 16}, {1, 1, 1, 16});
RandomTensorEltwise<half>(kernels::EltwiseType::SUB,
{3, 1, 1, 16}, {3, 32, 32, 16});
RandomTensorEltwise<half>(kernels::EltwiseType::PROD,
{1, 1, 1, 17}, {1, 31, 37, 17});
RandomTensorEltwise<half>(kernels::EltwiseType::DIV,
{5, 31, 37, 17}, {5, 1, 1, 17});
RandomTensorEltwise<half>(kernels::EltwiseType::DIV,
{5, 31, 37, 17}, {1, 1, 1, 17});
RandomTensorEltwise<half>(kernels::EltwiseType::DIV,
{5, 1, 1, 17}, {5, 31, 37, 17});
RandomTensorEltwise<half>(kernels::EltwiseType::MIN,
......
......@@ -38,6 +38,8 @@ math_type_mode = {
'MAX': 5,
'NEG': 6,
'ABS': 7,
'SQR_DIFF': 8,
'POW': 9,
}
buffer_type_map = {
......@@ -528,6 +530,103 @@ class TFConverter(object):
self.add_output_shape(final_op.outputs, op_def)
self.net_def.op.extend([op_def])
def convert_deconv2d(self, op):
op_def = mace_pb2.OperatorDef()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
op_def.name = op.name
op_def.type = 'Deconv2D'
out_shape_value = None
if len(op.inputs) == 2:
out_shape_value = op.get_attr('output_shape')
if self.device == 'cpu':
self.transpose_filter_tensor[get_input_tensor(
op, 1).name] = (3, 2, 0, 1)
else:
self.transpose_filter_tensor[get_input_tensor(
op, 1).name] = (0, 1, 3, 2)
if self.device == 'gpu':
op_def.input.extend([op.inputs[0].name])
buffer_type = "CONV2D_FILTER"
output_name = self.add_buffer_to_image(
get_input_tensor(op, 1).name, buffer_type)
op_def.input.extend([output_name])
else:
op_def.input.extend(
[get_input_tensor(op, i).name
for i in range(len(op.inputs))])
elif len(op.inputs) == 3:
out_shape_value = \
get_input_tensor(op, 0).eval().astype(np.int32).flat
self.unused_tensor.add(op.inputs[0].name)
if self.device == 'cpu':
self.transpose_filter_tensor[get_input_tensor(
op, 1).name] = (2, 3, 0, 1)
else:
self.transpose_filter_tensor[get_input_tensor(
op, 1).name] = (0, 1, 2, 3)
if self.device == 'gpu':
op_def.input.extend([op.inputs[2].name])
buffer_type = "CONV2D_FILTER"
output_name = self.add_buffer_to_image(
get_input_tensor(op, 1).name, buffer_type)
op_def.input.extend([output_name])
else:
op_def.input.extend([op.inputs[2].name])
op_def.input.extend([op.inputs[1].name])
else:
raise Exception('Too many inputs. Op: %s, type: %s' % (op.name,
op.type))
if out_shape_value is not None:
out_shape_arg = op_def.arg.add()
out_shape_arg.name = 'output_shape'
out_shape_arg.ints.extend(out_shape_value)
padding_arg = op_def.arg.add()
padding_arg.name = 'padding'
padding_arg.i = padding_mode[op.get_attr('padding')]
strides_arg = op_def.arg.add()
strides_arg.name = 'strides'
strides_arg.ints.extend(op.get_attr('strides')[1:3])
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
if self.device == 'cpu':
data_format_arg.s = 'NCHW'
else:
data_format_arg.s = 'NHWC'
final_op = op
self.resolved_ops[op.name] = 1
if len(self.tf_graph.get(op.name, [])) == 1 and \
self.tf_graph[op.name][0].type == 'BiasAdd':
bias_add_op = self.tf_graph[op.name][0]
if self.device == 'gpu':
output_name = self.add_buffer_to_image(
get_input_tensor(bias_add_op, 1).name, "ARGUMENT")
op_def.input.extend([output_name])
else:
op_def.input.extend([get_input_tensor(bias_add_op, 1).name])
final_op = bias_add_op
self.resolved_ops[bias_add_op.name] = 1
if len(self.tf_graph.get(final_op.name, [])) == 1 and \
self.tf_graph[final_op.name][0].type in activation_name_map:
activation_op = self.tf_graph[final_op.name][0]
fused_act_arg = op_def.arg.add()
fused_act_arg.name = 'activation'
fused_act_arg.s = activation_name_map[activation_op.type]
if activation_op.type == 'Relu6':
max_limit_arg = op_def.arg.add()
max_limit_arg.name = 'max_limit'
max_limit_arg.f = 6
final_op = activation_op
self.resolved_ops[activation_op.name] = 1
op_def.output.extend([output.name for output in final_op.outputs])
self.add_output_shape(final_op.outputs, op_def)
self.net_def.op.extend([op_def])
def check_conv_to_fc(self, op):
if self.device != 'cpu' or op.type != "Conv2D":
return False
......@@ -857,6 +956,7 @@ class TFConverter(object):
if len(op.inputs) == 2:
input_tensor0 = get_input_tensor(op, 0)
input_tensor1 = get_input_tensor(op, 1)
x_value = None
if np.asarray(input_tensor1.shape).size == 0:
x_value = input_tensor1.eval()
......@@ -867,7 +967,22 @@ class TFConverter(object):
op_def.input.extend([op.inputs[1].name])
self.unused_tensor.add(input_tensor0.name)
else:
op_def.input.extend([input.name for input in op.inputs])
if np.asarray(input_tensor0.shape).size == 1 \
and input_tensor0.op.type == 'Const':
if self.device == 'gpu':
output_name = self.add_buffer_to_image(
input_tensor0.name, "ARGUMENT")
op_def.input.extend([output_name])
else:
op_def.input.extend([input_tensor0.name])
if np.asarray(input_tensor1.shape).size == 1 \
and input_tensor1.op.type == 'Const':
if self.device == 'gpu':
output_name = self.add_buffer_to_image(
input_tensor1.name, "ARGUMENT")
op_def.input.extend([output_name])
else:
op_def.input.extend([input_tensor1.name])
if x_value is not None:
x_arg = op_def.arg.add()
x_arg.name = 'x'
......@@ -1150,6 +1265,8 @@ class TFConverter(object):
self.convert_winograd_conv_gpu(op)
else:
self.convert_conv2d(op)
elif op.type == 'Conv2DBackpropInput':
self.convert_deconv2d(op)
elif op.type == 'FusedBatchNorm':
self.convert_fused_batchnorm(op)
elif op.type == 'Mul' and op.name.find('batchnorm/mul') != -1:
......@@ -1159,7 +1276,10 @@ class TFConverter(object):
elif op.type == 'Relu6':
self.convert_relu6(op)
elif op.type == 'Add':
self.convert_add(op)
if len(op.inputs) > 2:
self.convert_add(op)
else:
self.convert_eltwise(op, 'ADD')
elif op.type == 'ConcatV2':
self.convert_concat(op)
elif op.type == 'ResizeBilinear':
......@@ -1176,6 +1296,12 @@ class TFConverter(object):
self.convert_depth_to_space(op, False)
elif op.type in ['Neg', 'neg', 'Negative', 'negative']:
self.convert_eltwise(op, 'NEG')
elif op.type in ['RealDiv', 'Div']:
self.convert_eltwise(op, 'DIV')
elif op.type in ['SquaredDifference']:
self.convert_eltwise(op, 'SQR_DIFF')
elif op.type in ['Pow']:
self.convert_eltwise(op, 'POW')
elif op.type == 'Mul':
self.convert_eltwise(op, 'MUL')
elif op.type == 'Sub':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册