diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 9920e5434265183f1f45a7eec924c5e92f61a482..cde9baa9932ac5789ac4d17596c0da439eabcaa8 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -98,6 +98,7 @@ extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry); extern void Register_Quantize(OperatorRegistry *op_registry); +extern void Register_ReduceMean(OperatorRegistry *op_registry); extern void Register_Requantize(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); @@ -145,6 +146,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Proposal(this); ops::Register_PSROIAlign(this); ops::Register_Quantize(this); + ops::Register_ReduceMean(this); ops::Register_Requantize(this); ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); diff --git a/mace/kernels/opencl/cl/reduce_mean.cl b/mace/kernels/opencl/cl/reduce_mean.cl new file mode 100644 index 0000000000000000000000000000000000000000..ceaac871699c5fe3714208140e7533cc6b52fbb2 --- /dev/null +++ b/mace/kernels/opencl/cl/reduce_mean.cl @@ -0,0 +1,66 @@ +#include + +__kernel void reduce_mean(KERNEL_ERROR_PARAMS + GLOBAL_WORK_GROUP_SIZE_DIM3 + __read_only image2d_t input, + __local float4* group_sum, + __private const int group_size, + __private const int partial_len, + __private const int remain_index, + __private const int batch, + __private const int in_height, + __private const int in_width, + __private const float in_height_r, + __private const float in_width_r, + __private const int channel_blocks, + __write_only image2d_t output) { + const int i = get_local_id(0); + const int j = get_local_id(1); + const int k = get_global_id(2); + +#ifndef NON_UNIFORM_WORK_GROUP + if (i >= local_size_dim0 || j >= local_size_dim1 || k >= global_size_dim2) + return; + const int dim0_size = local_size_dim0; +#else + const int dim0_size = get_local_size(0); +#endif + DATA_TYPE4 tmp = (DATA_TYPE4){0, 0, 0, 0}; + const int index = j * dim0_size + i; + const int b = k / channel_blocks; + const int ch = k - b * channel_blocks; + + DATA_TYPE4 in; + const int valid_part_len = select(partial_len, + partial_len - 1, + remain_index > 0 && index >= remain_index); + const int full_offset = index * partial_len; + const int base_offset = select(full_offset, + full_offset - (index - remain_index), + valid_part_len < partial_len); +#pragma unroll + for (int l = 0; l < valid_part_len; ++l) { + int offset = base_offset + l; + int h_id = floor(offset * in_width_r); + int w_id = offset - h_id * in_width; + int pos_x = mad24(ch, in_width, w_id); + int pos_y = mad24(b, in_height, h_id); + in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y)); + tmp = tmp + in; + } + group_sum[index] = tmp; + +#ifdef NON_QUALCOMM_ADRENO + barrier(CLK_LOCAL_MEM_FENCE); +#endif + + if (i == 0 && j == 0) { + DATA_TYPE4 out = (DATA_TYPE4){0, 0, 0, 0}; +#pragma unroll + for (int l = 0; l < group_size; ++l) { + out = out + group_sum[l]; + } + out = out * in_height_r * in_width_r; + WRITE_IMAGET(output, (int2)(ch, b), out); + } +} diff --git a/mace/kernels/opencl/reduce_mean_opencl.cc b/mace/kernels/opencl/reduce_mean_opencl.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8737c7fa77b9f01a76efe371fb546830e3d8bd9 --- /dev/null +++ b/mace/kernels/opencl/reduce_mean_opencl.cc @@ -0,0 +1,156 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/reduce_mean.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +MaceStatus ReduceMeanFunctor::operator()( + const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_CHECK_NOTNULL(input); + MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims."); + MACE_CHECK(input->dim_size() == 4, + "reduce mean gpu only support 4-dim input"); + MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2, + "reduce mean gpu only support 1,2-axis reduce"); + index_t batch = input->dim(0); + const index_t in_height = input->dim(1); + const index_t in_width = input->dim(2); + const index_t channels = input->dim(3); + const index_t channel_blocks = RoundUpDiv4(channels); + const uint32_t image_size = static_cast(in_height * in_width); + + auto runtime = OpenCLRuntime::Global(); + std::vector gws(3); + std::vector lws(3); + std::vector output_shape{batch, 1, 1, channels}; + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, + &output_image_shape); + MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape)); + if (kernel_.get() == nullptr) { + const DataType dt = DataTypeToEnum::value; + std::set built_options; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("reduce_mean"); + built_options.emplace("-Dreduce_mean=" + kernel_name); + + if (input->dtype() == output->dtype()) { + built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt)); + built_options.emplace(dt == DT_HALF ? "-DFP16" : ""); + } else { + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + } + if (runtime->gpu_type() != GPUType::QUALCOMM_ADRENO) { + built_options.emplace("-DNON_QUALCOMM_ADRENO"); + } + if (runtime->IsOutOfRangeCheckEnabled()) { + built_options.emplace("-DOUT_OF_RANGE_CHECK"); + kernel_error_ = std::move(std::unique_ptr( + new Buffer(GetDeviceAllocator(DeviceType::GPU)))); + MACE_RETURN_IF_ERROR(kernel_error_->Allocate(1)); + kernel_error_->Map(nullptr); + *(kernel_error_->mutable_data()) = 0; + kernel_error_->UnMap(); + } + kwg_size_ = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + + if (runtime->IsNonUniformWorkgroupsSupported()) { + built_options.emplace("-DNON_UNIFORM_WORK_GROUP"); + } + kernel_ = runtime->BuildKernel("reduce_mean", kernel_name, built_options); + } + + if (runtime->gpu_type() == GPUType::QUALCOMM_ADRENO) { + const uint32_t wave_size = + static_cast(runtime->GetKernelWaveSize(kernel_)); + gws = {4, (wave_size / 4), static_cast(batch * channel_blocks)}; + } else { + gws = {4, 16, static_cast(batch * channel_blocks)}; + } + lws = {gws[0], gws[1], 1}; + const int group_size = lws[0] * lws[1] * lws[2]; + const int partial_len = (image_size + group_size - 1) / group_size; + const int remain_index = image_size % group_size; + const float in_width_r = 1.f / in_width; + const float in_height_r = 1.f / in_height; + + if (!IsVecEqual(input_shape_, input->shape())) { + uint32_t idx = 0; + if (runtime->IsOutOfRangeCheckEnabled()) { + kernel_.setArg(idx++, + *(static_cast(kernel_error_->buffer()))); + } + if (!runtime->IsNonUniformWorkgroupsSupported()) { + kernel_.setArg(idx++, gws[0]); + kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); + } + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, (group_size * 4 * sizeof(float)), + nullptr); + kernel_.setArg(idx++, static_cast(group_size)); + kernel_.setArg(idx++, static_cast(partial_len)); + kernel_.setArg(idx++, static_cast(remain_index)); + kernel_.setArg(idx++, static_cast(batch)); + kernel_.setArg(idx++, static_cast(in_height)); + kernel_.setArg(idx++, static_cast(in_width)); + kernel_.setArg(idx++, in_height_r); + kernel_.setArg(idx++, in_width_r); + kernel_.setArg(idx++, static_cast(channel_blocks)); + kernel_.setArg(idx++, *(output->opencl_image())); + + input_shape_ = input->shape(); + } + + cl::Event event; + cl_int error; + if (runtime->IsNonUniformWorkgroupsSupported()) { + error = runtime->command_queue().enqueueNDRangeKernel( + kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); + } else { + std::vector roundup_gws(lws.size()); + for (size_t i = 0; i < lws.size(); ++i) { + roundup_gws[i] = RoundUp(gws[i], lws[i]); + } + error = runtime->command_queue().enqueueNDRangeKernel( + kernel_, cl::NullRange, + cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]), + cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); + } + if (runtime->IsOutOfRangeCheckEnabled()) { + kernel_error_->Map(nullptr); + char *kerror_code = kernel_error_->mutable_data(); + MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code; + kernel_error_->UnMap(); + } + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + + if (future != nullptr) { + future->wait_fn = [runtime, event](CallStats *stats) { + event.wait(); + if (stats != nullptr) { + runtime->GetCallStats(event, stats); + } + }; + } + + return MACE_SUCCESS; +} + +template struct ReduceMeanFunctor; +template struct ReduceMeanFunctor; +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/reduce_mean.h b/mace/kernels/reduce_mean.h new file mode 100644 index 0000000000000000000000000000000000000000..2b250e365e3e5dc5cdcd07ad95c03b264adbb9d6 --- /dev/null +++ b/mace/kernels/reduce_mean.h @@ -0,0 +1,230 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_REDUCE_MEAN_H_ +#define MACE_KERNELS_REDUCE_MEAN_H_ + +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#include +#endif +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +struct ReduceFunctorBase { + ReduceFunctorBase(const std::vector &axis, + const bool keep_dims) + : keep_dims_(keep_dims), + axis_(axis) {} + bool keep_dims_; + bool reduce_first_axis_; + const std::vector axis_; + std::vector data_reshape_; + std::vector out_shape_; +}; + +template +struct ReduceMeanFunctor : ReduceFunctorBase{ + ReduceMeanFunctor(const std::vector &axis, + const bool keep_dims) + : ReduceFunctorBase(axis, keep_dims) {} + + void Simplify(const Tensor *input, + const bool keep_dims) { + std::vector bitmap(static_cast(input->dim_size()), false); + if (axis_.size() == 0) { + for (int i = 0; i < input->dim_size(); ++i) { + bitmap[i] = true; + } + } else { + for (unsigned int i = 0; i < axis_.size(); ++i) { + const int index = axis_[i] >= 0 ? + axis_[i] : + axis_[i] + input->dim_size(); + bitmap[index] = true; + } + } + out_shape_.clear(); + for (unsigned int i = 0; i < input->dim_size(); ++i) { + if (!bitmap[i]) { + out_shape_.push_back(input->dim(i)); + } else if (keep_dims) { + out_shape_.push_back(1); + } + } + data_reshape_.clear(); + unsigned int dim_index = 0; + for (; dim_index < input->dim_size(); ++dim_index) { + if (input->dim(dim_index) != 1) break; + } + if (dim_index >= input->dim_size()) { + reduce_first_axis_ = true; + } else { + reduce_first_axis_ = bitmap[dim_index]; + data_reshape_.push_back(input->dim(dim_index)); + ++dim_index; + for (; dim_index < input->dim_size(); ++dim_index) { + const int n = input->dim(dim_index); + if (n == 1) { + bitmap[dim_index] = bitmap[dim_index - 1]; + } + if (bitmap[dim_index-1] != bitmap[dim_index]) { + data_reshape_.push_back(n); + } else { + data_reshape_.back() *= n; + } + } + } + } + + void Compute(const Tensor *input, Tensor *output) { + Tensor::MappingGuard input_mapper(input); + const T *input_ptr = input->data(); + Tensor::MappingGuard output_map(output); + T *output_ptr = output->mutable_data(); + memset(output_ptr, 0, output->size() * sizeof(T)); + switch (data_reshape_.size()) { + case 1: + if (reduce_first_axis_) { + T sum = 0; +#pragma omp parallel for reduction(+:sum) + for (int i = 0; i < data_reshape_[0]; ++i) { + sum = sum + input_ptr[i]; + } + output_ptr[0] = sum / data_reshape_[0]; + } else { +#pragma omp parallel for + for (int i = 0; i < data_reshape_[0]; ++i) { + output_ptr[i] = input_ptr[i]; + } + } + break; + case 2: + if (reduce_first_axis_) { +#pragma omp parallel for + for (int i = 0; i < data_reshape_[1]; ++i) { + for (int j = 0; j < data_reshape_[0]; ++j) { + output_ptr[i] += input_ptr[j * data_reshape_[1] + i]; + } + output_ptr[i] /= data_reshape_[0]; + } + } else { +#pragma omp parallel for + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[1]; ++j) { + output_ptr[i] += input_ptr[i * data_reshape_[1] + j]; + } + output_ptr[i] /= data_reshape_[1]; + } + } + break; + case 3: + if (reduce_first_axis_) { +#pragma omp parallel for + for (int i = 0; i < data_reshape_[1]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[0]; ++k) { + output_ptr[i] += + input_ptr[(k * data_reshape_[1] + i) * data_reshape_[2] + + j]; + } + } + output_ptr[i] /= (data_reshape_[0] * data_reshape_[2]); + } + } else { +#pragma omp parallel for collapse(2) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[1]; ++k) { + output_ptr[i * data_reshape_[2] + j] += + input_ptr[(i * data_reshape_[1] + k) * data_reshape_[2] + + j]; + } + output_ptr[i * data_reshape_[2] + j] /= data_reshape_[1]; + } + } + } + break; + case 4: + if (reduce_first_axis_) { +#pragma omp parallel for collapse(2) + for (int i = 0; i < data_reshape_[1]; ++i) { + for (int j = 0; j < data_reshape_[3]; ++j) { + for (int k = 0; k < data_reshape_[2]; ++k) { + for (int t = 0; t < data_reshape_[0]; ++t) { + output_ptr[i * data_reshape_[3] + j] += + input_ptr[((t * data_reshape_[1] + i) * + data_reshape_[2] + k)*data_reshape_[3] + j]; + } + } + output_ptr[i * data_reshape_[3] + j] /= + (data_reshape_[0] * data_reshape_[2]); + } + } + } else { +#pragma omp parallel for collapse(2) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[1]; ++k) { + for (int t = 0; t < data_reshape_[3]; ++t) { + output_ptr[i * data_reshape_[2] + j] += + input_ptr[((i * data_reshape_[1] + k) * + data_reshape_[2] + j)*data_reshape_[3] + t]; + } + } + output_ptr[i * data_reshape_[2] + j] /= + (data_reshape_[1] * data_reshape_[3]); + } + } + } + break; + default: + MACE_CHECK(false, "not implemented in mace") + << "data reshape size" << data_reshape_.size() + << "reduce first axis:" << reduce_first_axis_; + break; + } + } + + MaceStatus operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + Simplify(input, true); + output->Resize(out_shape_); + Compute(input, output); + return MACE_SUCCESS; + } +}; + +#ifdef MACE_ENABLE_OPENCL +template +struct ReduceMeanFunctor + : ReduceFunctorBase { + ReduceMeanFunctor(const std::vector axis, + const bool keep_dims) + : ReduceFunctorBase(axis, keep_dims) {} + + MaceStatus operator()(const Tensor *input, + Tensor *output_tensor, + StatsFuture *future); + + cl::Kernel kernel_; + uint32_t kwg_size_; + std::unique_ptr kernel_error_; + std::vector input_shape_; +}; +#endif + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_REDUCE_MEAN_H_ diff --git a/mace/kernels/strided_slice.h b/mace/kernels/strided_slice.h index eafe6acc7b7a21ddc0bf0f0089ebae7133f0ba0e..8f09610b98ea374567b25524b957593b194bdc86 100644 --- a/mace/kernels/strided_slice.h +++ b/mace/kernels/strided_slice.h @@ -33,7 +33,7 @@ struct StridedSliceFunctor { int ellipsis_mask, int new_axis_mask, int shrink_axis_mask, - bool is_slice = false) + bool is_slice) : begin_mask_(begin_mask), end_mask_(end_mask), ellipsis_mask_(ellipsis_mask), diff --git a/mace/ops/pooling_benchmark.cc b/mace/ops/pooling_benchmark.cc index d0da9b47b52735698be2d62473a58428731197e2..dec2b53a7acdb2b40667be124413cb3be708e74c 100644 --- a/mace/ops/pooling_benchmark.cc +++ b/mace/ops/pooling_benchmark.cc @@ -103,13 +103,15 @@ void Pooling(int iters, ##DEVICE) #define MACE_BM_POOLING(N, C, H, W, K, S, PA, PO) \ - MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); \ - MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, GPU); + MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, GPU); \ + MACE_BM_POOLING_MACRO(N, C, H, W, K, S, PA, PO, CPU); + MACE_BM_POOLING(1, 3, 129, 129, 2, 2, SAME, MAX); MACE_BM_POOLING(1, 3, 257, 257, 2, 2, SAME, MAX); MACE_BM_POOLING(1, 3, 513, 513, 2, 2, SAME, MAX); MACE_BM_POOLING(1, 3, 1025, 1025, 2, 2, SAME, MAX); +MACE_BM_POOLING(1, 32, 480, 640, 480, 640, VALID, AVG); } // namespace test } // namespace ops diff --git a/mace/ops/reduce_mean.cc b/mace/ops/reduce_mean.cc new file mode 100644 index 0000000000000000000000000000000000000000..d940077ca4310e7563a2f850fbdf273d2a2edd12 --- /dev/null +++ b/mace/ops/reduce_mean.cc @@ -0,0 +1,32 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/reduce_mean.h" + +namespace mace { +namespace ops { + +void Register_ReduceMean(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ReduceMeanOp); +#ifdef MACE_ENABLE_OPENCL + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + ReduceMeanOp); + + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + ReduceMeanOp); +#endif +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/reduce_mean.h b/mace/ops/reduce_mean.h new file mode 100644 index 0000000000000000000000000000000000000000..4a317259730ef437978e630f4182d2c50c3ad0bc --- /dev/null +++ b/mace/ops/reduce_mean.h @@ -0,0 +1,52 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_REDUCE_MEAN_H_ +#define MACE_OPS_REDUCE_MEAN_H_ + +#include +#include + +#include "mace/core/operator.h" +#include "mace/kernels/reduce_mean.h" + +namespace mace { +namespace ops { + +template +class ReduceMeanOp : public Operator { + public: + ReduceMeanOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetRepeatedArgs("axis"), + OperatorBase::GetOptionalArg("keepdims", true)) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const std::vector axis = + OperatorBase::GetRepeatedArgs("axis"); + const int left = static_cast(input->dim_size() * -1); + const int right = static_cast(input->dim_size()); + if (axis.size()) { + for (unsigned int i = 0; i < axis.size(); ++i) { + MACE_CHECK(axis[i] > left && axis[i] < right, "Axis is over range."); + } + } + Tensor *output = this->Output(OUTPUT); + + return functor_(input, output, future); + } + + private: + kernels::ReduceMeanFunctor functor_; + + protected: + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_REDUCE_MEAN_H_ diff --git a/mace/ops/reduce_mean_benchmark.cc b/mace/ops/reduce_mean_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..272df97d7c2eeeb96e5bdcd9f8ff41d0e7495246 --- /dev/null +++ b/mace/ops/reduce_mean_benchmark.cc @@ -0,0 +1,85 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void ReduceMean(int iters, int batch, int channels, + int height, int width) { + mace::testing::StopTiming(); + + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + if (D == DeviceType::GPU) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("ReduceMean", "ReduceMeanBM") + .Input("InputImage") + .AddIntsArg("axis", {1, 2}) + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + } else { + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + OpDefBuilder("ReduceMean", "ReduceMeanBM") + .Input("InputNCHW") + .AddIntsArg("axis", {2, 3}) + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void \ + MACE_BM_REDUCE_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(\ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + ReduceMean(iters, N, C, H, W); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_REDUCE_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define MACE_BM_REDUCE_MEAN(N, C, H, W) \ + MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, float, GPU); \ + MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, half, GPU); \ + MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, float, CPU); + + +MACE_BM_REDUCE_MEAN(1, 1, 512, 512); +MACE_BM_REDUCE_MEAN(4, 3, 128, 128); +MACE_BM_REDUCE_MEAN(4, 3, 512, 512); +MACE_BM_REDUCE_MEAN(16, 32, 112, 112); +MACE_BM_REDUCE_MEAN(8, 32, 112, 112); +MACE_BM_REDUCE_MEAN(8, 64, 256, 256); +MACE_BM_REDUCE_MEAN(1, 32, 480, 640); + + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/reduce_mean_test.cc b/mace/ops/reduce_mean_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..025f9c04143912490c4f66df768bb5a13bf7b08b --- /dev/null +++ b/mace/ops/reduce_mean_test.cc @@ -0,0 +1,383 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ReduceMeanOpTest : public OpsTestBase {}; + +namespace { +template +void Simple(const std::vector &input_shape, + const std::vector &input, + const std::vector &axis, + const std::vector &output_shape, + const std::vector &output) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", input_shape, input); + + if (D == DeviceType::CPU) { + OpDefBuilder("ReduceMean", "ReduceMeanTest") + .Input("Input") + .AddIntsArg("axis", axis) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + } else { + BufferToImage(&net, "Input", "InputImg", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("ReduceMean", "ReduceMeanTest") + .Input("InputImg") + .AddIntsArg("axis", axis) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + ImageToBuffer(&net, "OutputImg", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } + auto expected = CreateTensor(output_shape, output); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5, 1e-3); +} + +template +void Simple12Test() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2}, + {2, 1, 1, 4}, + {10, 11, 12, 13, + 10, 11, 12, 13}); +} + +template +void Simple1Axis() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {1}, + {2, 1, 3, 4}, + {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {-3}, + {1, 1, 3, 4}, + {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {2}, + {1, 2, 1, 4}, + {4, 5, 6, 7, 16, 17, 18, 19}); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {-1}, + {1, 2, 3, 1}, + {1.5, 5.5, 9.5, 13.5, 17.5, 21.5}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {1}, + {1, 1, 3, 3}, + {9, 10, 11, 12, 13, 14, 15, 16, 17}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {-2}, + {1, 3, 1, 3}, + {3, 4, 5, 12, 13, 14, 21, 22, 23}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {3}, + {1, 3, 3, 1}, + {1, 4, 7, 10, 13, 16, 19, 22, 25}); +} + +template +void Simple2Axis() { + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {0, 1}, + {1, 1, 3, 4}, + {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); + Simple({2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {0, 1}, + {1, 1, 4}, + {10, 11, 12, 13}); + Simple({2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {1, 2}, + {2, 1, 1}, + {5.5, 17.5}); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {0, 2}, + {1, 2, 1, 4}, + {4, 5, 6, 7, 16, 17, 18, 19}); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {1, 3}, + {1, 1, 3, 1}, + {7.5, 11.5, 15.5}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {1, 2}, + {1, 1, 1, 3}, + {12, 13, 14}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {0, 1}, + {1, 1, 3, 3}, + {9, 10, 11, 12, 13, 14, 15, 16, 17}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {2, 3}, + {1, 3, 1, 1}, + {4, 13, 22}); +} + +template +void Simple3Axis() { + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {1, 2, 3}, + {1, 1, 1, 1}, + {11.5}); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {0, 2, 3}, + {1, 2, 1, 1}, + {5.5, 17.5}); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {0, 1, 3}, + {1, 1, 3, 1}, + {7.5, 11.5, 15.5}); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {0, 1, 2}, + {1, 1, 1, 4}, + {10, 11, 12, 13}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {1, 2, 3}, + {1, 1, 1, 1}, + {13}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {0, 2, 3}, + {1, 3, 1, 1}, + {4, 13, 22}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {0, 1, 3}, + {1, 1, 3, 1}, + {10, 13, 16}); + Simple({1, 3, 3, 3}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, 24, 25, 26}, + {0, 1, 2}, + {1, 1, 1, 3}, + {12, 13, 14}); +} + +} // namespace + +TEST_F(ReduceMeanOpTest, CPUSimple12) { + Simple12Test(); +} + +TEST_F(ReduceMeanOpTest, GPUSimple12) { + Simple12Test(); +} + +TEST_F(ReduceMeanOpTest, CPUSimple1Axis) { + Simple1Axis(); +} + +TEST_F(ReduceMeanOpTest, CPUSimple2Axis) { + Simple2Axis(); +} + +TEST_F(ReduceMeanOpTest, CPUSimple3Axis) { + Simple3Axis(); +} + + +namespace { +template +void RandomTest(const std::vector &input_shape, + const std::vector &axis) { + testing::internal::LogToStderr(); + srand(time(NULL)); + // Construct graph + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", input_shape); + + std::vector axis_cpu(axis.size()); + for (unsigned int i = 0; i < axis.size(); ++i) { + if (axis[i] == 1 || axis[i] == 2) + axis_cpu[i] = axis[i] + 1; + else if (axis[i] == 3) + axis_cpu[i] = 1; + else + axis_cpu[i] = axis[i]; + } + + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + OpDefBuilder("ReduceMean", "ReduceMeanTest") + .Input("InputNCHW") + .AddIntsArg("axis", axis_cpu) + .Output("OutputNCHW") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(); + net.TransformDataFormat("OutputNCHW", NCHW, + "Output", NHWC); + BufferToImage(&net, "Input", "InputImg", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("ReduceMean", "ReduceMeanTest") + .Input("InputImg") + .AddIntsArg("axis", axis) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + ImageToBuffer(&net, "OutputImg", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + if (DataTypeToEnum::value == DT_FLOAT) { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-5, 1e-4); + } else { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-2, 1e-2); + } +} +} // namespace + +TEST_F(ReduceMeanOpTest, GPURandomFloat) { + RandomTest({4, 64, 64, 3}, {1, 2}); + RandomTest({2, 64, 64, 4}, {1, 2}); + RandomTest({8, 128, 128, 64}, {1, 2}); + RandomTest({1, 640, 480, 64}, {1, 2}); + RandomTest({1, 512, 512, 16}, {1, 2}); + RandomTest({8, 117, 87, 33}, {1, 2}); + RandomTest({1, 619, 450, 61}, {1, 2}); + RandomTest({1, 511, 561, 11}, {1, 2}); +} + +TEST_F(ReduceMeanOpTest, GPURandomHalf) { + RandomTest({4, 64, 64, 3}, {1, 2}); + RandomTest({2, 64, 64, 4}, {1, 2}); + RandomTest({8, 128, 128, 64}, {1, 2}); + RandomTest({1, 640, 480, 64}, {1, 2}); + RandomTest({1, 512, 512, 16}, {1, 2}); + RandomTest({8, 117, 87, 33}, {1, 2}); + RandomTest({1, 619, 450, 61}, {1, 2}); + RandomTest({1, 511, 561, 11}, {1, 2}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 40affd554bc5191cbf30fddb4b0ab421404a695e..78226c52d3173940e47c67617389539a599788ae 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -93,6 +93,7 @@ MaceSupportedOps = [ 'Proposal', 'PSROIAlign', 'Quantize', + 'ReduceMean', 'Requantize', 'Reshape', 'ResizeBilinear', @@ -142,6 +143,7 @@ class MaceKeyword(object): mace_constant_value_str = 'constant_value' mace_dims_str = 'dims' mace_axis_str = 'axis' + mace_keepdims_str = 'keepdims' mace_shape_str = 'shape' mace_winograd_filter_transformed = 'is_filter_transformed' mace_device = 'device' diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 3e9335dd82628996a38f684890a7733bdb929633..9076658d37c628915f8bdb062ea1675481bd07ec 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -536,21 +536,12 @@ class TensorflowConverter(base_converter.ConverterInterface): del op.input[1:] reduce_dims = tf_op.inputs[1].eval() - mace_check(reduce_dims[0] == 1 and reduce_dims[1] == 2, - "Mean only support reduce dim 1, 2") - - op.type = MaceOp.Pooling.name - pooling_type_arg = op.arg.add() - pooling_type_arg.name = MaceKeyword.mace_pooling_type_str - pooling_type_arg.i = PoolingType.AVG.value - padding_arg = op.arg.add() - padding_arg.name = MaceKeyword.mace_padding_str - padding_arg.i = PaddingMode.VALID.value - strides_arg = op.arg.add() - strides_arg.name = MaceKeyword.mace_strides_str - strides_arg.ints.extend([1, 1]) - kernels_arg = op.arg.add() - kernels_arg.name = MaceKeyword.mace_kernel_str - kernels_arg.ints.extend(tf_op.inputs[0].shape.as_list()[1:3]) + op.type = MaceOp.ReduceMean.name + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.ints.extend(reduce_dims) + keep_dims_arg = op.arg.add() + keep_dims_arg.name = MaceKeyword.mace_keepdims_str + keep_dims_arg.i = tf_op.get_attr(MaceKeyword.mace_keepdims_str) self._skip_tensor.add(tf_op.inputs[1].name) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index d5d0c0cf5da239ea979c3675b6a7f9583bc0fcfa..b176e29dea18d11e54e8414dda7983d57cf0c530 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -795,6 +795,46 @@ class Transformer(base_converter.ConverterInterface): 'only support squeeze at at [2, 3]') arg.ints[:] = [1, 2] + elif op.type == MaceOp.ReduceMean.name: + for arg in op.arg: + if arg.name == MaceKeyword.mace_axis_str: + if ConverterUtil.data_format( + op) == DataFormat.NHWC \ + and self._target_data_format == DataFormat.NCHW: # noqa + print("Transpose reduce mean args: %s(%s)" + % (op.name, op.type)) + reduce_axises = list(arg.ints) + new_axises = [] + for i in range(len(reduce_axises)): + idx = reduce_axises[i] + if idx == 1 or idx == 2: + new_axises.append(idx + 1) + elif idx == 3: + new_axises.append(1) + else: + new_axises.append(idx) + new_axises.sort() + arg.ints[:] = [] + arg.ints.extend(new_axises) + elif ConverterUtil.data_format( + op) == DataFormat.NCHW \ + and self._target_data_format == DataFormat.NHWC: # noqa + print("Transpose reduce mean args: %s(%s)" + % (op.name, op.type)) + reduce_axises = list(arg.ints) + new_axises = [] + for i in range(len(reduce_axises)): + idx = reduce_axises[i] + if idx == 2 or idx == 3: + new_axises.append(idx - 1) + elif idx == 1: + new_axises.append(3) + else: + new_axises.append(idx) + new_axises.sort() + arg.ints[:] = [] + arg.ints.extend(new_axises) + # transpose op output shape data_format = ConverterUtil.data_format(op) if data_format is not None \