提交 6b9a114c 编写于 作者: 李寅

Merge branch 'add_reduce_mean_op' into 'master'

add reduce mean

See merge request !582
......@@ -98,6 +98,7 @@ extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry);
extern void Register_Quantize(OperatorRegistry *op_registry);
extern void Register_ReduceMean(OperatorRegistry *op_registry);
extern void Register_Requantize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
......@@ -145,6 +146,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Proposal(this);
ops::Register_PSROIAlign(this);
ops::Register_Quantize(this);
ops::Register_ReduceMean(this);
ops::Register_Requantize(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
......
#include <common.h>
__kernel void reduce_mean(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__local float4* group_sum,
__private const int group_size,
__private const int partial_len,
__private const int remain_index,
__private const int batch,
__private const int in_height,
__private const int in_width,
__private const float in_height_r,
__private const float in_width_r,
__private const int channel_blocks,
__write_only image2d_t output) {
const int i = get_local_id(0);
const int j = get_local_id(1);
const int k = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (i >= local_size_dim0 || j >= local_size_dim1 || k >= global_size_dim2)
return;
const int dim0_size = local_size_dim0;
#else
const int dim0_size = get_local_size(0);
#endif
DATA_TYPE4 tmp = (DATA_TYPE4){0, 0, 0, 0};
const int index = j * dim0_size + i;
const int b = k / channel_blocks;
const int ch = k - b * channel_blocks;
DATA_TYPE4 in;
const int valid_part_len = select(partial_len,
partial_len - 1,
remain_index > 0 && index >= remain_index);
const int full_offset = index * partial_len;
const int base_offset = select(full_offset,
full_offset - (index - remain_index),
valid_part_len < partial_len);
#pragma unroll
for (int l = 0; l < valid_part_len; ++l) {
int offset = base_offset + l;
int h_id = floor(offset * in_width_r);
int w_id = offset - h_id * in_width;
int pos_x = mad24(ch, in_width, w_id);
int pos_y = mad24(b, in_height, h_id);
in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y));
tmp = tmp + in;
}
group_sum[index] = tmp;
#ifdef NON_QUALCOMM_ADRENO
barrier(CLK_LOCAL_MEM_FENCE);
#endif
if (i == 0 && j == 0) {
DATA_TYPE4 out = (DATA_TYPE4){0, 0, 0, 0};
#pragma unroll
for (int l = 0; l < group_size; ++l) {
out = out + group_sum[l];
}
out = out * in_height_r * in_width_r;
WRITE_IMAGET(output, (int2)(ch, b), out);
}
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/reduce_mean.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims.");
MACE_CHECK(input->dim_size() == 4,
"reduce mean gpu only support 4-dim input");
MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2,
"reduce mean gpu only support 1,2-axis reduce");
index_t batch = input->dim(0);
const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t image_size = static_cast<uint32_t >(in_height * in_width);
auto runtime = OpenCLRuntime::Global();
std::vector<uint32_t> gws(3);
std::vector<uint32_t> lws(3);
std::vector<index_t> output_shape{batch, 1, 1, channels};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
if (kernel_.get() == nullptr) {
const DataType dt = DataTypeToEnum<T>::value;
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("reduce_mean");
built_options.emplace("-Dreduce_mean=" + kernel_name);
if (input->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
built_options.emplace(dt == DT_HALF ? "-DFP16" : "");
} else {
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
}
if (runtime->gpu_type() != GPUType::QUALCOMM_ADRENO) {
built_options.emplace("-DNON_QUALCOMM_ADRENO");
}
if (runtime->IsOutOfRangeCheckEnabled()) {
built_options.emplace("-DOUT_OF_RANGE_CHECK");
kernel_error_ = std::move(std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(DeviceType::GPU))));
MACE_RETURN_IF_ERROR(kernel_error_->Allocate(1));
kernel_error_->Map(nullptr);
*(kernel_error_->mutable_data<char>()) = 0;
kernel_error_->UnMap();
}
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
kernel_ = runtime->BuildKernel("reduce_mean", kernel_name, built_options);
}
if (runtime->gpu_type() == GPUType::QUALCOMM_ADRENO) {
const uint32_t wave_size =
static_cast<uint32_t>(runtime->GetKernelWaveSize(kernel_));
gws = {4, (wave_size / 4), static_cast<uint32_t>(batch * channel_blocks)};
} else {
gws = {4, 16, static_cast<uint32_t>(batch * channel_blocks)};
}
lws = {gws[0], gws[1], 1};
const int group_size = lws[0] * lws[1] * lws[2];
const int partial_len = (image_size + group_size - 1) / group_size;
const int remain_index = image_size % group_size;
const float in_width_r = 1.f / in_width;
const float in_height_r = 1.f / in_height;
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_.setArg(idx++,
*(static_cast<cl::Buffer *>(kernel_error_->buffer())));
}
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel_.setArg(idx++, gws[0]);
kernel_.setArg(idx++, gws[1]);
kernel_.setArg(idx++, gws[2]);
}
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, (group_size * 4 * sizeof(float)),
nullptr);
kernel_.setArg(idx++, static_cast<int32_t>(group_size));
kernel_.setArg(idx++, static_cast<int32_t>(partial_len));
kernel_.setArg(idx++, static_cast<int32_t>(remain_index));
kernel_.setArg(idx++, static_cast<int32_t>(batch));
kernel_.setArg(idx++, static_cast<int32_t>(in_height));
kernel_.setArg(idx++, static_cast<int32_t>(in_width));
kernel_.setArg(idx++, in_height_r);
kernel_.setArg(idx++, in_width_r);
kernel_.setArg(idx++, static_cast<int32_t>(channel_blocks));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
cl::Event event;
cl_int error;
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event);
} else {
std::vector<uint32_t> roundup_gws(lws.size());
for (size_t i = 0; i < lws.size(); ++i) {
roundup_gws[i] = RoundUp(gws[i], lws[i]);
}
error = runtime->command_queue().enqueueNDRangeKernel(
kernel_, cl::NullRange,
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event);
}
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_error_->Map(nullptr);
char *kerror_code = kernel_error_->mutable_data<char>();
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
return MACE_SUCCESS;
}
template struct ReduceMeanFunctor<DeviceType::GPU, float>;
template struct ReduceMeanFunctor<DeviceType::GPU, half>;
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_REDUCE_MEAN_H_
#define MACE_KERNELS_REDUCE_MEAN_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
struct ReduceFunctorBase {
ReduceFunctorBase(const std::vector<int> &axis,
const bool keep_dims)
: keep_dims_(keep_dims),
axis_(axis) {}
bool keep_dims_;
bool reduce_first_axis_;
const std::vector<int> axis_;
std::vector<int> data_reshape_;
std::vector<index_t> out_shape_;
};
template <DeviceType D, typename T>
struct ReduceMeanFunctor : ReduceFunctorBase{
ReduceMeanFunctor(const std::vector<int> &axis,
const bool keep_dims)
: ReduceFunctorBase(axis, keep_dims) {}
void Simplify(const Tensor *input,
const bool keep_dims) {
std::vector<bool> bitmap(static_cast<uint32_t>(input->dim_size()), false);
if (axis_.size() == 0) {
for (int i = 0; i < input->dim_size(); ++i) {
bitmap[i] = true;
}
} else {
for (unsigned int i = 0; i < axis_.size(); ++i) {
const int index = axis_[i] >= 0 ?
axis_[i] :
axis_[i] + input->dim_size();
bitmap[index] = true;
}
}
out_shape_.clear();
for (unsigned int i = 0; i < input->dim_size(); ++i) {
if (!bitmap[i]) {
out_shape_.push_back(input->dim(i));
} else if (keep_dims) {
out_shape_.push_back(1);
}
}
data_reshape_.clear();
unsigned int dim_index = 0;
for (; dim_index < input->dim_size(); ++dim_index) {
if (input->dim(dim_index) != 1) break;
}
if (dim_index >= input->dim_size()) {
reduce_first_axis_ = true;
} else {
reduce_first_axis_ = bitmap[dim_index];
data_reshape_.push_back(input->dim(dim_index));
++dim_index;
for (; dim_index < input->dim_size(); ++dim_index) {
const int n = input->dim(dim_index);
if (n == 1) {
bitmap[dim_index] = bitmap[dim_index - 1];
}
if (bitmap[dim_index-1] != bitmap[dim_index]) {
data_reshape_.push_back(n);
} else {
data_reshape_.back() *= n;
}
}
}
}
void Compute(const Tensor *input, Tensor *output) {
Tensor::MappingGuard input_mapper(input);
const T *input_ptr = input->data<T>();
Tensor::MappingGuard output_map(output);
T *output_ptr = output->mutable_data<T>();
memset(output_ptr, 0, output->size() * sizeof(T));
switch (data_reshape_.size()) {
case 1:
if (reduce_first_axis_) {
T sum = 0;
#pragma omp parallel for reduction(+:sum)
for (int i = 0; i < data_reshape_[0]; ++i) {
sum = sum + input_ptr[i];
}
output_ptr[0] = sum / data_reshape_[0];
} else {
#pragma omp parallel for
for (int i = 0; i < data_reshape_[0]; ++i) {
output_ptr[i] = input_ptr[i];
}
}
break;
case 2:
if (reduce_first_axis_) {
#pragma omp parallel for
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[0]; ++j) {
output_ptr[i] += input_ptr[j * data_reshape_[1] + i];
}
output_ptr[i] /= data_reshape_[0];
}
} else {
#pragma omp parallel for
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[1]; ++j) {
output_ptr[i] += input_ptr[i * data_reshape_[1] + j];
}
output_ptr[i] /= data_reshape_[1];
}
}
break;
case 3:
if (reduce_first_axis_) {
#pragma omp parallel for
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
output_ptr[i] +=
input_ptr[(k * data_reshape_[1] + i) * data_reshape_[2]
+ j];
}
}
output_ptr[i] /= (data_reshape_[0] * data_reshape_[2]);
}
} else {
#pragma omp parallel for collapse(2)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[1]; ++k) {
output_ptr[i * data_reshape_[2] + j] +=
input_ptr[(i * data_reshape_[1] + k) * data_reshape_[2]
+ j];
}
output_ptr[i * data_reshape_[2] + j] /= data_reshape_[1];
}
}
}
break;
case 4:
if (reduce_first_axis_) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[3]; ++j) {
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
output_ptr[i * data_reshape_[3] + j] +=
input_ptr[((t * data_reshape_[1] + i) *
data_reshape_[2] + k)*data_reshape_[3] + j];
}
}
output_ptr[i * data_reshape_[3] + j] /=
(data_reshape_[0] * data_reshape_[2]);
}
}
} else {
#pragma omp parallel for collapse(2)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
output_ptr[i * data_reshape_[2] + j] +=
input_ptr[((i * data_reshape_[1] + k) *
data_reshape_[2] + j)*data_reshape_[3] + t];
}
}
output_ptr[i * data_reshape_[2] + j] /=
(data_reshape_[1] * data_reshape_[3]);
}
}
}
break;
default:
MACE_CHECK(false, "not implemented in mace")
<< "data reshape size" << data_reshape_.size()
<< "reduce first axis:" << reduce_first_axis_;
break;
}
}
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
Simplify(input, true);
output->Resize(out_shape_);
Compute(input, output);
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
struct ReduceMeanFunctor<DeviceType::GPU, T>
: ReduceFunctorBase {
ReduceMeanFunctor(const std::vector<int> axis,
const bool keep_dims)
: ReduceFunctorBase(axis, keep_dims) {}
MaceStatus operator()(const Tensor *input,
Tensor *output_tensor,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
std::vector<index_t> input_shape_;
};
#endif
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_REDUCE_MEAN_H_
......@@ -33,7 +33,7 @@ struct StridedSliceFunctor {
int ellipsis_mask,
int new_axis_mask,
int shrink_axis_mask,
bool is_slice = false)
bool is_slice)
: begin_mask_(begin_mask),
end_mask_(end_mask),
ellipsis_mask_(ellipsis_mask),
......
......@@ -103,13 +103,15 @@ void Pooling(int iters,
##DEVICE)
#define MACE_BM_POOLING(N, C, H, W, K, S, PA, PO) \
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, GPU);
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, GPU); \
MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU);
MACE_BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX);
MACE_BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX);
MACE_BM_POOLING(1, 3, 513, 513, 2, 2, SAME, MAX);
MACE_BM_POOLING(1, 3, 1025, 1025, 2, 2, SAME, MAX);
MACE_BM_POOLING(1, 32, 480, 640, 480, 640, VALID, AVG);
} // namespace test
} // namespace ops
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/reduce_mean.h"
namespace mace {
namespace ops {
void Register_ReduceMean(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ReduceMeanOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
ReduceMeanOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
ReduceMeanOp<DeviceType::GPU, half>);
#endif
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_REDUCE_MEAN_H_
#define MACE_OPS_REDUCE_MEAN_H_
#include <string>
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/reduce_mean.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class ReduceMeanOp : public Operator<D, T> {
public:
ReduceMeanOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetRepeatedArgs<int>("axis"),
OperatorBase::GetOptionalArg<bool>("keepdims", true)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const std::vector<int> axis =
OperatorBase::GetRepeatedArgs<int>("axis");
const int left = static_cast<int>(input->dim_size() * -1);
const int right = static_cast<int>(input->dim_size());
if (axis.size()) {
for (unsigned int i = 0; i < axis.size(); ++i) {
MACE_CHECK(axis[i] > left && axis[i] < right, "Axis is over range.");
}
}
Tensor *output = this->Output(OUTPUT);
return functor_(input, output, future);
}
private:
kernels::ReduceMeanFunctor<D, T> functor_;
protected:
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REDUCE_MEAN_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void ReduceMean(int iters, int batch, int channels,
int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ReduceMean", "ReduceMeanBM")
.Input("InputImage")
.AddIntsArg("axis", {1, 2})
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
} else {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("ReduceMean", "ReduceMeanBM")
.Input("InputNCHW")
.AddIntsArg("axis", {2, 3})
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, TYPE, DEVICE) \
static void \
MACE_BM_REDUCE_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ReduceMean<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK( \
MACE_BM_REDUCE_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_REDUCE_MEAN(N, C, H, W) \
MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, float, GPU); \
MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, half, GPU); \
MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, float, CPU);
MACE_BM_REDUCE_MEAN(1, 1, 512, 512);
MACE_BM_REDUCE_MEAN(4, 3, 128, 128);
MACE_BM_REDUCE_MEAN(4, 3, 512, 512);
MACE_BM_REDUCE_MEAN(16, 32, 112, 112);
MACE_BM_REDUCE_MEAN(8, 32, 112, 112);
MACE_BM_REDUCE_MEAN(8, 64, 256, 256);
MACE_BM_REDUCE_MEAN(1, 32, 480, 640);
} // namespace test
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ReduceMeanOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int> &axis,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input);
if (D == DeviceType::CPU) {
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("Input")
.AddIntsArg("axis", axis)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
BufferToImage<D, float>(&net, "Input", "InputImg",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("InputImg")
.AddIntsArg("axis", axis)
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImg", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
}
auto expected = CreateTensor<float>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5, 1e-3);
}
template <DeviceType D>
void Simple12Test() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 2},
{2, 1, 1, 4},
{10, 11, 12, 13,
10, 11, 12, 13});
}
template <DeviceType D>
void Simple1Axis() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23,
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1},
{2, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{-3},
{1, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{2},
{1, 2, 1, 4},
{4, 5, 6, 7, 16, 17, 18, 19});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{-1},
{1, 2, 3, 1},
{1.5, 5.5, 9.5, 13.5, 17.5, 21.5});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{1},
{1, 1, 3, 3},
{9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{-2},
{1, 3, 1, 3},
{3, 4, 5, 12, 13, 14, 21, 22, 23});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{3},
{1, 3, 3, 1},
{1, 4, 7, 10, 13, 16, 19, 22, 25});
}
template <DeviceType D>
void Simple2Axis() {
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1},
{1, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1},
{1, 1, 4},
{10, 11, 12, 13});
Simple<D>({2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1, 2},
{2, 1, 1},
{5.5, 17.5});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 2},
{1, 2, 1, 4},
{4, 5, 6, 7, 16, 17, 18, 19});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1, 3},
{1, 1, 3, 1},
{7.5, 11.5, 15.5});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{1, 2},
{1, 1, 1, 3},
{12, 13, 14});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{0, 1},
{1, 1, 3, 3},
{9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{2, 3},
{1, 3, 1, 1},
{4, 13, 22});
}
template <DeviceType D>
void Simple3Axis() {
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1, 2, 3},
{1, 1, 1, 1},
{11.5});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 2, 3},
{1, 2, 1, 1},
{5.5, 17.5});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1, 3},
{1, 1, 3, 1},
{7.5, 11.5, 15.5});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1, 2},
{1, 1, 1, 4},
{10, 11, 12, 13});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{1, 2, 3},
{1, 1, 1, 1},
{13});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{0, 2, 3},
{1, 3, 1, 1},
{4, 13, 22});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{0, 1, 3},
{1, 1, 3, 1},
{10, 13, 16});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{0, 1, 2},
{1, 1, 1, 3},
{12, 13, 14});
}
} // namespace
TEST_F(ReduceMeanOpTest, CPUSimple12) {
Simple12Test<DeviceType::CPU>();
}
TEST_F(ReduceMeanOpTest, GPUSimple12) {
Simple12Test<DeviceType::GPU>();
}
TEST_F(ReduceMeanOpTest, CPUSimple1Axis) {
Simple1Axis<DeviceType::CPU>();
}
TEST_F(ReduceMeanOpTest, CPUSimple2Axis) {
Simple2Axis<DeviceType::CPU>();
}
TEST_F(ReduceMeanOpTest, CPUSimple3Axis) {
Simple3Axis<DeviceType::CPU>();
}
namespace {
template <DeviceType D, typename T>
void RandomTest(const std::vector<index_t> &input_shape,
const std::vector<int> &axis) {
testing::internal::LogToStderr();
srand(time(NULL));
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", input_shape);
std::vector<int> axis_cpu(axis.size());
for (unsigned int i = 0; i < axis.size(); ++i) {
if (axis[i] == 1 || axis[i] == 2)
axis_cpu[i] = axis[i] + 1;
else if (axis[i] == 3)
axis_cpu[i] = 1;
else
axis_cpu[i] = axis[i];
}
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("InputNCHW")
.AddIntsArg("axis", axis_cpu)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW,
"Output", NHWC);
BufferToImage<D, T>(&net, "Input", "InputImg",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("InputImg")
.AddIntsArg("axis", axis)
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImg", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DT_FLOAT) {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-5, 1e-4);
} else {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-2, 1e-2);
}
}
} // namespace
TEST_F(ReduceMeanOpTest, GPURandomFloat) {
RandomTest<DeviceType::GPU, float>({4, 64, 64, 3}, {1, 2});
RandomTest<DeviceType::GPU, float>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 128, 128, 64}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 117, 87, 33}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 511, 561, 11}, {1, 2});
}
TEST_F(ReduceMeanOpTest, GPURandomHalf) {
RandomTest<DeviceType::GPU, half>({4, 64, 64, 3}, {1, 2});
RandomTest<DeviceType::GPU, half>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 128, 128, 64}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 117, 87, 33}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 511, 561, 11}, {1, 2});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -93,6 +93,7 @@ MaceSupportedOps = [
'Proposal',
'PSROIAlign',
'Quantize',
'ReduceMean',
'Requantize',
'Reshape',
'ResizeBilinear',
......@@ -142,6 +143,7 @@ class MaceKeyword(object):
mace_constant_value_str = 'constant_value'
mace_dims_str = 'dims'
mace_axis_str = 'axis'
mace_keepdims_str = 'keepdims'
mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed'
mace_device = 'device'
......
......@@ -536,21 +536,12 @@ class TensorflowConverter(base_converter.ConverterInterface):
del op.input[1:]
reduce_dims = tf_op.inputs[1].eval()
mace_check(reduce_dims[0] == 1 and reduce_dims[1] == 2,
"Mean only support reduce dim 1, 2")
op.type = MaceOp.Pooling.name
pooling_type_arg = op.arg.add()
pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
pooling_type_arg.i = PoolingType.AVG.value
padding_arg = op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_str
padding_arg.i = PaddingMode.VALID.value
strides_arg = op.arg.add()
strides_arg.name = MaceKeyword.mace_strides_str
strides_arg.ints.extend([1, 1])
kernels_arg = op.arg.add()
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(tf_op.inputs[0].shape.as_list()[1:3])
op.type = MaceOp.ReduceMean.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.ints.extend(reduce_dims)
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = tf_op.get_attr(MaceKeyword.mace_keepdims_str)
self._skip_tensor.add(tf_op.inputs[1].name)
......@@ -795,6 +795,46 @@ class Transformer(base_converter.ConverterInterface):
'only support squeeze at at [2, 3]')
arg.ints[:] = [1, 2]
elif op.type == MaceOp.ReduceMean.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if ConverterUtil.data_format(
op) == DataFormat.NHWC \
and self._target_data_format == DataFormat.NCHW: # noqa
print("Transpose reduce mean args: %s(%s)"
% (op.name, op.type))
reduce_axises = list(arg.ints)
new_axises = []
for i in range(len(reduce_axises)):
idx = reduce_axises[i]
if idx == 1 or idx == 2:
new_axises.append(idx + 1)
elif idx == 3:
new_axises.append(1)
else:
new_axises.append(idx)
new_axises.sort()
arg.ints[:] = []
arg.ints.extend(new_axises)
elif ConverterUtil.data_format(
op) == DataFormat.NCHW \
and self._target_data_format == DataFormat.NHWC: # noqa
print("Transpose reduce mean args: %s(%s)"
% (op.name, op.type))
reduce_axises = list(arg.ints)
new_axises = []
for i in range(len(reduce_axises)):
idx = reduce_axises[i]
if idx == 2 or idx == 3:
new_axises.append(idx - 1)
elif idx == 1:
new_axises.append(3)
else:
new_axises.append(idx)
new_axises.sort()
arg.ints[:] = []
arg.ints.extend(new_axises)
# transpose op output shape
data_format = ConverterUtil.data_format(op)
if data_format is not None \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册