From 643aad357a984750b04cde88ab59e394cafa4775 Mon Sep 17 00:00:00 2001 From: liutuo Date: Thu, 27 Sep 2018 15:35:38 +0800 Subject: [PATCH] fold sqr_dif andf reduce_mean to fix out of range in half datatype --- mace/kernels/opencl/cl/reduce_mean.cl | 7 +- mace/kernels/opencl/cl/sqrdiff_mean.cl | 68 +++++++ mace/kernels/opencl/image/sqrdiff_mean.h | 176 ++++++++++++++++++ mace/kernels/opencl/sqrdiff_mean.cc | 43 +++++ mace/kernels/sqrdiff_mean.h | 109 +++++++++++ mace/ops/ops_register.cc | 18 +- mace/ops/reduce_mean_benchmark.cc | 3 +- mace/ops/reduce_mean_test.cc | 2 + mace/ops/sqrdiff_mean.cc | 42 +++++ mace/ops/sqrdiff_mean.h | 53 ++++++ mace/ops/sqrdiff_mean_benchmark.cc | 102 ++++++++++ mace/ops/sqrdiff_mean_test.cc | 174 +++++++++++++++++ .../tools/converter_tool/base_converter.py | 3 + .../tools/converter_tool/transformer.py | 29 +++ .../opencl-kernel/opencl_kernel_configure.bzl | 1 + 15 files changed, 816 insertions(+), 14 deletions(-) create mode 100644 mace/kernels/opencl/cl/sqrdiff_mean.cl create mode 100644 mace/kernels/opencl/image/sqrdiff_mean.h create mode 100644 mace/kernels/opencl/sqrdiff_mean.cc create mode 100644 mace/kernels/sqrdiff_mean.h create mode 100644 mace/ops/sqrdiff_mean.cc create mode 100644 mace/ops/sqrdiff_mean.h create mode 100644 mace/ops/sqrdiff_mean_benchmark.cc create mode 100644 mace/ops/sqrdiff_mean_test.cc diff --git a/mace/kernels/opencl/cl/reduce_mean.cl b/mace/kernels/opencl/cl/reduce_mean.cl index 674c1c64..93a318b3 100644 --- a/mace/kernels/opencl/cl/reduce_mean.cl +++ b/mace/kernels/opencl/cl/reduce_mean.cl @@ -3,7 +3,7 @@ __kernel void reduce_mean(OUT_OF_RANGE_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM3 __read_only image2d_t input, - __local DATA_TYPE4 *group_sum, + __local float4 *group_sum, __private const int group_size, __private const int partial_len, __private const int remain_index, @@ -24,7 +24,7 @@ __kernel void reduce_mean(OUT_OF_RANGE_PARAMS return; #endif const int dim0_size = get_local_size(0); - DATA_TYPE4 tmp = (DATA_TYPE4){0, 0, 0, 0}; + float4 tmp = (float4){0, 0, 0, 0}; const int index = mad24(j, dim0_size, i); const int b = floor(k * channel_blocks_reciprocal); const int ch = mad24(b, -channel_blocks, k); @@ -47,7 +47,7 @@ __kernel void reduce_mean(OUT_OF_RANGE_PARAMS in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y)); tmp = tmp + in; } - group_sum[index] = tmp; + group_sum[index] = tmp * image_size_reciprocal; #ifdef NON_QUALCOMM_ADRENO barrier(CLK_LOCAL_MEM_FENCE); @@ -59,7 +59,6 @@ __kernel void reduce_mean(OUT_OF_RANGE_PARAMS for (int l = 0; l < group_size; ++l) { out = out + group_sum[l]; } - out = out * image_size_reciprocal; WRITE_IMAGET(output, (int2)(ch, b), out); } } diff --git a/mace/kernels/opencl/cl/sqrdiff_mean.cl b/mace/kernels/opencl/cl/sqrdiff_mean.cl new file mode 100644 index 00000000..2a297bea --- /dev/null +++ b/mace/kernels/opencl/cl/sqrdiff_mean.cl @@ -0,0 +1,68 @@ +#include + +__kernel void sqrdiff_mean(OUT_OF_RANGE_PARAMS + GLOBAL_WORK_GROUP_SIZE_DIM3 + __read_only image2d_t input, + __read_only image2d_t input1, + __local float4 *group_sum, + __private const int group_size, + __private const int partial_len, + __private const int remain_index, + __private const int batch, + __private const int in_height, + __private const int in_width, + __private const float image_size_reciprocal, + __private const float in_width_reciprocal, + __private const int channel_blocks, + __private const float channel_blocks_reciprocal, + __write_only image2d_t output) { + const int i = get_local_id(0); + const int j = get_local_id(1); + const int k = get_global_id(2); + +#ifndef NON_UNIFORM_WORK_GROUP + if (k >= global_size_dim2) + return; +#endif + const int dim0_size = get_local_size(0); + float4 tmp = (float4){0, 0, 0, 0}; + const int index = mad24(j, dim0_size, i); + const int b = floor(k * channel_blocks_reciprocal); + const int ch = mad24(b, -channel_blocks, k); + + DATA_TYPE4 in; + const int valid_part_len = select(partial_len, + partial_len - 1, + remain_index > 0 && index >= remain_index); + const int full_offset = mul24(index, partial_len); + const int base_offset = select(full_offset, + full_offset - (index - remain_index), + valid_part_len < partial_len); + float4 diff = (float4){0, 0, 0, 0}; + DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(ch, b)); +#pragma unroll + for (int l = 0; l < valid_part_len; ++l) { + int offset = base_offset + l; + int h_id = floor(offset * in_width_reciprocal); + int w_id = mad24(h_id, -in_width, offset); + int pos_x = mad24(ch, in_width, w_id); + int pos_y = mad24(b, in_height, h_id); + in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y)); + diff = in- in1; + tmp = tmp + diff * diff; + } + group_sum[index] = tmp * image_size_reciprocal; + +#ifdef NON_QUALCOMM_ADRENO + barrier(CLK_LOCAL_MEM_FENCE); +#endif + + if (i == 0 && j == 0) { + DATA_TYPE4 out = (DATA_TYPE4){0, 0, 0, 0}; +#pragma unroll + for (int l = 0; l < group_size; ++l) { + out = out + group_sum[l]; + } + WRITE_IMAGET(output, (int2)(ch, b), out); + } +} diff --git a/mace/kernels/opencl/image/sqrdiff_mean.h b/mace/kernels/opencl/image/sqrdiff_mean.h new file mode 100644 index 00000000..31959a62 --- /dev/null +++ b/mace/kernels/opencl/image/sqrdiff_mean.h @@ -0,0 +1,176 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef MACE_KERNELS_OPENCL_IMAGE_SQRDIFF_MEAN_H_ +#define MACE_KERNELS_OPENCL_IMAGE_SQRDIFF_MEAN_H_ + +#include "mace/kernels/sqrdiff_mean.h" + +#include +#include +#include +#include + +#include "mace/kernels/opencl/helper.h" + +namespace mace { +namespace kernels { +namespace opencl { +namespace image { + +template +class SqrDiffMeanKernel : public OpenCLSqrDiffMeanKernel { + public: + SqrDiffMeanKernel() {} + + MaceStatus Compute( + OpKernelContext *context, + const Tensor *input, + const Tensor *input1, + Tensor *output, + StatsFuture *future) override; + + private: + cl::Kernel kernel_; + uint32_t kwg_size_; + std::vector input_shape_; +}; + +template +MaceStatus SqrDiffMeanKernel::Compute( + OpKernelContext *context, + const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future) { + MACE_CHECK_NOTNULL(input0); + MACE_CHECK_NOTNULL(input1); + MACE_CHECK(input0->dim(0) == input1->dim(0) && + input0->dim(3) == input1->dim(3)); + MACE_CHECK(input0->dim_size() == 4 && input1->dim_size() == 4, + "SqrDiffMean gpu only support 4-dim input"); + index_t batch = input0->dim(0); + const index_t in_height = input0->dim(1); + const index_t in_width = input0->dim(2); + const index_t channels = input0->dim(3); + const index_t channel_blocks = RoundUpDiv4(channels); + const uint32_t image_size = static_cast(in_height * in_width); + + std::vector gws(3); + std::vector lws(3); + std::vector output_shape{batch, 1, 1, channels}; + std::vector output_image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, + &output_image_shape); + MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape)); + + auto runtime = context->device()->opencl_runtime(); + MACE_OUT_OF_RANGE_DEFINITION; + + if (kernel_.get() == nullptr) { + const DataType dt = DataTypeToEnum::value; + std::set built_options; + MACE_OUT_OF_RANGE_CONFIG; + MACE_NON_UNIFORM_WG_CONFIG; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("sqrdiff_mean"); + built_options.emplace("-Dsqrdiff_mean=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt)); + if (runtime->gpu_type() != GPUType::QUALCOMM_ADRENO) { + built_options.emplace("-DNON_QUALCOMM_ADRENO"); + } + MACE_RETURN_IF_ERROR(runtime->BuildKernel("sqrdiff_mean", + kernel_name, + built_options, + &kernel_)); + + kwg_size_ = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + } + + if (runtime->gpu_type() == GPUType::QUALCOMM_ADRENO) { + const uint32_t wave_size = + static_cast(runtime->GetKernelWaveSize(kernel_)); + gws = {4, (wave_size / 4), static_cast(batch * channel_blocks)}; + } else { + gws = {4, 16, static_cast(batch * channel_blocks)}; + } + lws = {gws[0], gws[1], 1}; + const int group_size = lws[0] * lws[1] * lws[2]; + const int partial_len = (image_size + group_size - 1) / group_size; + const int remain_index = image_size % group_size; + const float in_width_reciprocal = 1.f / in_width; + const float img_size_reciprocal = 1.f / (in_width * in_height); + const float channel_blk_reciprocal = 1.f / channel_blocks; + + MACE_OUT_OF_RANGE_INIT(kernel_); + if (!IsVecEqual(input_shape_, input0->shape())) { + uint32_t idx = 0; + MACE_OUT_OF_RANGE_SET_ARGS(kernel_); + MACE_SET_3D_GWS_ARGS(kernel_, gws); + kernel_.setArg(idx++, *(input0->opencl_image())); + kernel_.setArg(idx++, *(input1->opencl_image())); + kernel_.setArg(idx++, (group_size * 4 * sizeof(float)), + nullptr); + kernel_.setArg(idx++, static_cast(group_size)); + kernel_.setArg(idx++, static_cast(partial_len)); + kernel_.setArg(idx++, static_cast(remain_index)); + kernel_.setArg(idx++, static_cast(batch)); + kernel_.setArg(idx++, static_cast(in_height)); + kernel_.setArg(idx++, static_cast(in_width)); + kernel_.setArg(idx++, img_size_reciprocal); + kernel_.setArg(idx++, in_width_reciprocal); + kernel_.setArg(idx++, static_cast(channel_blocks)); + kernel_.setArg(idx++, channel_blk_reciprocal); + kernel_.setArg(idx++, *(output->opencl_image())); + + input_shape_ = input0->shape(); + } + + cl::Event event; + cl_int error; + if (runtime->IsNonUniformWorkgroupsSupported()) { + error = runtime->command_queue().enqueueNDRangeKernel( + kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); + } else { + std::vector roundup_gws(lws.size()); + for (size_t i = 0; i < lws.size(); ++i) { + roundup_gws[i] = RoundUp(gws[i], lws[i]); + } + error = runtime->command_queue().enqueueNDRangeKernel( + kernel_, cl::NullRange, + cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]), + cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); + } + MACE_CL_RET_STATUS(error); + MACE_OUT_OF_RANGE_VALIDATION; + + if (future != nullptr) { + future->wait_fn = [runtime, event](CallStats *stats) { + event.wait(); + if (stats != nullptr) { + runtime->GetCallStats(event, stats); + } + }; + } + + return MACE_SUCCESS; +} + +} // namespace image +} // namespace opencl +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_OPENCL_IMAGE_SQRDIFF_MEAN_H_ diff --git a/mace/kernels/opencl/sqrdiff_mean.cc b/mace/kernels/opencl/sqrdiff_mean.cc new file mode 100644 index 00000000..a0a6401d --- /dev/null +++ b/mace/kernels/opencl/sqrdiff_mean.cc @@ -0,0 +1,43 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/kernels/sqrdiff_mean.h" +#include "mace/kernels/opencl/image/sqrdiff_mean.h" + +namespace mace { +namespace kernels { + +template +SqrDiffMeanFunctor::SqrDiffMeanFunctor( + OpKernelContext *context) : OpKernel(context) { + if (context->device()->opencl_runtime()->UseImageMemory()) { + kernel_.reset(new opencl::image::SqrDiffMeanKernel()); + } else { + MACE_NOT_IMPLEMENTED; + } +} + +template +MaceStatus SqrDiffMeanFunctor::operator()( + const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future) { + return kernel_->Compute(context_, input0, input1, output, future); +} + +template struct SqrDiffMeanFunctor; +template struct SqrDiffMeanFunctor; +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/sqrdiff_mean.h b/mace/kernels/sqrdiff_mean.h new file mode 100644 index 00000000..1c2d009c --- /dev/null +++ b/mace/kernels/sqrdiff_mean.h @@ -0,0 +1,109 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_SQRDIFF_MEAN_H_ +#define MACE_KERNELS_SQRDIFF_MEAN_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/kernels/kernel.h" +#ifdef MACE_ENABLE_OPENCL +#include "mace/core/runtime/opencl/cl2_header.h" +#endif + +namespace mace { +namespace kernels { + +template +struct SqrDiffMeanFunctor : OpKernel { + explicit SqrDiffMeanFunctor(OpKernelContext *context) + : OpKernel(context) {} + + void Compute(const Tensor *input0, + const Tensor *input1, + Tensor *output) { + Tensor::MappingGuard input0_mapper(input0); + Tensor::MappingGuard input1_mapper(input1); + const T *input_ptr0 = input0->data(); + const T *input_ptr1 = input1->data(); + Tensor::MappingGuard output_map(output); + T *output_ptr = output->mutable_data(); + memset(output_ptr, 0, output->size() * sizeof(T)); + + const index_t img_size = input0->dim(2) * input0->dim(3); + const index_t bc = input0->dim(0) * input0->dim(1); +#pragma omp parallel for + for (int i = 0; i < bc; ++i) { + for (int j = 0; j < img_size; ++j) { + T diff = input_ptr0[i * img_size + j] - input_ptr1[i]; + output_ptr[i] += diff * diff; + } + output_ptr[i] /= img_size; + } + } + + MaceStatus operator()(const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + + MACE_CHECK(input0->dim(0) == input1->dim(0) && + input0->dim(1) == input1->dim(1), + "inputs dims N and C should be the same."); + + std::vector out_shape(4); + out_shape[0] = input0->dim(0); + out_shape[1] = input0->dim(1); + out_shape[2] = 1; + out_shape[3] = 1; + + output->Resize(out_shape); + Compute(input0, input1, output); + return MACE_SUCCESS; + } +}; + +#ifdef MACE_ENABLE_OPENCL +class OpenCLSqrDiffMeanKernel { + public: + virtual MaceStatus Compute( + OpKernelContext *context, + const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future) = 0; + MACE_VIRTUAL_EMPTY_DESTRUCTOR(OpenCLSqrDiffMeanKernel); +}; +template +struct SqrDiffMeanFunctor : OpKernel { + explicit SqrDiffMeanFunctor(OpKernelContext *context); + + MaceStatus operator()(const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future); + + std::unique_ptr kernel_; +}; +#endif + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_SQRDIFF_MEAN_H_ diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index 7fda59bc..1c29386c 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -54,15 +54,16 @@ extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); extern void Register_Reverse(OperatorRegistryBase *op_registry); extern void Register_ScalarMath(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry); -extern void Register_Split(OperatorRegistryBase *op_registry); extern void Register_Softmax(OperatorRegistryBase *op_registry); -extern void Register_Stack(OperatorRegistryBase *op_registry); -extern void Register_Unstack(OperatorRegistryBase *op_registry); -extern void Register_StridedSlice(OperatorRegistryBase *op_registry); extern void Register_SpaceToBatchND(OperatorRegistryBase *op_registry); extern void Register_SpaceToDepth(OperatorRegistryBase *op_registry); +extern void Register_Split(OperatorRegistryBase *op_registry); +extern void Register_SqrDiffMean(OperatorRegistryBase *op_registry); extern void Register_Squeeze(OperatorRegistryBase *op_registry); +extern void Register_Stack(OperatorRegistryBase *op_registry); +extern void Register_StridedSlice(OperatorRegistryBase *op_registry); extern void Register_Transpose(OperatorRegistryBase *op_registry); +extern void Register_Unstack(OperatorRegistryBase *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistryBase *op_registry); // NOLINT(whitespace/line_length) extern void Register_WinogradTransform(OperatorRegistryBase *op_registry); @@ -112,15 +113,16 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_Reverse(this); ops::Register_ScalarMath(this); ops::Register_Shape(this); - ops::Register_Split(this); ops::Register_Softmax(this); - ops::Register_Stack(this); - ops::Register_Unstack(this); - ops::Register_StridedSlice(this); ops::Register_SpaceToBatchND(this); ops::Register_SpaceToDepth(this); + ops::Register_Split(this); + ops::Register_Stack(this); + ops::Register_StridedSlice(this); + ops::Register_SqrDiffMean(this); ops::Register_Squeeze(this); ops::Register_Transpose(this); + ops::Register_Unstack(this); ops::Register_WinogradInverseTransform(this); ops::Register_WinogradTransform(this); diff --git a/mace/ops/reduce_mean_benchmark.cc b/mace/ops/reduce_mean_benchmark.cc index 21bf06f5..3591c9b1 100644 --- a/mace/ops/reduce_mean_benchmark.cc +++ b/mace/ops/reduce_mean_benchmark.cc @@ -83,9 +83,8 @@ void ReduceMean(int iters, int batch, int channels, MACE_BM_REDUCE_MEAN(1, 1, 512, 512); MACE_BM_REDUCE_MEAN(4, 3, 128, 128); -MACE_BM_REDUCE_MEAN(4, 3, 512, 512); +MACE_BM_REDUCE_MEAN(4, 1, 512, 512); MACE_BM_REDUCE_MEAN(16, 32, 112, 112); -MACE_BM_REDUCE_MEAN(8, 32, 112, 112); MACE_BM_REDUCE_MEAN(8, 64, 256, 256); MACE_BM_REDUCE_MEAN(1, 32, 480, 640); diff --git a/mace/ops/reduce_mean_test.cc b/mace/ops/reduce_mean_test.cc index 2b1875de..b1bbe5cc 100644 --- a/mace/ops/reduce_mean_test.cc +++ b/mace/ops/reduce_mean_test.cc @@ -388,6 +388,7 @@ TEST_F(ReduceMeanOpTest, GPURandomFloat) { RandomTest({2, 64, 64, 4}, {1, 2}); RandomTest({8, 128, 128, 64}, {1, 2}); RandomTest({1, 640, 480, 64}, {1, 2}); + RandomTest({1, 480, 640, 32}, {1, 2}); RandomTest({1, 512, 512, 16}, {1, 2}); RandomTest({8, 117, 87, 33}, {1, 2}); RandomTest({1, 619, 450, 61}, {1, 2}); @@ -399,6 +400,7 @@ TEST_F(ReduceMeanOpTest, GPURandomHalf) { RandomTest({2, 64, 64, 4}, {1, 2}); RandomTest({8, 128, 128, 64}, {1, 2}); RandomTest({1, 640, 480, 64}, {1, 2}); + RandomTest({1, 480, 640, 32}, {1, 2}); RandomTest({1, 512, 512, 16}, {1, 2}); RandomTest({8, 117, 87, 33}, {1, 2}); RandomTest({1, 619, 450, 61}, {1, 2}); diff --git a/mace/ops/sqrdiff_mean.cc b/mace/ops/sqrdiff_mean.cc new file mode 100644 index 00000000..d8e8bd51 --- /dev/null +++ b/mace/ops/sqrdiff_mean.cc @@ -0,0 +1,42 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/sqrdiff_mean.h" + +namespace mace { +namespace ops { + +void Register_SqrDiffMean(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SqrDiffMean") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + SqrDiffMeanOp); +#ifdef MACE_ENABLE_OPENCL + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SqrDiffMean") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + SqrDiffMeanOp); + + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SqrDiffMean") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + SqrDiffMeanOp); +#endif +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/sqrdiff_mean.h b/mace/ops/sqrdiff_mean.h new file mode 100644 index 00000000..f021c0b2 --- /dev/null +++ b/mace/ops/sqrdiff_mean.h @@ -0,0 +1,53 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_SQRDIFF_MEAN_H_ +#define MACE_OPS_SQRDIFF_MEAN_H_ + +#include +#include + +#include "mace/core/operator.h" +#include "mace/kernels/sqrdiff_mean.h" + +namespace mace { +namespace ops { + +template +class SqrDiffMeanOp : public Operator { + public: + SqrDiffMeanOp(const OperatorDef &operator_def, OpKernelContext *context) + : Operator(operator_def, context), + functor_(context) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input0 = this->Input(INPUT0); + const Tensor *input1 = this->Input(INPUT1); + Tensor *output = this->Output(OUTPUT); + + return functor_(input0, input1, output, future); + } + + private: + kernels::SqrDiffMeanFunctor functor_; + + protected: + MACE_OP_INPUT_TAGS(INPUT0, INPUT1); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_SQRDIFF_MEAN_H_ diff --git a/mace/ops/sqrdiff_mean_benchmark.cc b/mace/ops/sqrdiff_mean_benchmark.cc new file mode 100644 index 00000000..f3bfd44c --- /dev/null +++ b/mace/ops/sqrdiff_mean_benchmark.cc @@ -0,0 +1,102 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void SqrDiffMean(int iters, int batch, int channels, + int height, int width) { + mace::testing::StopTiming(); + + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + net.AddRandomInput("Input1", {batch, 1, 1, channels}); + + if (D == DeviceType::GPU) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Input1", "InputImage1", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("SqrDiffMean", "SqrDiffMeanBM") + .Input("InputImage") + .Input("InputImage1") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + } else { + net.TransformDataFormat("Input", + NHWC, + "InputNCHW", + NCHW); + net.TransformDataFormat("Input1", + NHWC, + "InputNCHW1", + NCHW); + OpDefBuilder("SqrDiffMean", "SqrDiffMeanBM") + .Input("InputNCHW") + .Input("InputNCHW1") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_SQRDIFF_MEAN_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void \ + MACE_BM_SQRDIFF_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(\ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + SqrDiffMean(iters, N, C, H, W); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_SQRDIFF_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define MACE_BM_SQRDIFF_MEAN(N, C, H, W) \ + MACE_BM_SQRDIFF_MEAN_MACRO(N, C, H, W, float, GPU); \ + MACE_BM_SQRDIFF_MEAN_MACRO(N, C, H, W, half, GPU); \ + MACE_BM_SQRDIFF_MEAN_MACRO(N, C, H, W, float, CPU); + + +MACE_BM_SQRDIFF_MEAN(1, 1, 512, 512); +MACE_BM_SQRDIFF_MEAN(4, 3, 128, 128); +MACE_BM_SQRDIFF_MEAN(4, 1, 512, 512); +MACE_BM_SQRDIFF_MEAN(8, 64, 256, 256); +MACE_BM_SQRDIFF_MEAN(1, 32, 480, 640); + + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/sqrdiff_mean_test.cc b/mace/ops/sqrdiff_mean_test.cc new file mode 100644 index 00000000..e88810bc --- /dev/null +++ b/mace/ops/sqrdiff_mean_test.cc @@ -0,0 +1,174 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class SqrDiffMeanOpTest : public OpsTestBase {}; + +namespace { +template +void Simple(const std::vector &input_shape0, + const std::vector &input0, + const std::vector &input_shape1, + const std::vector &input1, + const std::vector &output_shape, + const std::vector &output) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input0", input_shape0, input0); + net.AddInputFromArray("Input1", input_shape1, input1); + + net.TransformDataFormat("Input0", + NHWC, + "InputNCHW0", + NCHW); + net.TransformDataFormat("Input1", + NHWC, + "InputNCHW1", + NCHW); + + if (D == DeviceType::CPU) { + OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") + .Input("InputNCHW0") + .Input("InputNCHW1") + .Output("OutputNCHW") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + + net.TransformDataFormat("OutputNCHW", + NCHW, + "Output", + NHWC); + } else { + BufferToImage(&net, "Input0", "InputImg0", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Input1", "InputImg1", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") + .Input("InputImg0") + .Input("InputImg1") + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + ImageToBuffer(&net, "OutputImg", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } + auto expected = net.CreateTensor(output_shape, output); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5, 1e-3); +} + +template +void Simple12Test() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {2, 1, 1, 4}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {2, 1, 1, 4}, + {127.667, 146.667, 167.667, 190.667, + 127.667, 146.667, 167.667, 190.667}); +} + + +} // namespace + +TEST_F(SqrDiffMeanOpTest, CPUSimple12) { + Simple12Test(); +} + +TEST_F(SqrDiffMeanOpTest, GPUSimple12) { + Simple12Test(); +} + +namespace { +template +void RandomTest(const std::vector &input_shape0, + const std::vector &input_shape1) { + testing::internal::LogToStderr(); + srand(time(NULL)); + // Construct graph + OpsTestNet net; + // Add input data + net.AddRandomInput("Input0", input_shape0); + net.AddRandomInput("Input1", input_shape1); + + net.TransformDataFormat("Input0", NHWC, "InputNCHW0", + NCHW); + net.TransformDataFormat("Input1", NHWC, "InputNCHW1", + NCHW); + OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") + .Input("InputNCHW0") + .Input("InputNCHW1") + .Output("OutputNCHW") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(); + net.TransformDataFormat("OutputNCHW", NCHW, + "Output", NHWC); + BufferToImage(&net, "Input0", "InputImg0", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Input1", "InputImg1", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") + .Input("InputImg0") + .Input("InputImg1") + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + ImageToBuffer(&net, "OutputImg", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + if (DataTypeToEnum::value == DT_FLOAT) { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-4, 1e-3); + } else { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-2, 1e-2); + } +} +} // namespace + +TEST_F(SqrDiffMeanOpTest, GPURandomFloat) { + RandomTest({4, 64, 64, 3}, {4, 1, 1, 3}); + RandomTest({2, 64, 64, 4}, {2, 1, 1, 4}); + RandomTest({8, 128, 128, 64}, {8, 1, 1, 64}); + RandomTest({1, 640, 480, 64}, {1, 1, 1, 64}); + RandomTest({8, 117, 87, 33}, {8, 1, 1, 33}); + RandomTest({1, 619, 450, 61}, {1, 1, 1, 61}); + RandomTest({11, 511, 561, 1}, {11, 1, 1, 1}); +} + +TEST_F(SqrDiffMeanOpTest, GPURandomHalf) { + RandomTest({4, 64, 64, 3}, {4, 1, 1, 3}); + RandomTest({2, 64, 64, 4}, {2, 1, 1, 4}); + RandomTest({8, 128, 128, 64}, {8, 1, 1, 64}); + RandomTest({1, 640, 480, 64}, {1, 1, 1, 64}); + RandomTest({8, 117, 87, 33}, {8, 1, 1, 33}); + RandomTest({1, 619, 450, 61}, {1, 1, 1, 61}); + RandomTest({11, 511, 561, 1}, {11, 1, 1, 1}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index d677ed34..5e6c6f8e 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -122,6 +122,7 @@ MaceSupportedOps = [ 'Softmax', 'SpaceToBatchND', 'SpaceToDepth', + 'SqrDiffMean', 'Transpose', 'WinogradInverseTransform', 'WinogradTransform', @@ -217,6 +218,7 @@ class TransformerRule(Enum): REARRANGE_BATCH_TO_SPACE = 30 ADD_OPENCL_INFORMATIONS = 31 FOLD_DECONV_AND_BN = 32 + FOLD_SQRDIFF_MEAN = 33 class ConverterInterface(object): @@ -397,6 +399,7 @@ class ConverterOption(object): TransformerRule.FOLD_BIASADD, TransformerRule.FLATTEN_ATROUS_CONV, TransformerRule.FOLD_ACTIVATION, + TransformerRule.FOLD_SQRDIFF_MEAN, TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, TransformerRule.RESHAPE_FC_WEIGHT, # Model data format related transformation diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 2f4a153b..36137635 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -78,6 +78,7 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.FOLD_BIASADD: self.fold_biasadd, TransformerRule.FLATTEN_ATROUS_CONV: self.flatten_atrous_conv, TransformerRule.FOLD_ACTIVATION: self.fold_activation, + TransformerRule.FOLD_SQRDIFF_MEAN: self.fold_squared_diff_mean, TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters, TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format, TransformerRule.ADD_IN_OUT_TENSOR_INFO: @@ -351,6 +352,34 @@ class Transformer(base_converter.ConverterInterface): return False + def fold_squared_diff_mean(self): + net = self._model + for op in net.op: + if op.type == MaceOp.Eltwise.name and len(op.input) == 2: + elttype = ConverterUtil.get_arg( + op, + MaceKeyword.mace_element_type_str).i + if elttype == EltwiseType.SQR_DIFF.value and\ + self.consumer_count(op.output[0]) == 1: + consumer_op = self._consumers[op.output[0]][0] + axis = ConverterUtil.get_arg( + consumer_op, + MaceKeyword.mace_axis_str).ints + keep_dims = ConverterUtil.get_arg( + consumer_op, + MaceKeyword.mace_keepdims_str).i + if consumer_op.type == MaceOp.ReduceMean.name and\ + len(consumer_op.input) == 1 and \ + axis[0] == 1 and axis[1] == 2 and keep_dims != 0: + print("Fold SquaredDiff ReduceMean: %s" % op.name) + op.type = MaceOp.SqrDiffMean.name + op.output[0] = consumer_op.output[0] + self.replace_quantize_info(op, consumer_op) + self.safe_remove_node(consumer_op, op) + return True + + return False + def transform_lstmcell_zerostate(self): net = self._model diff --git a/repository/opencl-kernel/opencl_kernel_configure.bzl b/repository/opencl-kernel/opencl_kernel_configure.bzl index 0fe17b05..97c9639c 100644 --- a/repository/opencl-kernel/opencl_kernel_configure.bzl +++ b/repository/opencl-kernel/opencl_kernel_configure.bzl @@ -55,6 +55,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax_buffer.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_depth.cl")) + unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/sqrdiff_mean.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl")) python_bin_path = repository_ctx.which("python") -- GitLab