提交 2cf24f6d 编写于 作者: Y yejianwu

merge with master

......@@ -82,7 +82,6 @@ extern void Register_BiasAdd(OperatorRegistry *op_registry);
extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_CWise(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Dequantize(OperatorRegistry *op_registry);
......@@ -125,7 +124,6 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_ChannelShuffle(this);
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_CWise(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
......
......@@ -119,19 +119,20 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) {
tensor_map_[const_tensor.name()] = std::move(tensor);
}
if (type == DeviceType::OPENCL) {
CreateImageOutputTensor(net_def);
if (type == DeviceType::CPU || type == DeviceType::OPENCL) {
CreateOutputTensorBuffer(net_def, type);
}
}
void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
DeviceType device_type) {
if (!net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) {
return;
}
DataType dtype = DataType::DT_INVALID;
// We use the data type of the first op (with mem id, must be image),
// as GPU have consistent data type for each layer for now.
// We use the data type of the first op with mem id,
// as CPU&GPU have consistent data type for each layer for now.
// As DSP may have different data output type for each op,
// we stick to the same concept.
for (auto &op : net_def.op()) {
......@@ -148,11 +149,19 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
}
MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid.");
for (auto &mem_block : net_def.mem_arena().mem_block()) {
std::unique_ptr<BufferBase> image_buf(
new Image({mem_block.x(), mem_block.y()}, dtype));
preallocated_allocator_.SetBuffer(mem_block.mem_id(), std::move(image_buf));
if (device_type == DeviceType::OPENCL) {
std::unique_ptr<BufferBase> image_buf(
new Image({mem_block.x(), mem_block.y()}, dtype));
preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(image_buf));
} else {
std::unique_ptr<BufferBase> tensor_buf(
new Buffer(GetDeviceAllocator(device_type), mem_block.x()));
preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(tensor_buf));
}
}
VLOG(3) << "Preallocate image to tensors";
VLOG(3) << "Preallocate buffer to tensors";
for (auto &op : net_def.op()) {
if (!op.mem_id().empty()) {
auto mem_ids = op.mem_id();
......@@ -161,15 +170,17 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
std::unique_ptr<Tensor> tensor
(new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]), dtype));
tensor->SetSourceOpName(op.name());
VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")"
<< " Mem: " << mem_ids[i]
<< " Image shape: "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())
->image_shape()[0]
<< ", "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())
->image_shape()[1];
tensor_map_[op.output(i)] = std::move(tensor);
if (device_type == DeviceType::OPENCL) {
VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")"
<< " Mem: " << mem_ids[i]
<< " Image shape: "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())
->image_shape()[0]
<< ", "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())
->image_shape()[1];
}
}
}
}
......
......@@ -52,7 +52,7 @@ class Workspace {
ScratchBuffer *GetScratchBuffer(DeviceType device_type);
private:
void CreateImageOutputTensor(const NetDef &net_def);
void CreateOutputTensorBuffer(const NetDef &net_def, DeviceType device_type);
TensorMap tensor_map_;
......
......@@ -51,6 +51,39 @@ extern void Conv2dNeonK3x3S2(const float *input,
const index_t out_channels,
float *output);
extern void Conv2dNeonK7x7S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
extern void Conv2dNeonK7x7S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
extern void Conv2dNeonK7x7S3(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
} // namespace kernels
} // namespace mace
......
此差异已折叠。
......@@ -227,6 +227,12 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1;
bool use_neon_1x1_s1 = filter_h == 1 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_7x7_s1 = filter_h == 7 && filter_w == 7
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1;
bool use_neon_7x7_s2 = filter_h == 7 && filter_w == 7
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && dilation_w == 1;
bool use_neon_7x7_s3 = filter_h == 7 && filter_w == 7
&& stride_h == 3 && stride_w == 3 && dilation_h == 1 && dilation_w == 1;
std::vector<index_t> transformed_input_shape;
std::vector<index_t> transformed_output_shape;
......@@ -291,6 +297,44 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_7x7_s1) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, extra_output_height + 6);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width = std::max(padded_input_width, extra_output_width + 6);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_7x7_s2) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 2 + 7);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 2 + 7);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
} else if (use_neon_7x7_s3) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 3 + 7);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 3 + 7);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
}
// decide scratch size before allocate it
......@@ -416,6 +460,45 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
channels,
pad_output);
};
} else if (use_neon_7x7_s1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK7x7S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (use_neon_7x7_s2) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK7x7S2(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (use_neon_7x7_s3) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK7x7S3(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dGeneral(pad_input,
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_CWISE_H_
#define MACE_KERNELS_CWISE_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace kernels {
enum CWiseType {
MUL = 0,
ADD = 1,
MAX = 2,
MIN = 3,
SUB = 4,
DIV = 5,
NEG = 6,
ABS = 7,
};
struct CWiseFunctorBase {
CWiseFunctorBase(const CWiseType type, const float coeff)
: type_(type), coeff_(coeff) {}
CWiseType type_;
float coeff_;
};
template <DeviceType D, typename T>
struct CWiseFunctor : CWiseFunctorBase {
CWiseFunctor(const CWiseType type, const float coeff)
: CWiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
const index_t size = input->size();
switch (type_) {
case MUL:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = coeff_ * input_ptr[i];
}
break;
case ADD:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = coeff_ + input_ptr[i];
}
break;
case MAX:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max<T>(input_ptr[i], coeff_);
}
break;
case MIN:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min<T>(input_ptr[i], coeff_);
}
break;
case SUB:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input_ptr[i] - coeff_;
}
break;
case DIV:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input_ptr[i] / coeff_;
}
break;
case NEG:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = 0 - input_ptr[i];
}
break;
case ABS:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
T val = input_ptr[i];
output_ptr[i] = (val > 0)? val : 0 - val;
}
break;
default:
LOG(FATAL) << "CWise op not support type " << type_;
}
}
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
struct CWiseFunctor<DeviceType::OPENCL, T> : CWiseFunctorBase {
CWiseFunctor(const CWiseType type, const float coeff)
: CWiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
std::vector<index_t> input_shape_;
};
#endif // MACE_ENABLE_OPENCL
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_CWISE_H_
......@@ -35,10 +35,15 @@ enum EltwiseType {
MAX = 2,
MIN = 3,
SUB = 4,
DIV = 5,
NEG = 6,
ABS = 7,
SQR_DIFF = 8,
};
struct EltwiseFunctorBase {
EltwiseFunctorBase(const EltwiseType type, const std::vector<float> &coeff)
EltwiseFunctorBase(const EltwiseType type,
const std::vector<float> &coeff)
: type_(type), coeff_(coeff) {}
EltwiseType type_;
......@@ -47,63 +52,195 @@ struct EltwiseFunctorBase {
template <DeviceType D, typename T>
struct EltwiseFunctor : EltwiseFunctorBase {
EltwiseFunctor(const EltwiseType type, const std::vector<float> &coeff)
EltwiseFunctor(const EltwiseType type,
const std::vector<float> &coeff)
: EltwiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input0,
const Tensor *input1,
const index_t start_axis,
const bool is_scaler,
const float value,
const bool swap,
Tensor *output,
StatsFuture *future) {
Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard input1_guard(input1);
Tensor::MappingGuard output_guard(output);
if (is_scaler) {
Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard output_guard(output);
const T *input0_ptr = input0->data<T>();
const T *input1_ptr = input1->data<T>();
T *output_ptr = output->mutable_data<T>();
const index_t size = input0->size();
switch (type_) {
case PROD:
const T *input0_ptr = input0->data<T>();
T *output_ptr = output->mutable_data<T>();
const index_t num = input0->size();
switch (type_) {
case PROD:
#pragma omp parallel for
for (index_t i = 0; i < num; ++i) {
output_ptr[i] = input0_ptr[i] * value;
}
break;
case SUM:
if (coeff_.empty()) {
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input0_ptr[i] * input1_ptr[i];
}
break;
case SUM:
if (coeff_.empty()) {
for (index_t i = 0; i < num; ++i) {
output_ptr[i] = input0_ptr[i] + value;
}
} else {
const float coeff_0 = swap ? coeff_[1] : coeff_[0];
const float coeff_1 = swap ? coeff_[0] : coeff_[1];
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input0_ptr[i] + input1_ptr[i];
for (index_t i = 0; i < num; ++i) {
output_ptr[i] = coeff_0 * input0_ptr[i] +
coeff_1 * value;
}
}
} else {
break;
case MAX:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] =
coeff_[0] * input0_ptr[i] + coeff_[1] * input1_ptr[i];
for (index_t i = 0; i < num; ++i) {
output_ptr[i] = std::max<T>(input0_ptr[i], value);
}
}
break;
case MAX:
break;
case MIN:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max<T>(input0_ptr[i], input1_ptr[i]);
}
break;
case MIN:
for (index_t i = 0; i < num; ++i) {
output_ptr[i] = std::min<T>(input0_ptr[i], value);
}
break;
case SUB:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min<T>(input0_ptr[i], input1_ptr[i]);
}
break;
case SUB:
for (index_t i = 0; i < num; ++i) {
output_ptr[i] = swap ? value - input0_ptr[i] :
input0_ptr[i] - value;
}
break;
case DIV:
if (!swap) {
MACE_CHECK(fabs(value) > 1e-6, "cannot divided by 0.");
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input0_ptr[i] - input1_ptr[i];
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type_;
for (index_t i = 0; i < num; ++i) {
output_ptr[i] = input0_ptr[i] / value;
}
} else {
#pragma omp parallel for
for (index_t i = 0; i < num; ++i) {
MACE_CHECK(fabs(input0_ptr[i]) > 1e-6, "cannot divided by 0.");
output_ptr[i] = value / input0_ptr[i];
}
}
break;
case SQR_DIFF:
#pragma omp parallel for
for (index_t i = 0; i < num; ++i) {
const float tmp = input0_ptr[i] - value;
output_ptr[i] = tmp * tmp;
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type_;
}
} else {
MACE_CHECK_NOTNULL(input0);
MACE_CHECK_NOTNULL(input1);
Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard input1_guard(input1);
Tensor::MappingGuard output_guard(output);
const T *input0_ptr = input0->data<T>();
const T *input1_ptr = input1->data<T>();
T *output_ptr = output->mutable_data<T>();
const index_t size0 = input0->size();
const index_t size1 = input1->size();
const index_t num = size0 / size1;
switch (type_) {
case PROD:
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < num; ++i) {
for (index_t j= 0; j < size1; ++j) {
output_ptr[i * size1 + j] =
input0_ptr[i * size1 + j] * input1_ptr[j];
}
}
break;
case SUM:
if (coeff_.empty()) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < num; ++i) {
for (index_t j = 0; j < size1; ++j) {
output_ptr[i * size1 + j] =
input0_ptr[i * size1 + j] + input1_ptr[j];
}
}
} else {
const float coeff_0 = swap ? coeff_[1] : coeff_[0];
const float coeff_1 = swap ? coeff_[0] : coeff_[1];
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < num; ++i) {
for (index_t j = 0; j < size1; ++j) {
output_ptr[i * size1 + j] =
coeff_0 * input0_ptr[i * size1 + j] +
coeff_1 * input1_ptr[j];
}
}
}
break;
case MAX:
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < num; ++i) {
for (index_t j = 0; j < size1; ++j) {
output_ptr[i * size1 + j] =
std::max<T>(input0_ptr[i * size1 + j], input1_ptr[j]);
}
}
break;
case MIN:
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < num; ++i) {
for (index_t j = 0; j < size1; ++j) {
output_ptr[i * size1 + j] =
std::min<T>(input0_ptr[i * size1 + j], input1_ptr[j]);
}
}
break;
case SUB:
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < num; ++i) {
for (index_t j = 0; j < size1; ++j) {
output_ptr[i * size1 + j] = swap ?
input0_ptr[i * size1 + j] - input1_ptr[j] :
input1_ptr[j] - input0_ptr[i * size1 + j];
}
}
break;
case DIV:
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < num; ++i) {
for (index_t j = 0; j < size1; ++j) {
if (!swap) {
MACE_CHECK(fabs(input1_ptr[j]) > 1e-6, "cannot divided by 0.");
output_ptr[i * size1 + j] =
input0_ptr[i * size1 + j] / input1_ptr[j];
} else {
MACE_CHECK(fabs(input0_ptr[i * size1 + j]) > 1e-6,
"cannot divided by 0.");
output_ptr[i * size1 + j] =
input1_ptr[j] / input0_ptr[i * size1 + j];
}
}
}
break;
case SQR_DIFF:
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < num; ++i) {
for (index_t j = 0; j < size1; ++j) {
const T tmp = input0_ptr[i * size1 + j] - input1_ptr[j];
output_ptr[i * size1 + j] = tmp * tmp;
}
}
break;
default:
LOG(FATAL) << "Eltwise op not support type " << type_;
}
}
}
};
......@@ -111,11 +248,16 @@ struct EltwiseFunctor : EltwiseFunctorBase {
#ifdef MACE_ENABLE_OPENCL
template <typename T>
struct EltwiseFunctor<DeviceType::OPENCL, T> : EltwiseFunctorBase {
EltwiseFunctor(const EltwiseType type, const std::vector<float> &coeff)
EltwiseFunctor(const EltwiseType type,
const std::vector<float> &coeff)
: EltwiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input0,
const Tensor *input1,
const index_t start_axis,
const bool is_scaler,
const float value,
const bool swap,
Tensor *output,
StatsFuture *future);
......
#include <common.h>
__kernel void cwise(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__private const int width,
__private const int channel,
__private const float value,
__write_only image2d_t output) {
const int w = get_global_id(0);
const int hb = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || hb >= global_size_dim1) return;
#endif
const int remain_chan = channel - mul24((w / width), 4);
DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value};
DATA_TYPE4 out;
#if CWISE_TYPE == 0
out = in0 * in1;
#elif CWISE_TYPE == 1
out = in0 + in1;
#elif CWISE_TYPE == 2
out = fmax(in0, in1);
#elif CWISE_TYPE == 3
out = fmin(in0, in1);
#elif CWISE_TYPE == 4
out = in0 - in1;
#elif CWISE_TYPE == 5
out = in0 / in1;
#elif CWISE_TYPE == 6
in1 = (DATA_TYPE4)(0, 0, 0, 0);
out = in1 - in0;
#elif CWISE_TYPE == 7
out = fabs(in0);
#endif
#if CWISE_TYPE == 1 || CWISE_TYPE == 2 || CWISE_TYPE == 3 || CWISE_TYPE == 4
if (remain_chan < 4) {
switch (remain_chan) {
case 1:
out.y = 0;
case 2:
out.z = 0;
case 3:
out.w = 0;
}
}
#endif
WRITE_IMAGET(output, (int2)(w, hb), out);
}
#include <common.h>
__kernel void eltwise(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input0,
__read_only image2d_t input1,
__private const float value,
__private const int height,
__private const int width,
__private const int channel,
#ifdef COEFF_SUM
__private const float coeff0,
__private const float coeff1,
#endif
__write_only image2d_t output) {
const int w = get_global_id(0);
const int hb = get_global_id(1);
const int c = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || hb >= global_size_dim1) return;
if (c >= global_size_dim0 || w >= global_size_dim1 || hb >= global_size_dim2)
return;
#endif
DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb));
int pos_w;
int pos_h;
#if START_AXIS == 0
pos_w = mad24(c, width, w);
pos_h = hb;
#elif START_AXIS == 1
pos_w = mad24(c, width, w);
pos_h = hb % height;
#elif START_AXIS == 2
pos_w = mad24(c, width, w);
pos_h = 0;
#elif START_AXIS == 3
pos_w = c;
pos_h = 0;
#endif
const int pos = mad24(c, width, w);
const int remain_channel = channel - 4 * c;
DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 in1 ;
#if IS_SCALER == 1
in1 = (DATA_TYPE4){value, value, value, value};
#else
in1 = READ_IMAGET(input1, SAMPLER, (int2)(pos_w, pos_h));
#endif
DATA_TYPE4 out;
#if ELTWISE_TYPE == 0
out = in0 * in1;
#elif ELTWISE_TYPE == 1
#ifdef COEFF_SUM
out = mad(coeff0, in0, mad(coeff1, in1, 0));
#if NEEDSWAP == 0
out = mad(coeff0, in0, mad(coeff1, in1, 0));
#else
out = mad(coeff1, in0, mad(coeff0, in1, 0));
#endif
#else
out = in0 + in1;
#endif
......@@ -34,8 +66,49 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS
#elif ELTWISE_TYPE == 3
out = fmin(in0, in1);
#elif ELTWISE_TYPE == 4
out = in0 - in1;
#if NEED_SWAP == 0
out = in0 - in1;
#else
out = in1 - in0;
#endif
#elif ELTWISE_TYPE == 5
#if NEED_SWAP == 0
if (fabs(in1.x) > 0.000001f)
out.x = in0.x / in1.x;
if (fabs(in1.y) > 0.000001f)
out.y = in0.y / in1.y;
if (fabs(in1.z) > 0.000001f)
out.z = in0.z / in1.z;
if (fabs(in1.w) > 0.000001f)
out.w = in0.w / in1.w;
#else
if (fabs(in1.x) > 0.000001f)
out.x = in1.x / in0.x;
if (fabs(in1.y) > 0.000001f)
out.y = in1.y / in0.y;
if (fabs(in1.z) > 0.000001f)
out.z = in1.z / in0.z;
if (fabs(in1.w) > 0.000001f)
out.w = in1.w / in0.w;
#endif
#elif ELTWISE_TYPE == 8
DATA_TYPE4 diff = in0 - in1;
out = diff * diff;
#endif
#if ELTWISE_TYPE == 1 || ELTWISE_TYPE == 2 || ELTWISE_TYPE == 3 \
|| ELTWISE_TYPE == 4 || ELTWISE_TYPE == 8
if (remain_channel < 4) {
switch (remain_channel) {
case 1:
out.y = 0;
case 2:
out.z = 0;
case 3:
out.w = 0;
}
}
#endif
WRITE_IMAGET(output, (int2)(w, hb), out);
WRITE_IMAGET(output, (int2)(pos, hb), out);
}
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/cwise.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
void CWiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_pixels = channel_blocks * width;
const index_t batch_height_pixels = batch * height;
auto runtime = OpenCLRuntime::Global();
const uint32_t gws[2] = {static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels)};
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("cwise");
built_options.emplace("-Dcwise=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DCWISE_TYPE=", type_));
if (runtime->IsOutOfRangeCheckEnabled()) {
built_options.emplace("-DOUT_OF_RANGE_CHECK");
kernel_error_ = std::move(std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(DeviceType::OPENCL), 1)));
kernel_error_->Map(nullptr);
*(kernel_error_->mutable_data<char>()) = 0;
kernel_error_->UnMap();
}
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
kernel_ = runtime->BuildKernel("cwise", kernel_name, built_options);
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_.setArg(idx++,
*(static_cast<cl::Buffer *>(kernel_error_->buffer())));
}
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel_.setArg(idx++, gws[0]);
kernel_.setArg(idx++, gws[1]);
}
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<int32_t>(width));
kernel_.setArg(idx++, static_cast<int32_t>(channels));
kernel_.setArg(idx++, static_cast<float>(coeff_));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 0};
std::stringstream ss;
ss << "cwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_error_->Map(nullptr);
char *kerror_code = kernel_error_->mutable_data<char>();
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
}
template struct CWiseFunctor<DeviceType::OPENCL, float>;
template struct CWiseFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
......@@ -23,6 +23,10 @@ namespace kernels {
template <typename T>
void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
const Tensor *input1,
const index_t start_axis,
const bool is_scaler,
const float value,
const bool swap,
Tensor *output,
StatsFuture *future) {
const index_t batch = input0->dim(0);
......@@ -31,14 +35,15 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
const index_t channels = input0->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_pixels = channel_blocks * width;
const index_t batch_height_pixels = batch * height;
const uint32_t gws[2] = {static_cast<uint32_t>(width_pixels),
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(batch_height_pixels)};
const int scaler = is_scaler ? 1 : 0;
const int need_swap = swap ? 1 : 0;
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
......@@ -47,6 +52,9 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DELTWISE_TYPE=", type_));
built_options.emplace(MakeString("-DSTART_AXIS=", start_axis));
built_options.emplace(MakeString("-DIS_SCALER=", scaler));
built_options.emplace(MakeString("-DNEEDSWAP=", need_swap));
if (runtime->IsOutOfRangeCheckEnabled()) {
built_options.emplace("-DOUT_OF_RANGE_CHECK");
kernel_error_ = std::move(std::unique_ptr<Buffer>(
......@@ -73,9 +81,14 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel_.setArg(idx++, gws[0]);
kernel_.setArg(idx++, gws[1]);
kernel_.setArg(idx++, gws[2]);
}
kernel_.setArg(idx++, *(input0->opencl_image()));
kernel_.setArg(idx++, *(input1->opencl_image()));
kernel_.setArg(idx++, value);
kernel_.setArg(idx++, static_cast<int32_t>(height));
kernel_.setArg(idx++, static_cast<int32_t>(width));
kernel_.setArg(idx++, static_cast<int32_t>(channels));
if (!coeff_.empty()) {
kernel_.setArg(idx++, coeff_[0]);
kernel_.setArg(idx++, coeff_[1]);
......@@ -85,11 +98,11 @@ void EltwiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input0,
input_shape_ = input0->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 0};
const std::vector<uint32_t> lws = {8, kwg_size_ / 64, 8, 0};
std::stringstream ss;
ss << "eltwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_error_->Map(nullptr);
char *kerror_code = kernel_error_->mutable_data<char>();
......
......@@ -152,6 +152,9 @@ BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 15, 1, 1, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 1, 15, 1, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 7, 7, 1, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 7, 7, 2, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 7, 7, 3, 1, SAME, 128);
// 3 channels input
BM_CONV_2D(1, 3, 480, 480, 1, 1, 1, 1, VALID, 3);
......
......@@ -878,7 +878,7 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape,
1e-4, 1e-4);
};
for (int kernel_size : {3, 5}) {
for (int kernel_size : {3, 5, 7}) {
for (int stride : {2, 3}) {
func(kernel_size, kernel_size, stride, stride);
}
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/cwise.h"
namespace mace {
namespace ops {
void Register_CWise(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
CWiseOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
CWiseOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
CWiseOp<DeviceType::OPENCL, half>);
#endif // MACE_ENABLE_OPENCL
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_CWISE_H_
#define MACE_OPS_CWISE_H_
#include <string>
#include "mace/core/operator.h"
#include "mace/kernels/cwise.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class CWiseOp : public Operator<D, T> {
public:
CWiseOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
x_(OperatorBase::GetSingleArgument<float>("x", 1.0)),
functor_(static_cast<kernels::CWiseType>(
OperatorBase::GetSingleArgument<int>(
"type", static_cast<int>(
kernels::CWiseType::ADD))),
this->x_) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
Tensor *output_tensor = this->Output(OUTPUT);
output_tensor->ResizeLike(input_tensor);
functor_(input_tensor, output_tensor, future);
return true;
}
protected:
const float x_;
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::CWiseFunctor<D, T> functor_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_CWISE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void CWise(int iters, int batch, int channels,
int height, int width, float x, int type) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("CWise", "CWiseBM")
.Input("InputImage")
.Output("Output")
.AddIntArg("type", type)
.AddFloatArg("x", x)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("CWise", "CWiseBM")
.Input("Input")
.Output("Output")
.AddIntArg("type", type)
.AddFloatArg("x", x)
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define BM_CWISE_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \
static void \
BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
CWise<DEVICE, TYPE>(iters, N, C, H, W, X, G); \
} \
BENCHMARK( \
BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE)
#define BM_CWISE(N, C, H, W, X, G) \
BM_CWISE_MACRO(N, C, H, W, X, G, float, CPU); \
BM_CWISE_MACRO(N, C, H, W, X, G, float, OPENCL); \
BM_CWISE_MACRO(N, C, H, W, X, G, half, OPENCL);
BM_CWISE(1, 1, 512, 512, 2, 0);
BM_CWISE(1, 3, 128, 128, 2, 1);
BM_CWISE(1, 3, 512, 512, 2, 4);
BM_CWISE(1, 32, 112, 112, 2, 5);
BM_CWISE(1, 32, 112, 112, 2, 6);
BM_CWISE(1, 32, 112, 112, 2, 7);
BM_CWISE(1, 64, 256, 256, 3, 0);
BM_CWISE(1, 64, 512, 512, 3, 1);
BM_CWISE(1, 128, 56, 56, 3, 4);
BM_CWISE(1, 128, 256, 256, 3, 5);
BM_CWISE(1, 64, 512, 512, 3, 6);
BM_CWISE(1, 64, 512, 512, 3, 7);
BM_CWISE(1, 256, 14, 14, 3, 0);
BM_CWISE(1, 512, 14, 14, 3, 1);
BM_CWISE(1, 1024, 7, 7, 3, 4);
BM_CWISE(32, 1, 256, 256, 3, 5);
BM_CWISE(32, 1, 256, 256, 3, 6);
BM_CWISE(32, 1, 256, 256, 3, 7);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "../kernels/cwise.h"
namespace mace {
namespace ops {
namespace test {
class CWiseOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple(const kernels::CWiseType type,
const std::vector<index_t> &shape,
const std::vector<float> &input0,
const float x,
const std::vector<float> &output) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input1", shape, input0);
if (D == DeviceType::CPU) {
OpDefBuilder("CWise", "CWiseTest")
.Input("Input1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
BufferToImage<D, half>(&net, "Input1", "InputImg1",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("CWise", "CWiseTest")
.Input("InputImg1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x)
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImg", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
}
auto expected = CreateTensor<float>(shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5, 1e-3);
}
} // namespace
TEST_F(CWiseOpTest, CPUSimple) {
Simple<DeviceType::CPU>(kernels::CWiseType::MUL, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6});
Simple<DeviceType::CPU>(kernels::CWiseType::ADD, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8});
Simple<DeviceType::CPU>(kernels::CWiseType::DIV, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60});
Simple<DeviceType::CPU>(kernels::CWiseType::SUB, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4});
Simple<DeviceType::CPU>(kernels::CWiseType::NEG, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6});
Simple<DeviceType::CPU>(kernels::CWiseType::ABS, {1, 1, 2, 3},
{1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6});
}
TEST_F(CWiseOpTest, GPUSimple) {
Simple<DeviceType::OPENCL>(kernels::CWiseType::MUL, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6});
Simple<DeviceType::OPENCL>(kernels::CWiseType::ADD, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8});
Simple<DeviceType::OPENCL>(kernels::CWiseType::DIV, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60});
Simple<DeviceType::OPENCL>(kernels::CWiseType::SUB, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4});
Simple<DeviceType::OPENCL>(kernels::CWiseType::NEG, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6});
Simple<DeviceType::OPENCL>(kernels::CWiseType::ABS, {1, 1, 2, 3},
{1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6});
}
namespace {
template <DeviceType D, typename T>
void RandomTest(const kernels::CWiseType type,
const std::vector<index_t> &shape) {
testing::internal::LogToStderr();
srand(time(NULL));
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input1", shape);
OpDefBuilder("CWise", "CWiseTest")
.Input("Input1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", 1.2)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
BufferToImage<D, T>(&net, "Input1", "InputImg1",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("CWise", "CWiseTest")
.Input("InputImg1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", 1.2)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImg", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DT_FLOAT) {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-5, 1e-4);
} else {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-2, 1e-2);
}
}
} // namespace
TEST_F(CWiseOpTest, OPENCLRandomFloat) {
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::MUL,
{3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::ADD,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::SUB,
{3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::DIV,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::NEG,
{13, 32, 32, 64});
}
TEST_F(CWiseOpTest, OPENCLRandomHalf) {
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::MUL,
{3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::ADD,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::SUB,
{3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::DIV,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::NEG,
{13, 32, 32, 64});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -32,24 +32,53 @@ class EltwiseOp : public Operator<D, T> {
OperatorBase::GetRepeatedArgument<float>("coeff")) {}
bool Run(StatsFuture *future) override {
const Tensor *input0 = this->Input(0);
const Tensor *input1 = this->Input(1);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input0->dim_size() == input1->dim_size())
if (this->InputSize() == 1) {
const Tensor* input = this->Input(0);
Tensor *output = this->Output(OUTPUT);
start_axis_ = input->dim_size() - 1;
is_scaler_ = true;
output->ResizeLike(input);
const float x = OperatorBase::GetSingleArgument<float>("x", 1.0);
functor_(input, nullptr, start_axis_,
is_scaler_, x, false, output, future);
} else {
const index_t size0 = this->Input(0)->size();
const index_t size1 = this->Input(1)->size();
const bool swap = (size0 < size1);
const Tensor *input0 = swap ? this->Input(1) : this->Input(0);
const Tensor *input1 = swap ? this->Input(0) : this->Input(1);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input0->dim_size() == input1->dim_size())
<< "Inputs of Eltwise op must be same shape";
for (int i = 0; i < input0->dim_size(); ++i) {
MACE_CHECK(input0->dim(i) == input1->dim(i))
<< "Inputs of Eltwise op must be same shape";
start_axis_ = input0->dim_size() - 1;
is_scaler_ = (input1->size() == 1);
uint32_t compared_size = 1;
if (!is_scaler_) {
while (start_axis_ >= 0) {
MACE_CHECK(input0->dim(start_axis_) == input1->dim(start_axis_),
"Invalid inputs dimension at axis: ") << start_axis_
<< "input 0: " << input0->dim(start_axis_)
<< "input 1: " << input1->dim(start_axis_);
compared_size *= input1->dim(start_axis_);
if (compared_size == input1->size()) {
break;
}
start_axis_--;
}
}
output->ResizeLike(input0);
const float x = OperatorBase::GetSingleArgument<float>("x", 1.0);
functor_(input0, input1, start_axis_,
is_scaler_, x, swap, output, future);
}
output->ResizeLike(input0);
functor_(input0, input1, output, future);
return true;
}
private:
kernels::EltwiseFunctor<D, T> functor_;
index_t start_axis_;
bool is_scaler_;
private:
OP_OUTPUT_TAGS(OUTPUT);
......
......@@ -25,23 +25,26 @@ class EltwiseOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple(const kernels::EltwiseType type,
const std::vector<index_t> &shape,
const std::vector<index_t> &shape0,
const std::vector<index_t> &shape1,
const std::vector<float> &input0,
const std::vector<float> &input1,
const std::vector<float> &output,
const float x = 1.f,
const std::vector<float> coeff = {}) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input1", shape, input0);
net.AddInputFromArray<D, float>("Input2", shape, input1);
net.AddInputFromArray<D, float>("Input1", shape0, input0);
net.AddInputFromArray<D, float>("Input2", shape1, input1);
if (D == DeviceType::CPU) {
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input1")
.Input("Input2")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x)
.AddFloatsArg("coeff", coeff)
.Output("Output")
.Finalize(net.NewOperatorDef());
......@@ -57,6 +60,7 @@ void Simple(const kernels::EltwiseType type,
.Input("InputImg1")
.Input("InputImg2")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x)
.AddFloatsArg("coeff", coeff)
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
......@@ -68,7 +72,7 @@ void Simple(const kernels::EltwiseType type,
kernels::BufferType::IN_OUT_CHANNEL);
}
auto expected = CreateTensor<float>(shape, output);
auto expected = CreateTensor<float>(shape0, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
......@@ -76,53 +80,200 @@ void Simple(const kernels::EltwiseType type,
TEST_F(EltwiseOpTest, CPUSimple) {
Simple<DeviceType::CPU>(kernels::EltwiseType::PROD, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{1, 4, 9, 16, 25, 36});
Simple<DeviceType::CPU>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{2, 4, 6, 8, 10, 12});
Simple<DeviceType::CPU>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{3, 6, 9, 12, 15, 18}, {2, 1});
{3, 6, 9, 12, 15, 18}, 1., {2, 1});
Simple<DeviceType::CPU>(kernels::EltwiseType::MAX, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6},
{1, 2, 3, 4, 6, 6});
Simple<DeviceType::CPU>(kernels::EltwiseType::MIN, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6},
{1, 1, 3, 3, 5, 6});
Simple<DeviceType::CPU>(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6},
{0, 1, 0, 1, 1, 0});
Simple<DeviceType::CPU>(kernels::EltwiseType::DIV, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 2, 10, 24},
{1, 2, 1, 2, 0.5, 0.25});
Simple<DeviceType::CPU>(kernels::EltwiseType::PROD, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3},
{1, 4, 9, 4, 10, 18});
Simple<DeviceType::CPU>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3},
{2, 4, 6, 5, 7, 9});
Simple<DeviceType::CPU>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3},
{3, 6, 9, 9, 12, 15}, 1., {2, 1});
Simple<DeviceType::CPU>(kernels::EltwiseType::MAX, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3},
{1, 2, 3, 4, 5, 6});
Simple<DeviceType::CPU>(kernels::EltwiseType::MIN, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3},
{1, 1, 3, 1, 1, 3});
Simple<DeviceType::CPU>(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3},
{0, 1, 0, 9, 16, 9});
Simple<DeviceType::CPU>(kernels::EltwiseType::DIV, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3},
{1, 2, 1, 4, 5, 2});
Simple<DeviceType::CPU>(kernels::EltwiseType::PROD, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {2},
{2, 4, 6, 8, 10, 12}, 2);
Simple<DeviceType::CPU>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {2},
{3, 4, 5, 6, 7, 8}, 2);
Simple<DeviceType::CPU>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {2},
{4, 6, 8, 10, 12, 14}, 2, {2, 1});
Simple<DeviceType::CPU>(kernels::EltwiseType::MAX, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {3},
{3, 3, 3, 4, 5, 6}, 3);
Simple<DeviceType::CPU>(kernels::EltwiseType::MIN, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {3},
{1, 2, 3, 3, 3, 3}, 3);
Simple<DeviceType::CPU>(kernels::EltwiseType::DIV, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {0.5},
{2, 4, 6, 8, 10, 12}, 0.5);
Simple<DeviceType::CPU>(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {3},
{4, 1, 0, 1, 4, 9}, 3);
}
TEST_F(EltwiseOpTest, GPUSimple) {
Simple<DeviceType::OPENCL>(kernels::EltwiseType::PROD, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{1, 4, 9, 16, 25, 36});
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{1, 4, 9, 16, 25, 36});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{2, 4, 6, 8, 10, 12});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{3, 6, 9, 12, 15, 18}, 1., {2, 1});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::MAX, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6},
{1, 2, 3, 4, 6, 6});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::MIN, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6},
{1, 1, 3, 3, 5, 6});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::DIV, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 2, 10, 24},
{1, 2, 1, 2, 0.5, 0.25});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3},
{1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6},
{0, 1, 0, 1, 1, 0});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::PROD, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3},
{1, 4, 9, 4, 10, 18});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{2, 4, 6, 8, 10, 12});
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3},
{2, 4, 6, 5, 7, 9});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6},
{3, 6, 9, 12, 15, 18}, {2, 1});
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 2, 3},
{3, 6, 9, 9, 12, 15}, 1., {2, 1});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::MAX, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6},
{1, 2, 3, 4, 6, 6});
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3},
{1, 2, 3, 4, 5, 6});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::MIN, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6},
{1, 1, 3, 3, 5, 6});
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3},
{1, 1, 3, 1, 1, 3});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3},
{0, 1, 0, 9, 16, 9});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::DIV, {1, 1, 2, 3},
{1, 1, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 3},
{1, 2, 1, 4, 5, 2});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::PROD, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {2},
{2, 4, 6, 8, 10, 12}, 2);
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {2},
{3, 4, 5, 6, 7, 8}, 2);
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SUM, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {2},
{4, 6, 8, 10, 12, 14}, 2, {2, 1});
Simple<DeviceType::OPENCL>(kernels::EltwiseType::MAX, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {3},
{3, 3, 3, 4, 5, 6}, 3);
Simple<DeviceType::OPENCL>(kernels::EltwiseType::MIN, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {3},
{1, 2, 3, 3, 3, 3}, 3);
Simple<DeviceType::OPENCL>(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {3},
{4, 1, 0, 1, 4, 9}, 3);
Simple<DeviceType::OPENCL>(kernels::EltwiseType::DIV, {1, 1, 2, 3},
{1, 1, 1, 1},
{1, 2, 3, 4, 5, 6}, {0.5},
{2, 4, 6, 8, 10, 12}, 0.5);
}
namespace {
template <DeviceType D, typename T>
void RandomTest(const kernels::EltwiseType type,
const std::vector<index_t> &shape) {
const std::vector<index_t> &shape1,
const std::vector<index_t> &shape2) {
testing::internal::LogToStderr();
srand(time(NULL));
// Construct graph
OpsTestNet net;
bool is_divide = (type == kernels::EltwiseType::DIV);
// Add input data
net.AddRandomInput<D, float>("Input1", shape);
net.AddRandomInput<D, float>("Input2", shape);
net.AddRandomInput<D, float>("Input1", shape1, true, is_divide);
net.AddRandomInput<D, float>("Input2", shape2, true, is_divide);
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input1")
......@@ -166,24 +317,110 @@ void RandomTest(const kernels::EltwiseType type,
TEST_F(EltwiseOpTest, OPENCLRandomFloat) {
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::PROD,
{3, 23, 37, 19},
{3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::SUM,
{13, 32, 32, 64},
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::MAX,
{3, 32, 32, 64},
{3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::MIN,
{13, 32, 32, 64},
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::DIV,
{13, 32, 32, 64},
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::SQR_DIFF,
{13, 32, 32, 64},
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::PROD,
{3, 23, 37, 19},
{1, 1, 37, 19});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::SUM,
{13, 32, 32, 64},
{1, 1, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::MAX,
{3, 32, 32, 64},
{1, 1, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::MIN,
{13, 32, 32, 64},
{1, 1, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::DIV,
{13, 32, 32, 63},
{1, 1, 32, 63});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::SQR_DIFF,
{13, 32, 32, 64},
{1, 1, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::PROD,
{3, 23, 37, 19},
{1, 1, 1, 19});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::SUM,
{13, 32, 32, 64},
{1, 1, 1, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::MAX,
{3, 32, 32, 64},
{1, 1, 1, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::MIN,
{13, 32, 32, 64},
{1, 1, 1, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::DIV,
{13, 32, 32, 64},
{1, 1, 1, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::EltwiseType::SQR_DIFF,
{13, 32, 32, 64},
{1, 1, 1, 64});
}
TEST_F(EltwiseOpTest, OPENCLRandomHalf) {
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::PROD,
{3, 23, 37, 19},
{3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::PROD,
{3, 23, 37, 19},
{1, 23, 37, 19});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::PROD,
{3, 23, 37, 19},
{1, 1, 37, 19});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::PROD,
{3, 23, 37, 19},
{1, 1, 1, 19});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::SUM,
{13, 32, 32, 64});
{13, 32, 32, 64},
{1, 1, 1, 1});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::SUM,
{13, 32, 32, 64},
{1, 1, 1, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::SUM,
{13, 32, 32, 64},
{1, 1, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::MAX,
{3, 32, 32, 64},
{3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::MAX,
{3, 32, 32, 64},
{1, 1, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::MIN,
{13, 32, 32, 64},
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::SQR_DIFF,
{13, 32, 32, 64},
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::SQR_DIFF,
{13, 32, 32, 64},
{1, 1, 1, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::SQR_DIFF,
{13, 32, 32, 64},
{1, 1, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::DIV,
{13, 32, 32, 64},
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::DIV,
{13, 32, 32, 64},
{1, 1, 1, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::EltwiseType::DIV,
{13, 32, 32, 64},
{1, 1, 32, 64});
}
} // namespace test
......
......@@ -150,7 +150,8 @@ class OpsTestNet {
template<DeviceType D, typename T>
void AddRandomInput(const std::string &name,
const std::vector<index_t> &shape,
bool positive = true) {
bool positive = true,
bool truncate = false) {
Tensor *input =
ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum<T>::v());
input->Resize(shape);
......@@ -162,14 +163,24 @@ class OpsTestNet {
std::normal_distribution<float> nd(0, 1);
if (DataTypeToEnum<T>::value == DT_HALF) {
std::generate(
input_data, input_data + input->size(), [&gen, &nd, positive] {
return half_float::half_cast<half>(positive ? std::abs(nd(gen))
: nd(gen));
input_data, input_data + input->size(),
[&gen, &nd, positive, truncate] {
float d = nd(gen);
if (truncate) {
if (std::abs(d) > 100.f) d = 100.f;
if (std::abs(d) < 0.001f) d = 0.001f;
}
return half_float::half_cast<half>(positive ?std::abs(d) : d);
});
} else {
std::generate(input_data, input_data + input->size(),
[&gen, &nd, positive] {
return positive ? std::abs(nd(gen)) : nd(gen);
[&gen, &nd, positive, truncate] {
float d = nd(gen);
if (truncate) {
if (std::abs(d) > 100.f) d = 100.f;
if (std::abs(d) < 0.001f) d = 0.001f;
}
return (positive ?std::abs(d) : d);
});
}
}
......
......@@ -1188,8 +1188,11 @@ def convert_to_mace_pb(model_file, weight_file, input_node_str,
print "PB Converted."
if device == 'gpu':
print "start optimize memory."
mem_optimizer = memory_optimizer.MemoryOptimizer(net_def)
mem_optimizer.optimize()
memory_optimizer.optimize_gpu_memory(net_def)
print "Memory optimization done."
elif device == 'cpu':
print "start optimize memory."
memory_optimizer.optimize_cpu_memory(net_def)
print "Memory optimization done."
return net_def
......@@ -22,13 +22,13 @@ class MemoryOptimizer(object):
self.net_def = net_def
self.idle_mem = set()
self.op_mem = {} # op_name->mem_id
self.mem_block = {} # mem_id->[x, y]
self.mem_block = {} # mem_id->[size] or mem_id->[x, y]
self.total_mem_count = 0
self.ref_counter = {}
consumers = {}
for op in net_def.op:
if self.is_buffer_image_op(op):
if not self.op_need_optimize_memory(op):
continue
for ipt in op.input:
if ipt not in consumers:
......@@ -36,7 +36,7 @@ class MemoryOptimizer(object):
consumers[ipt].append(op)
# only ref op's output tensor
for op in net_def.op:
if self.is_buffer_image_op(op):
if not self.op_need_optimize_memory(op):
continue
for output in op.output:
tensor_name = output
......@@ -45,29 +45,47 @@ class MemoryOptimizer(object):
else:
self.ref_counter[tensor_name] = 0
def is_buffer_image_op(self, op):
if op.type == 'BufferToImage':
for arg in op.arg:
if arg.name == 'mode' and arg.i == 0:
return True
return op.type == 'ImageToBuffer'
def op_need_optimize_memory(self, op):
return True
def get_mem_size(self, op_type, output_shape):
mem_size = [0, 0]
if op_type == 'WinogradTransform' or op_type == 'MatMul':
mem_size[0] = output_shape[2] * output_shape[3]
mem_size[1] = output_shape[0] * int((output_shape[1] + 3) / 4)
else:
mem_size[0] = output_shape[2] * int((output_shape[3] + 3) / 4)
mem_size[1] = output_shape[0] * output_shape[1]
return mem_size
def get_op_mem_block(self, op_type, output_shape):
return [reduce(operator.mul, output_shape, 1)]
def mem_size(self, memory_block):
return memory_block[0]
def sub_mem_block(self, mem_block1, mem_block2):
return self.mem_size(mem_block1) - self.mem_size(mem_block2)
def resize_mem_block(self, old_mem_block, op_mem_block):
return [max(old_mem_block[0], op_mem_block[0])]
def add_net_mem_blocks(self):
for mem in self.mem_block:
arena = self.net_def.mem_arena
block = arena.mem_block.add()
block.mem_id = mem
block.x = self.mem_block[mem][0]
block.y = 1
def mem_area(self, memory_size):
return memory_size[0] * memory_size[1]
def get_total_origin_mem_size(self):
origin_mem_size = 0
for op in self.net_def.op:
if not self.op_need_optimize_memory(op):
continue
origin_mem_size += reduce(operator.mul, op.output_shape[0].dims, 1)
return origin_mem_size
def get_total_optimized_mem_size(self):
optimized_mem_size = 0
for mem in self.mem_block:
print mem, self.mem_block[mem]
optimized_mem_size += self.mem_size(self.mem_block[mem])
return optimized_mem_size
def optimize(self):
for op in self.net_def.op:
if self.is_buffer_image_op(op):
if not self.op_need_optimize_memory(op):
continue
if not op.output_shape:
print('WARNING: There is no output shape information to '
......@@ -78,38 +96,42 @@ class MemoryOptimizer(object):
'the number of output.')
return
for i in range(len(op.output)):
op_mem_size = self.get_mem_size(op.type,
op.output_shape[i].dims)
op_mem_block = self.get_op_mem_block(op.type,
op.output_shape[i].dims)
mem_id = -1
if len(self.idle_mem) > 0:
best_mem_candidate_id = -1
best_mem_candidate_delta_area = sys.maxint
best_mem_candidate_shape = []
best_mem_add_size = sys.maxint
best_mem_waste_size = sys.maxint
for mid in self.idle_mem:
reuse_mem_size = self.mem_block[mid]
resize_mem_size = [
max(reuse_mem_size[0], op_mem_size[0]),
max(reuse_mem_size[1], op_mem_size[1])
]
delta_mem_area = self.mem_area(
resize_mem_size) - self.mem_area(reuse_mem_size)
if delta_mem_area < best_mem_candidate_delta_area:
best_mem_candidate_id = mid
best_mem_candidate_delta_area = delta_mem_area
best_mem_candidate_shape = resize_mem_size
if best_mem_candidate_delta_area <= self.mem_area(
op_mem_size):
# reuse
self.mem_block[
best_mem_candidate_id] = best_mem_candidate_shape
mem_id = best_mem_candidate_id
old_mem_block = self.mem_block[mid]
new_mem_block = self.resize_mem_block(
old_mem_block, op_mem_block)
add_mem_size = self.sub_mem_block(new_mem_block,
old_mem_block)
waste_mem_size = self.sub_mem_block(new_mem_block,
op_mem_block)
# minimize add_mem_size; if best_mem_add_size is 0,
# then minimize waste_mem_size
if (best_mem_add_size > 0 and
add_mem_size < best_mem_add_size) \
or (best_mem_add_size == 0 and
waste_mem_size < best_mem_waste_size):
best_mem_id = mid
best_mem_add_size = add_mem_size
best_mem_waste_size = waste_mem_size
best_mem_block = new_mem_block
# if add mem size < op mem size, then reuse it
if best_mem_add_size <= self.mem_size(op_mem_block):
self.mem_block[best_mem_id] = best_mem_block
mem_id = best_mem_id
self.idle_mem.remove(mem_id)
if mem_id == -1:
mem_id = self.total_mem_count
self.total_mem_count += 1
self.mem_block[mem_id] = op_mem_size
self.mem_block[mem_id] = op_mem_block
op.mem_id.extend([mem_id])
self.op_mem[op.output[i]] = mem_id
......@@ -123,6 +145,43 @@ class MemoryOptimizer(object):
elif self.ref_counter[ipt] < 0:
raise Exception('ref count is less than 0')
self.add_net_mem_blocks()
print('total op: %d', len(self.net_def.op))
print('origin mem: %d, optimized mem: %d',
self.get_total_origin_mem_size(),
self.get_total_optimized_mem_size())
class GPUMemoryOptimizer(MemoryOptimizer):
def op_need_optimize_memory(self, op):
if op.type == 'BufferToImage':
for arg in op.arg:
if arg.name == 'mode' and arg.i == 0:
return False
return op.type != 'ImageToBuffer'
def get_op_mem_block(self, op_type, output_shape):
mem_block = [0, 0]
if op_type == 'WinogradTransform' or op_type == 'MatMul':
mem_block[0] = output_shape[2] * output_shape[3]
mem_block[1] = output_shape[0] * int((output_shape[1] + 3) / 4)
else:
mem_block[0] = output_shape[2] * int((output_shape[3] + 3) / 4)
mem_block[1] = output_shape[0] * output_shape[1]
return mem_block
def mem_size(self, memory_block):
return memory_block[0] * memory_block[1] * 4
def resize_mem_block(self, old_mem_block, op_mem_block):
resize_mem_block = [
max(old_mem_block[0], op_mem_block[0]),
max(old_mem_block[1], op_mem_block[1])
]
return resize_mem_block
def add_net_mem_blocks(self):
for mem in self.mem_block:
arena = self.net_def.mem_arena
block = arena.mem_block.add()
......@@ -130,21 +189,12 @@ class MemoryOptimizer(object):
block.x = self.mem_block[mem][0]
block.y = self.mem_block[mem][1]
print('total op: %d', len(self.net_def.op))
origin_mem_size = 0
optimized_mem_size = 0
for op in self.net_def.op:
if self.is_buffer_image_op(op):
continue
origin_mem_size += reduce(operator.mul, op.output_shape[0].dims, 1)
for mem in self.mem_block:
print mem, self.mem_block[mem]
optimized_mem_size += reduce(operator.mul, self.mem_block[mem], 4)
print('origin mem: %d, optimized mem: %d', origin_mem_size,
optimized_mem_size)
def optimize_gpu_memory(net_def):
mem_optimizer = GPUMemoryOptimizer(net_def)
mem_optimizer.optimize()
def optimize_memory(net_def):
def optimize_cpu_memory(net_def):
mem_optimizer = MemoryOptimizer(net_def)
mem_optimizer.optimize()
......@@ -829,37 +829,25 @@ class TFConverter(object):
self.resolved_ops[op.name] = 1
self.unused_tensor.add(get_input_tensor(op, 1).name)
def convert_math(self, op, math_type):
def convert_eltwise(self, op, math_type):
op_def = self.net_def.op.add()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
op_def.name = op.name
if len(op.inputs) == 1:
op_def.type = "CWise"
op_def.input.extend([input.name for input in op.inputs])
x_arg = op_def.arg.add()
x_arg.name = 'x'
x_arg.f = 0
elif len(op.inputs) >= 2:
op_def.type = "Eltwise"
op_def.input.extend([input.name for input in op.inputs])
x_value = op.get_attr('x')
if len(op.inputs) >= 2:
input_tensor0 = get_input_tensor(op, 0)
input_tensor1 = get_input_tensor(op, 1)
if input_tensor0.shape == input_tensor1.shape:
op_def.type = "Eltwise"
op_def.input.extend([input.name for input in op.inputs])
else:
op_def.type = "CWise"
x_value = 0
if len(input_tensor1.shape) == 4:
op_def.input.extend([op.inputs[1].name])
x_value = get_input_tensor(op, 0).eval().astype(np.float32)
else:
op_def.input.extend([op.inputs[0].name])
x_value = get_input_tensor(op, 1).eval().astype(np.float32)
x_arg = op_def.arg.add()
x_arg.name = 'x'
x_arg.f = x_value
if len(input_tensor0) == 1:
x_value = input_tensor0.eval().astype(np.float32)
elif len(input_tensor1) == 1:
x_value = input_tensor1.eval().astype(np.float32)
x_arg = op_def.arg.add()
x_arg.name = 'x'
x_arg.f = x_value
type_arg = op_def.arg.add()
type_arg.name = 'type'
type_arg.i = math_type_mode[math_type]
......@@ -1156,11 +1144,11 @@ class TFConverter(object):
elif op.type == 'SpaceToDepth':
self.convert_depth_to_space(op, False)
elif op.type in ['Neg', 'neg', 'Negative', 'negative']:
self.convert_math(op, 'NEG')
self.convert_eltwise(op, 'NEG')
elif op.type == 'Mul':
self.convert_math(op, 'MUL')
self.convert_eltwise(op, 'MUL')
elif op.type == 'Sub':
self.convert_math(op, 'SUB')
self.convert_eltwise(op, 'SUB')
elif self.is_softmax(op):
self.convert_softmax(op)
elif op.type in ['Relu', 'Sigmoid', 'Tanh']:
......@@ -1367,8 +1355,11 @@ def convert_to_mace_pb(model_file, input_node, input_shape, output_node,
print "Model Converted."
if device == 'gpu':
print "start optimize memory."
mem_optimizer = memory_optimizer.MemoryOptimizer(net_def)
mem_optimizer.optimize()
memory_optimizer.optimize_gpu_memory(net_def)
print "Memory optimization done."
elif device == 'cpu':
print "start optimize memory."
memory_optimizer.optimize_cpu_memory(net_def)
print "Memory optimization done."
return net_def
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册