From 6b65f42b5894675b598bb0192a5038a38704ec54 Mon Sep 17 00:00:00 2001 From: luxuhui Date: Tue, 12 May 2020 19:06:59 +0800 Subject: [PATCH] feature: support GroupNorm op N/A Signed-off-by: Luxuhui --- mace/core/ops/operator.h | 2 +- mace/libmace/mace.cc | 2 +- mace/ops/argmax.cc | 2 +- mace/ops/eltwise.cc | 2 +- mace/ops/group_norm.cc | 195 ++++++++++++++++ mace/ops/mvnorm.cc | 5 +- mace/ops/opencl/cl/mvnorm.cl | 95 ++++++-- mace/ops/opencl/image/mvnorm.cc | 219 +++++++++++------- mace/ops/opencl/image/mvnorm.h | 67 ++++-- mace/ops/opencl/mvnorm.h | 1 - mace/ops/registry/ops_registry.cc | 2 + mace/utils/tuner.h | 2 +- test/ccunit/mace/ops/group_norm_test.cc | 262 ++++++++++++++++++++++ third_party/caffe/caffe.proto | 4 +- tools/python/transform/base_converter.py | 14 +- tools/python/transform/caffe_converter.py | 19 ++ tools/python/transform/shape_inference.py | 1 + 17 files changed, 761 insertions(+), 133 deletions(-) create mode 100644 mace/ops/group_norm.cc create mode 100644 test/ccunit/mace/ops/group_norm_test.cc diff --git a/mace/core/ops/operator.h b/mace/core/ops/operator.h index 2a23b4a9..0525934a 100644 --- a/mace/core/ops/operator.h +++ b/mace/core/ops/operator.h @@ -65,7 +65,7 @@ class Operation { } bool ExistArg(const std::string &name) const { - MACE_CHECK(operator_def_, "operator_def was null!"); + MACE_CHECK(operator_def_, "operator_def is null!"); return ProtoArgHelper::ExistArg(*operator_def_, name); } diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index dcd61557..d7b79d54 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -893,7 +893,7 @@ MaceStatus MaceEngine::Impl::TransposeOutput( int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); MACE_CHECK(output_size <= output->second.impl_->buffer_size) - << "Output size exceeds buffer size: shape" + << output_tensor->name() << " Output size exceeds buffer size: shape" << MakeString(shape) << " vs buffer size " << output->second.impl_->buffer_size; output->second.impl_->shape = shape; diff --git a/mace/ops/argmax.cc b/mace/ops/argmax.cc index 91a92509..c00e4603 100644 --- a/mace/ops/argmax.cc +++ b/mace/ops/argmax.cc @@ -176,7 +176,7 @@ class ArgMaxOp : public Operation { const FrameworkType model_type_; // for Caffe const bool has_axis_; - const bool top_k_; + const int top_k_; const bool out_val_; // for ONNX and TENSORFLOW diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index 6dba080a..39f7e648 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -953,7 +953,7 @@ class EltwiseOp : public Operation { swapped = !swapped; } - // convert tensor for caffe's boardcast + // convert tensor for caffe's broadcast if (!has_data_format_ && input0->dim_size() == 4) { if (input1->dim_size() == 2) { const_cast(input1)->Reshape( diff --git a/mace/ops/group_norm.cc b/mace/ops/group_norm.cc new file mode 100644 index 00000000..2e06d1f5 --- /dev/null +++ b/mace/ops/group_norm.cc @@ -0,0 +1,195 @@ +// Copyright 2020 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "mace/core/ops/operator.h" +#include "mace/core/registry/ops_registry.h" +#include "mace/ops/activation.h" +#include "mace/ops/delegator/activation.h" +#include "mace/utils/memory.h" + +#ifdef MACE_ENABLE_OPENCL +#include "mace/ops/opencl/image/mvnorm.h" +#endif // MACE_ENABLE_OPENCL + +namespace mace { +namespace ops { + +template +class GroupNormOp; + +template +class GroupNormOp : public Operation { + public: + explicit GroupNormOp(OpConstructContext *context) + : Operation(context), + eps_(Operation::GetOptionalArg("epsilon", + static_cast(1e-5))), + group_num_(Operation::GetOptionalArg("group_num", 32)) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = Input(INPUT); + Tensor *output = Output(OUTPUT); + MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", + input->dim_size()); + const std::vector &input_shape = input->shape(); + MACE_RETURN_IF_ERROR(output->Resize(input_shape)); + + const auto batch = input_shape[0]; + const auto channel = input_shape[1]; + const auto height = input_shape[2]; + const auto width = input_shape[3]; + MACE_CHECK(channel % group_num_ == 0, + "group_num_ invalid.", channel, group_num_); + const auto group_size = channel / group_num_; + + Tensor::MappingGuard guard_input(input); + Tensor::MappingGuard guard_output(output); + + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + const auto outer_loop = batch * group_num_; + const auto inner_loop = group_size * height * width; + utils::ThreadPool &thread_pool = + context->device()->cpu_runtime()->thread_pool(); + + auto *scratch_buffer = context->device()->scratch_buffer(); + scratch_buffer->Rewind(); + auto scratch_buffer_size = outer_loop * sizeof(float) * 2; + MACE_RETURN_IF_ERROR(scratch_buffer->GrowSize(scratch_buffer_size)); + float *mean_ptr = scratch_buffer->mutable_data(); + float *variance_ptr = mean_ptr + outer_loop; + + // compute EX + thread_pool.Compute1D([=](index_t start, index_t end, index_t step) { + for (index_t i = start; i < end; i += step) { + const auto offset = inner_loop * i; + mean_ptr[i] = std::accumulate(input_data + offset, + input_data + offset + inner_loop, + static_cast(0.0f)); + mean_ptr[i] /= inner_loop; + } + }, 0, outer_loop, 1); + + // compute (X - EX)^2 + thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + const auto offset = i * inner_loop; + for (index_t j = start1; j < end1; j += step1) { + const auto idx = offset + j; + const auto x_ex = input_data[idx] - mean_ptr[i]; + output_data[idx] = x_ex * x_ex; + } + } + }, 0, outer_loop, 1, 0, inner_loop, 1); + + // compute (E((X - EX)^2) + eps_)^0.5 + thread_pool.Compute1D([=](index_t start, index_t end, index_t step) { + for (index_t i = start; i < end; i += step) { + auto output_data_base = output_data + inner_loop * i; + variance_ptr[i] = std::accumulate(output_data_base, + output_data_base + inner_loop, + static_cast(0.0f)); + variance_ptr[i] = std::pow(variance_ptr[i] / inner_loop + eps_, 0.5f); + } + }, 0, outer_loop, 1); + + // compute (X - EX) / ((E((X - EX)^2) + eps_)^0.5) + thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + const auto offset = i * inner_loop; + for (index_t j = start1; j < end1; j += step1) { + output_data[offset + j] = + (input_data[offset + j] - mean_ptr[i]) / variance_ptr[i]; + } + } + }, 0, outer_loop, 1, 0, inner_loop, 1); + + return MaceStatus::MACE_SUCCESS; + } + + private: + const float eps_; + const int group_num_; + + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +#ifdef MACE_ENABLE_OPENCL +template<> +class GroupNormOp : public Operation { + public: + explicit GroupNormOp(OpConstructContext *context) : Operation(context) { + const auto group_num = Operation::GetOptionalArg("group_num", 32); + const auto eps = Operation::GetOptionalArg( + "epsilon", static_cast(1e-5)); + if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) { + kernel_ = make_unique( + true, opencl::image::MeanType::GROUP_CHANNELS, eps, group_num); + } else { + MACE_NOT_IMPLEMENTED; + } + } + + MaceStatus Run(OpContext *context) override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.", + input->dim_size()); + MACE_RETURN_IF_ERROR(output->ResizeLike(input)); + + return kernel_->Compute(context, input, output); + } + + private: + std::unique_ptr kernel_; + MACE_OP_INPUT_TAGS(INPUT); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; +#endif // MACE_ENABLE_OPENCL + +void RegisterGroupNorm(OpRegistry *op_registry) { + MACE_REGISTER_OP(op_registry, "GroupNorm", GroupNormOp, + DeviceType::CPU, float); + MACE_REGISTER_BF16_OP(op_registry, "GroupNorm", GroupNormOp, DeviceType::CPU); + MACE_REGISTER_GPU_OP(op_registry, "GroupNorm", GroupNormOp); + MACE_REGISTER_OP_CONDITION( + op_registry, OpConditionBuilder("GroupNorm").SetDevicePlacerFunc( + [](OpConditionContext *context) -> std::set { + auto op = context->operator_def(); + if (op->output_shape_size() != op->output_size()) { + return {DeviceType::CPU, DeviceType::GPU}; + } + + const int group_num = ProtoArgHelper::GetOptionalArg( + *op, "group_num", 32); + auto output_channels = op->output_shape(0).dims()[3]; + const int group_size = output_channels / group_num; + if (group_size % 4 == 0) { + return {DeviceType::CPU, DeviceType::GPU}; + } + + return {DeviceType::CPU}; + })); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/mvnorm.cc b/mace/ops/mvnorm.cc index 11bcff21..f593f6cc 100644 --- a/mace/ops/mvnorm.cc +++ b/mace/ops/mvnorm.cc @@ -145,9 +145,12 @@ class MVNormOp : public Operation { Operation::GetOptionalArg("across_channels", false); auto eps = Operation::GetOptionalArg("epsilon", 1e-9); + auto mean_type = across_channels ? + opencl::image::MeanType::ACROSS_CHANNELS : + opencl::image::MeanType::SINGLE_CHANNEL; if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) { kernel_ = make_unique( - normalize_variance, across_channels, eps); + normalize_variance, mean_type, eps); } else { MACE_NOT_IMPLEMENTED; } diff --git a/mace/ops/opencl/cl/mvnorm.cl b/mace/ops/opencl/cl/mvnorm.cl index 850b1958..14f74efb 100644 --- a/mace/ops/opencl/cl/mvnorm.cl +++ b/mace/ops/opencl/cl/mvnorm.cl @@ -1,14 +1,15 @@ #include -DATA_TYPE4 compute_mean_image(image2d_t input, const int width_idx, - const int hb_idx, const int chan_blks, - const int height, const int width) { +DATA_TYPE4 compute_mean_image(image2d_t input, const int height, + const int width, const int chan_blks, + const int group_blks, + const int batch_idx, const int chan_blk_idx) { DATA_TYPE4 total = 0.0f; DATA_TYPE4 mean = 0.0f; - const int hb_base = mul24(hb_idx / height, height); - const int wc_blks = mul24(width, chan_blks); + const int hb_base = mul24(batch_idx, height); #ifdef ACROSS_CHANNELS + const int wc_blks = mul24(width, chan_blks); for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) { for (int pos = 0; pos < wc_blks; ++pos) { DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx)); @@ -16,7 +17,23 @@ DATA_TYPE4 compute_mean_image(image2d_t input, const int width_idx, } } DATA_TYPE total_value = total.x + total.y + total.z + total.w; - DATA_TYPE mean_value = total_value / (DATA_TYPE)(mul24(mul24(height, wc_blks), 4)); + DATA_TYPE mean_value = + total_value / (DATA_TYPE)(mul24(mul24(height, wc_blks), 4)); + mean = (DATA_TYPE4){mean_value, mean_value, mean_value, mean_value}; +#else +#ifdef GROUP_CHANNELS + const int group_base = chan_blk_idx / group_blks * group_blks; + const int wg_blks_start = mul24(width, group_base); + const int wg_blks_end = wg_blks_start + group_blks * width; + for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) { + for (int pos = wg_blks_start; pos < wg_blks_end; ++pos) { + DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx)); + total += in_data; + } + } + DATA_TYPE total_value = total.x + total.y + total.z + total.w; + const int total_num = mul24(mul24(height, wg_blks_end - wg_blks_start), 4); + DATA_TYPE mean_value = total_value / (DATA_TYPE)(total_num); mean = (DATA_TYPE4){mean_value, mean_value, mean_value, mean_value}; #else for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) { @@ -27,15 +44,42 @@ DATA_TYPE4 compute_mean_image(image2d_t input, const int width_idx, } } mean = total / mul24(height, width); -#endif +#endif // GROUP_CHANNELS +#endif // ACROSS_CHANNELS return mean; } +__kernel void mvnorm_compute_mean_value(OUT_OF_RANGE_PARAMS + GLOBAL_WORK_GROUP_SIZE_DIM2 + __read_only image2d_t input, + __private const int height, + __private const int width, + __private const int group_blks, + __write_only image2d_t output) { + const int chan_blk_idx = get_global_id(0); + const int batch_idx = get_global_id(1); + +#ifndef NON_UNIFORM_WORK_GROUP + if (chan_blk_idx >= global_size_dim0 || batch_idx >= global_size_dim1) { + return; + } +#endif + + const int chan_blks = global_size_dim0; + const int batch = global_size_dim1; + + DATA_TYPE4 mean = compute_mean_image(input, height, width, chan_blks, + group_blks, batch_idx, chan_blk_idx); + WRITE_IMAGET(output, (int2)(chan_blk_idx, batch_idx), mean); +} + __kernel void mvnorm_mean(OUT_OF_RANGE_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM3 __read_only image2d_t input, + __read_only image2d_t mean_image, // E(X) __private const int height, + __private const int group_blks, __write_only image2d_t output) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); @@ -54,18 +98,20 @@ __kernel void mvnorm_mean(OUT_OF_RANGE_PARAMS const int pos = mad24(chan_blk_idx, width, width_idx); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx)); - DATA_TYPE4 mean = compute_mean_image(input, width_idx, - hb_idx, chan_blks, height, width); + DATA_TYPE4 mean = READ_IMAGET( + mean_image, SAMPLER, (int2)(chan_blk_idx, hb_idx / height)); in_data -= mean; WRITE_IMAGET(output, (int2)(pos, hb_idx), in_data); } +// compute the (X - EX)^2 __kernel void mvnorm_vn_step1(OUT_OF_RANGE_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM3 __read_only image2d_t input, - __write_only image2d_t mean_image, // E(X) + __read_only image2d_t mean_image, // E(X) __write_only image2d_t square_image, // (X - EX)^2 - __private const int height) { + __private const int height, + __private const int group_blks) { const int chan_blk_idx = get_global_id(0); const int width_idx = get_global_id(1); const int hb_idx = get_global_id(2); @@ -82,23 +128,21 @@ __kernel void mvnorm_vn_step1(OUT_OF_RANGE_PARAMS const int pos = mad24(chan_blk_idx, width, width_idx); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx)); - DATA_TYPE4 mean = compute_mean_image(input, width_idx, - hb_idx, chan_blks, height, width); + DATA_TYPE4 mean = + READ_IMAGET(mean_image, SAMPLER, (int2)(chan_blk_idx, hb_idx / height)); in_data = in_data - mean; - DATA_TYPE4 pow_data = in_data * in_data; - if (hb_idx == 0 && width_idx == 0) { - WRITE_IMAGET(mean_image, (int2)(chan_blk_idx, 0), mean); - } + DATA_TYPE4 pow_data = in_data * in_data; WRITE_IMAGET(square_image, (int2)(pos, hb_idx), pow_data); } - +// compute (X - EX) / (E((X - EX)^2)^0.5 + eps_) __kernel void mvnorm_vn_step2(OUT_OF_RANGE_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM3 __read_only image2d_t input, __read_only image2d_t mean_image, // E(X) - __read_only image2d_t square_image, // (X - EX)^2 + __read_only image2d_t mean_image_sqr, // E((X - EX)^2) __private const int height, + __private const int group_blks, __private const float eps, __write_only image2d_t output) { const int chan_blk_idx = get_global_id(0); @@ -115,15 +159,20 @@ __kernel void mvnorm_vn_step2(OUT_OF_RANGE_PARAMS const int chan_blks = global_size_dim0; const int width = global_size_dim1; - DATA_TYPE4 mean = READ_IMAGET(mean_image, SAMPLER, (int2)(chan_blk_idx, 0)); + DATA_TYPE4 mean = READ_IMAGET( + mean_image, SAMPLER, (int2)(chan_blk_idx, hb_idx / height)); const int pos = mad24(chan_blk_idx, width, width_idx); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx)); in_data = in_data - mean; - DATA_TYPE4 mean_v = compute_mean_image(square_image, width_idx, - hb_idx, chan_blks, height, width); + DATA_TYPE4 mean_sqr = READ_IMAGET( + mean_image_sqr, SAMPLER, (int2)(chan_blk_idx, hb_idx / height));; - DATA_TYPE4 norm_data = in_data / (sqrt(mean_v) + eps); +#ifdef GROUP_CHANNELS + DATA_TYPE4 norm_data = in_data / sqrt(mean_sqr + eps); +#else + DATA_TYPE4 norm_data = in_data / (sqrt(mean_sqr) + eps); +#endif WRITE_IMAGET(output, (int2)(pos, hb_idx), norm_data); } diff --git a/mace/ops/opencl/image/mvnorm.cc b/mace/ops/opencl/image/mvnorm.cc index b0356bf6..de409310 100644 --- a/mace/ops/opencl/image/mvnorm.cc +++ b/mace/ops/opencl/image/mvnorm.cc @@ -29,101 +29,167 @@ namespace { MaceStatus BuildMVNKernel(OpenCLRuntime *runtime, cl::Kernel *kernel, const char *kernel_name, std::set *built_options, - bool across_channel) { - std::stringstream micro_name; - micro_name << "-Dmvnorm=" << kernel_name; - built_options->emplace(micro_name.str()); + MeanType mean_type_) { + std::stringstream macro_name; + macro_name << "-Dmvnorm=" << kernel_name; + built_options->emplace(macro_name.str()); built_options->emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT)); built_options->emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT)); - if (across_channel) { + if (mean_type_ == MeanType::ACROSS_CHANNELS) { built_options->emplace("-DACROSS_CHANNELS"); + } else if (mean_type_ == MeanType::GROUP_CHANNELS) { + built_options->emplace("-DGROUP_CHANNELS"); } MACE_RETURN_IF_ERROR(runtime->BuildKernel("mvnorm", kernel_name, *built_options, kernel)); return MaceStatus::MACE_SUCCESS; } -std::unique_ptr CreateImage( - OpContext *context, const DataType dt, - const std::vector &buffer_shape) { - std::unique_ptr image = - make_unique(context->device()->allocator()); - std::vector shape; - OpenCLUtil::CalImage2DShape( - buffer_shape, OpenCLBufferType::IN_OUT_CHANNEL, &shape); - MACE_CHECK(image->Allocate(shape, dt) == MaceStatus::MACE_SUCCESS); - VLOG(1) << "MVNormKernel::CreateImage allocate image_:" << MakeString(shape); - - return image; -} - } // namespace MVNormKernel::MVNormKernel(bool normalize_variance, - bool across_channels, float eps) + MeanType mean_type, float eps, int group_num) : normalize_variance_(normalize_variance), - across_channels_(across_channels), - eps_(eps) {} - -void MVNormKernel::CheckImage(OpContext *context, const DataType dt, - const std::vector &square_shape, - const std::vector &mean_shape) { - if (square_image_ == nullptr) { - square_image_ = CreateImage(context, dt, square_shape); - } - - if (mean_image_ == nullptr) { - mean_image_ = CreateImage(context, dt, mean_shape); - } -} + mean_type_(mean_type), + eps_(eps), + group_num_(group_num) {} -MaceStatus MVNormKernel::Compute(OpContext - *context, - const Tensor *input, Tensor - *output) { +MaceStatus MVNormKernel::Compute(OpContext *context, + const Tensor *input, Tensor *output) { const auto batch = input->dim(0); const auto height = input->dim(1); const auto width = input->dim(2); const auto channels = input->dim(3); + index_t group_blocks = 0; + if (mean_type_ == MeanType::GROUP_CHANNELS) { + MACE_CHECK(group_num_ > 0, "group num should > 0"); + const index_t group = channels / group_num_; + MACE_CHECK(group > 0 && group % 4 == 0, group, " can not be divided by 4"); + group_blocks = group / 4; + } + + return DoCompute(context, input, output, batch, + height, width, channels, group_blocks); +} + +MaceStatus MVNormKernel::DoCompute( + OpContext *context, const Tensor *input, Tensor *output, + const index_t batch, const index_t height, const index_t width, + const index_t channels, const index_t group_blocks) { const index_t channel_blocks = RoundUpDiv4(channels); const uint32_t gws[3] = {static_cast(channel_blocks), static_cast(width), static_cast(height * batch)}; auto runtime = context->device()->gpu_runtime()->opencl_runtime(); + const std::vector mean_shape = {batch, 1, 1, channels}; + std::vector mean_image_shape; + OpenCLUtil::CalImage2DShape(mean_shape, OpenCLBufferType::IN_OUT_HEIGHT, + &mean_image_shape); + ScratchImageManager *scratch_manager = + context->device()->gpu_runtime()->scratch_image_manager(); + ScratchImage scratch_mean_image(scratch_manager); + auto mace_mean_image = scratch_mean_image.Scratch( + context->device()->allocator(), mean_image_shape, input->dtype()); + cl::Image *mean_image = static_cast(mace_mean_image->buffer()); + if (normalize_variance_) { - const std::vector &square_shape = input->buffer_shape(); - const std::vector mean_shape = {1, 1, 1, channels}; - CheckImage(context, input->dtype(), square_shape, mean_shape); - // compute the (X - EX)^2 + ScratchImage scratch_mean_image_sqr(scratch_manager); + auto mace_mean_image_sqr = scratch_mean_image_sqr.Scratch( + context->device()->allocator(), mean_image_shape, input->dtype()); + cl::Image *mean_image_sqr = + static_cast(mace_mean_image_sqr->buffer()); + // compute the EX + MACE_RETURN_IF_ERROR(ExecuteMeanValueKernel( + context, runtime, batch, height, width, channel_blocks, group_blocks, + input->opencl_image(), mean_image)); + // compute (X - EX)^2 to output MACE_RETURN_IF_ERROR(ExecuteVarianceNormStep1Kernel( - context, runtime, gws, input)); + context, runtime, gws, height, group_blocks, input->opencl_image(), + mean_image, output->opencl_image())); + // compute E((X - EX)^2) to mean_image_sqr_ + MACE_RETURN_IF_ERROR(ExecuteMeanValueKernel( + context, runtime, batch, height, width, channel_blocks, group_blocks, + output->opencl_image(), mean_image_sqr)); // compute the compute (X - EX) / (E((X - EX)^2)^0.5 + eps_) MACE_RETURN_IF_ERROR(ExecuteVarianceNormStep2Kernel( - context, runtime, gws, input, output)); + context, runtime, gws, height, group_blocks, input->opencl_image(), + mean_image, mean_image_sqr, output->opencl_image())); } else { + // compute the EX + MACE_RETURN_IF_ERROR(ExecuteMeanValueKernel( + context, runtime, batch, height, width, channel_blocks, group_blocks, + input->opencl_image(), mean_image)); + // compute the (X - EX) MACE_RETURN_IF_ERROR(ExecuteMeanNormKernel( - context, runtime, gws, input, output)); + context, runtime, gws, height, group_blocks, input->opencl_image(), + mean_image, output->opencl_image())); } - return - MaceStatus::MACE_SUCCESS; + return MaceStatus::MACE_SUCCESS; +} + +MaceStatus MVNormKernel::ExecuteMeanValueKernel(OpContext *context, + OpenCLRuntime *runtime, + const index_t batch, + const index_t height, + const index_t width, + const index_t channel_blocks, + const index_t group_blocks, + const cl::Image *input_image, + cl::Image *output_image) { + MACE_OUT_OF_RANGE_DEFINITION; + if (kernel_mean_.get() == nullptr) { + std::set built_options; + MACE_OUT_OF_RANGE_CONFIG; + MACE_NON_UNIFORM_WG_CONFIG; + MACE_RETURN_IF_ERROR( + BuildMVNKernel(runtime, &kernel_mean_, "mvnorm_compute_mean_value", + &built_options, mean_type_)); + kwg_size_mean_ = static_cast( + runtime->GetKernelMaxWorkGroupSize(kernel_mean_)); + } + + const uint32_t gws[2] = {static_cast(channel_blocks), + static_cast(batch)}; + const std::vector lws = {static_cast(kwg_size_mean_) / 8, + 8, 0}; + MACE_OUT_OF_RANGE_INIT(kernel_mean_); + uint32_t idx = 0; + MACE_OUT_OF_RANGE_SET_ARGS(kernel_mean_); + MACE_SET_2D_GWS_ARGS(kernel_mean_, gws); + kernel_mean_.setArg(idx++, *input_image); + kernel_mean_.setArg(idx++, static_cast(height)); + kernel_mean_.setArg(idx++, static_cast(width)); + kernel_mean_.setArg(idx++, static_cast(group_blocks)); + kernel_mean_.setArg(idx++, *output_image); + + std::string tuning_key = Concat( + "mvnorm_compute_mean_opencl_kernel", gws[0], + gws[1], normalize_variance_, mean_type_); + + MACE_RETURN_IF_ERROR(TuningOrRun2DKernel(runtime, kernel_mean_, tuning_key, + gws, lws, context->future())); + MACE_OUT_OF_RANGE_VALIDATION; + return MaceStatus::MACE_SUCCESS; } MaceStatus MVNormKernel::ExecuteMeanNormKernel(OpContext *context, OpenCLRuntime *runtime, const uint32_t (&gws)[3], - const Tensor *input, - Tensor *output) { - const auto height = input->dim(1); + const index_t height, + const index_t group_blocks, + const cl::Image *input, + const cl::Image *mean_image, + cl::Image *output) { MACE_OUT_OF_RANGE_DEFINITION; if (kernel_step1_.get() == nullptr) { std::set built_options; MACE_OUT_OF_RANGE_CONFIG; MACE_NON_UNIFORM_WG_CONFIG; MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step1_, "mvnorm_mean", - &built_options, across_channels_)); + &built_options, mean_type_)); kwg_size_step1_ = static_cast( runtime->GetKernelMaxWorkGroupSize(kernel_step1_)); } @@ -132,14 +198,16 @@ MaceStatus MVNormKernel::ExecuteMeanNormKernel(OpContext *context, uint32_t idx = 0; MACE_OUT_OF_RANGE_SET_ARGS(kernel_step1_); MACE_SET_3D_GWS_ARGS(kernel_step1_, gws); - kernel_step1_.setArg(idx++, *(input->opencl_image())); + kernel_step1_.setArg(idx++, *input); + kernel_step1_.setArg(idx++, *mean_image); kernel_step1_.setArg(idx++, static_cast(height)); - kernel_step1_.setArg(idx++, *(output->opencl_image())); + kernel_step1_.setArg(idx++, static_cast(group_blocks)); + kernel_step1_.setArg(idx++, *output); std::vector lws = Default3DLocalWS(runtime, gws, kwg_size_step1_); - std::string - tuning_key = Concat("mvnorm_mean_opencl_kernel", gws[0], gws[1], gws[2], - normalize_variance_, across_channels_); + std::string tuning_key = Concat("mvnorm_mean_opencl_kernel", gws[0], gws[1], + gws[2], normalize_variance_, + mean_type_, group_blocks); MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step1_, tuning_key, gws, lws, context->future())); @@ -147,12 +215,11 @@ MaceStatus MVNormKernel::ExecuteMeanNormKernel(OpContext *context, return MaceStatus::MACE_SUCCESS; } -// The first step of compute Variance Norm, compute the (X - EX)^2 -// store them into the square_image_ +// compute the (X - EX)^2 MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel( OpContext *context, OpenCLRuntime *runtime, - const uint32_t (&gws)[3], const Tensor *input) { - const auto height = input->dim(1); + const uint32_t (&gws)[3], const index_t height, const index_t group_blocks, + const cl::Image *input, const cl::Image *mean_image, cl::Image *output) { MACE_OUT_OF_RANGE_DEFINITION; if (kernel_step1_.get() == nullptr) { std::set built_options; @@ -160,7 +227,7 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel( MACE_NON_UNIFORM_WG_CONFIG; MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step1_, "mvnorm_vn_step1", - &built_options, across_channels_)); + &built_options, mean_type_)); kwg_size_step1_ = static_cast( runtime->GetKernelMaxWorkGroupSize(kernel_step1_)); } @@ -169,18 +236,17 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel( uint32_t idx = 0; MACE_OUT_OF_RANGE_SET_ARGS(kernel_step1_); MACE_SET_3D_GWS_ARGS(kernel_step1_, gws); - kernel_step1_.setArg(idx++, *(input->opencl_image())); - cl::Image *mean_image = static_cast(mean_image_->buffer()); + kernel_step1_.setArg(idx++, *input); kernel_step1_.setArg(idx++, *mean_image); - cl::Image *square_image = static_cast(square_image_->buffer()); - kernel_step1_.setArg(idx++, *square_image); + kernel_step1_.setArg(idx++, *output); kernel_step1_.setArg(idx++, static_cast(height)); + kernel_step1_.setArg(idx++, static_cast(group_blocks)); std::vector lws = Default3DLocalWS(runtime, gws, kwg_size_step1_); std::string tuning_key = Concat("mvnorm_v_step1_opencl_kernel", gws[0], gws[1], gws[2], normalize_variance_, - across_channels_); + mean_type_); MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step1_, tuning_key, gws, lws, context->future())); @@ -188,12 +254,12 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel( return MaceStatus::MACE_SUCCESS; } -// The second step of compute Variance Norm, read the (X - EX)^2 from -// square_image_ and compute (X - EX) / (E((X - EX)^2)^0.5 + eps_) +// compute (X - EX) / (E((X - EX)^2)^0.5 + eps_) MaceStatus MVNormKernel::ExecuteVarianceNormStep2Kernel( OpContext *context, OpenCLRuntime *runtime, const uint32_t (&gws)[3], - const Tensor *input, Tensor *output) { - const auto height = input->dim(1); + const index_t height, const index_t group_blocks, + const cl::Image *input, const cl::Image *mean_image, + const cl::Image *mean_image_sqr, cl::Image *output) { MACE_OUT_OF_RANGE_DEFINITION; if (kernel_step2_.get() == nullptr) { std::set built_options; @@ -201,7 +267,7 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep2Kernel( MACE_NON_UNIFORM_WG_CONFIG; MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step2_, "mvnorm_vn_step2", - &built_options, across_channels_)); + &built_options, mean_type_)); kwg_size_step2_ = static_cast( runtime->GetKernelMaxWorkGroupSize(kernel_step2_)); } @@ -210,20 +276,19 @@ MaceStatus MVNormKernel::ExecuteVarianceNormStep2Kernel( uint32_t idx = 0; MACE_OUT_OF_RANGE_SET_ARGS(kernel_step2_); MACE_SET_3D_GWS_ARGS(kernel_step2_, gws); - kernel_step2_.setArg(idx++, *(input->opencl_image())); - cl::Image *mean_image = static_cast(mean_image_->buffer()); + kernel_step2_.setArg(idx++, *input); kernel_step2_.setArg(idx++, *mean_image); - cl::Image *square_image = static_cast(square_image_->buffer()); - kernel_step2_.setArg(idx++, *square_image); + kernel_step2_.setArg(idx++, *mean_image_sqr); kernel_step2_.setArg(idx++, static_cast(height)); + kernel_step2_.setArg(idx++, static_cast(group_blocks)); kernel_step2_.setArg(idx++, static_cast(eps_)); - kernel_step2_.setArg(idx++, *(output->opencl_image())); + kernel_step2_.setArg(idx++, *output); std::vector lws = Default3DLocalWS(runtime, gws, kwg_size_step2_); std::string tuning_key = Concat("mvnorm_v_step2_opencl_kernel", gws[0], gws[1], gws[2], normalize_variance_, - across_channels_); + mean_type_); MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step2_, tuning_key, gws, lws, context->future())); diff --git a/mace/ops/opencl/image/mvnorm.h b/mace/ops/opencl/image/mvnorm.h index 5752167e..93fe424d 100644 --- a/mace/ops/opencl/image/mvnorm.h +++ b/mace/ops/opencl/image/mvnorm.h @@ -28,48 +28,79 @@ namespace ops { namespace opencl { namespace image { +enum MeanType { + SINGLE_CHANNEL, + GROUP_CHANNELS, + ACROSS_CHANNELS, +}; + class MVNormKernel : public OpenCLMVNormKernel { public: - explicit MVNormKernel(bool normalize_variance_, - bool across_channels, float eps); + explicit MVNormKernel(bool normalize_variance_, MeanType mean_type, + float eps, int group_num = 0); ~MVNormKernel() = default; MaceStatus Compute( OpContext *context, const Tensor *input, Tensor *output) override; private: - void CheckImage(OpContext *context, const DataType dt, - const std::vector &square_shape, - const std::vector &mean_shape); + MaceStatus DoCompute(OpContext *context, const Tensor *input, + Tensor *output, const index_t batch, + const index_t height, const index_t width, + const index_t channels, const index_t group_blocks); + + MaceStatus ExecuteMeanValueKernel(OpContext *context, + OpenCLRuntime *runtime, + const index_t batch, + const index_t height, + const index_t width, + const index_t channel_blocks, + const index_t group_blocks, + const cl::Image *input_image, + cl::Image *output_image); + MaceStatus ExecuteMeanNormKernel(OpContext *context, OpenCLRuntime *runtime, const uint32_t (&gws)[3], - const Tensor *input, - Tensor *output); + const index_t height, + const index_t group_blocks, + const cl::Image *input, + const cl::Image *mean_image, + cl::Image *output); + + // compute the (X - EX)^2 MaceStatus ExecuteVarianceNormStep1Kernel(OpContext *context, OpenCLRuntime *runtime, const uint32_t (&gws)[3], - const Tensor *input); + const index_t height, + const index_t group_blocks, + const cl::Image *input, + const cl::Image *mean_image, + cl::Image *output); + + // compute (X - EX) / (E((X - EX)^2)^0.5 + eps_) MaceStatus ExecuteVarianceNormStep2Kernel(OpContext *context, OpenCLRuntime *runtime, const uint32_t (&gws)[3], - const Tensor *input, - Tensor *output); + const index_t height, + const index_t group_blocks, + const cl::Image *input, + const cl::Image *mean_image, + const cl::Image *mean_image_sqr, + cl::Image *output); private: - bool normalize_variance_; - bool across_channels_; - float eps_; + const bool normalize_variance_; + const MeanType mean_type_; + const float eps_; + const int group_num_; + cl::Kernel kernel_mean_; + uint32_t kwg_size_mean_; cl::Kernel kernel_step1_; uint32_t kwg_size_step1_; cl::Kernel kernel_step2_; uint32_t kwg_size_step2_; - - // the cache of (X - EX)^2 - std::unique_ptr square_image_; - // the cache of EX - std::unique_ptr mean_image_; }; } // namespace image diff --git a/mace/ops/opencl/mvnorm.h b/mace/ops/opencl/mvnorm.h index 433ef0a2..5e96c759 100644 --- a/mace/ops/opencl/mvnorm.h +++ b/mace/ops/opencl/mvnorm.h @@ -16,7 +16,6 @@ #define MACE_OPS_OPENCL_MVNORM_H_ #include "mace/public/mace.h" -#include "mace/utils/math.h" namespace mace { diff --git a/mace/ops/registry/ops_registry.cc b/mace/ops/registry/ops_registry.cc index 2f6e8c73..9135fa00 100644 --- a/mace/ops/registry/ops_registry.cc +++ b/mace/ops/registry/ops_registry.cc @@ -41,6 +41,7 @@ extern void RegisterExtractPooling(OpRegistry *op_registry); extern void RegisterFill(OpRegistry *op_registry); extern void RegisterFullyConnected(OpRegistry *op_registry); extern void RegisterGather(OpRegistry *op_registry); +extern void RegisterGroupNorm(OpRegistry *op_registry); extern void RegisterIdentity(OpRegistry *op_registry); extern void RegisterIfDefined(OpRegistry *op_registry); extern void RegisterInferConv2dShape(OpRegistry *op_registry); @@ -120,6 +121,7 @@ void RegisterAllOps(OpRegistry *registry) { ops::RegisterFill(registry); ops::RegisterFullyConnected(registry); ops::RegisterGather(registry); + ops::RegisterGroupNorm(registry); ops::RegisterIdentity(registry); ops::RegisterIfDefined(registry); ops::RegisterInferConv2dShape(registry); diff --git a/mace/utils/tuner.h b/mace/utils/tuner.h index f7c09769..b715464f 100644 --- a/mace/utils/tuner.h +++ b/mace/utils/tuner.h @@ -95,7 +95,7 @@ class Tuner { std::vector opt_param = default_param; RetType res = Tune(param_generator, func, timer, &opt_param); VLOG(3) << "Tuning " << param_key - << " retult: " << MakeString(opt_param); + << " result: " << MakeString(opt_param); param_table_[obfucated_param_key] = opt_param; return res; } else { diff --git a/test/ccunit/mace/ops/group_norm_test.cc b/test/ccunit/mace/ops/group_norm_test.cc new file mode 100644 index 00000000..9a85a23f --- /dev/null +++ b/test/ccunit/mace/ops/group_norm_test.cc @@ -0,0 +1,262 @@ +// Copyright 2020 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/types.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class GroupNormOpTest : public OpsTestBase {}; + +namespace { +template +void TestGroupNorm(const std::vector &input_shape, + const std::vector &input, + int group_num, + const std::vector &output) { + OpsTestNet net; + net.AddInputFromArray(MakeString("Input"), input_shape, input); + + if (D == DeviceType::CPU) { + net.TransformDataFormat( + "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW); + } + OpDefBuilder("GroupNorm", "GroupNormTest") + .Input(D == DeviceType::CPU ? "InputNCHW" : "Input") + .AddIntArg("group_num", group_num) + .Output(D == DeviceType::CPU ? "OutputNCHW" : "Output") + .Finalize(net.NewOperatorDef()); + + net.RunOp(D); + + if (D == DeviceType::CPU) { + net.TransformDataFormat( + "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC); + } + + net.AddInputFromArray("ExpectedOutput", input_shape, output); + if (DataTypeToEnum::value == DT_HALF) { + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output"), 1e-2, 1e-2); + } else { + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output"), 1e-3); + } +} +} // namespace + +TEST_F(GroupNormOpTest, SimpleTestCPU) { + TestGroupNorm( + {1, 1, 2, 64}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 1, 2, 3, 4, 5, 6, 7, 8}, + 16, + {-1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318, + 1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318, + 1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318, + 1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318, + 1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.80467, 0.928465, 1.05226, + 1.17606, -1.52746, -1.09104, -0.654625, -0.218208, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.17606, -1.05226, -0.928465, -0.80467, 0.218208, + 0.654625, 1.09104, 1.52746}); +} + + +TEST_F(GroupNormOpTest, SimpleTestOpenCL) { + TestGroupNorm( + {1, 1, 2, 64}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 1, 2, 3, 4, 5, 6, 7, 8}, + 16, + {-1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318, + 1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318, + 1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318, + 1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.468507, 0.780844, 1.09318, + 1.40552, -1.52746, -1.09104, -0.654625, -0.218208, -1.52746, + -1.09104, -0.654625, -0.218208, 0.80467, 0.928465, 1.05226, + 1.17606, -1.52746, -1.09104, -0.654625, -0.218208, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.40552, -1.09318, -0.780844, -0.468507, 0.218208, + 0.654625, 1.09104, 1.52746, 0.218208, 0.654625, 1.09104, + 1.52746, -1.17606, -1.05226, -0.928465, -0.80467, 0.218208, + 0.654625, 1.09104, 1.52746}); +} + + +TEST_F(GroupNormOpTest, SimpleRandomOPENCL) { + static unsigned int seed = time(NULL); + index_t batch = 1 + rand_r(&seed) % 5; + index_t group = 4 + 4 * (rand_r(&seed) % 3); + index_t group_num = 2 + rand_r(&seed) % 16; + index_t channels = group * group_num; + index_t height = 64; + index_t width = 64; + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", + {batch, height, width, channels}); + + net.TransformDataFormat( + "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW); + + // Construct graph + OpDefBuilder("GroupNorm", "GroupNormTest") + .Input("InputNCHW") + .AddFloatArg("epsilon", 1e-3) + .Output("OutputNCHW") + .AddIntArg("group_num", group_num) + .Finalize(net.NewOperatorDef()); + + // run cpu + net.RunOp(); + + net.TransformDataFormat( + "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC); + + // Check + auto expected = net.CreateTensor(); + expected->Copy(*net.GetOutput("Output")); + + // Run on opencl + OpDefBuilder("GroupNorm", "GroupNormTest") + .Input("Input") + .AddFloatArg("epsilon", 1e-3) + .Output("Output") + .AddIntArg("group_num", group_num) + .Finalize(net.NewOperatorDef()); + + net.Setup(DeviceType::GPU); + + // Tuning + setenv("MACE_TUNING", "1", 1); + net.Run(); + unsetenv("MACE_TUNING"); + + // Run on opencl + net.Run(); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), + 1e-5, 1e-4); +} + + +TEST_F(GroupNormOpTest, SimpleRandomHalfOPENCL) { + // generate random input + static unsigned int seed = time(NULL); + index_t batch = 1 + rand_r(&seed) % 5; + index_t group = 4 + 4 * (rand_r(&seed) % 16); + index_t group_num = 2 + rand_r(&seed) % 16; + index_t channels = group * group_num; + index_t height = 64; + index_t width = 64; + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", + {batch, height, width, channels}); + + net.TransformDataFormat( + "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW); + + // Construct graph + OpDefBuilder("GroupNorm", "GroupNormTest") + .Input("InputNCHW") + .AddFloatArg("epsilon", 1e-3) + .Output("OutputNCHW") + .AddIntArg("group_num", group_num) + .Finalize(net.NewOperatorDef()); + + // run cpu + net.RunOp(); + + net.TransformDataFormat( + "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC); + + // Check + auto expected = net.CreateTensor(); + expected->Copy(*net.GetOutput("Output")); + + // Run on opencl + OpDefBuilder("GroupNorm", "GroupNormTest") + .Input("Input") + .AddFloatArg("epsilon", 1e-3) + .Output("Output") + .AddIntArg("group_num", group_num) + .AddIntArg("T", static_cast(DataType::DT_HALF)) + .Finalize(net.NewOperatorDef()); + + net.Setup(DeviceType::GPU); + + // Tuning + setenv("MACE_TUNING", "1", 1); + net.Run(); + unsetenv("MACE_TUNING"); + + // Run on opencl + net.Run(); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), + 1e-1, 1e-2); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/third_party/caffe/caffe.proto b/third_party/caffe/caffe.proto index d4e03464..db3e370b 100644 --- a/third_party/caffe/caffe.proto +++ b/third_party/caffe/caffe.proto @@ -1955,7 +1955,7 @@ message GroupNormParameter { } message ResizeNearestParameter { - optional float height_scale=1 [default = 2.0]; - optional float width_scale =2 [default = 2.0]; + optional float height_scale = 1 [default = 2.0]; + optional float width_scale = 2 [default = 2.0]; } diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index 5c81e04a..c9a889ae 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -49,7 +49,6 @@ class ActivationType(Enum): TANH = 4 SIGMOID = 5 LEAKYRELU = 6 - RELU6 = 7 class EltwiseType(Enum): @@ -101,6 +100,7 @@ MaceSupportedOps = [ 'Concat', 'Conv2D', 'Crop', + 'Cumsum', 'Deconv2D', 'DepthToSpace', 'DepthwiseConv2d', @@ -112,15 +112,18 @@ MaceSupportedOps = [ 'Fill', 'FullyConnected', 'Gather', + 'GroupNorm', 'Identity', 'IfDefined', 'InferConv2dShape', 'KaldiBatchNorm', 'LocalResponseNorm', + 'LpNorm', 'LSTMCell', 'LstmNonlinear', 'DynamicLSTM', 'MatMul', + 'MVNorm', 'OneHot', 'Pad', 'PadContext', @@ -154,11 +157,8 @@ MaceSupportedOps = [ 'Subsample', 'SumGroup', 'TargetRMSNorm', - 'Transpose', - 'Cumsum', 'Tile', - 'LpNorm', - 'MVNorm', + 'Transpose', ] MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str) @@ -178,7 +178,8 @@ MaceFixedDataFormatOps = [MaceOp.BatchNorm, MaceOp.SpaceToBatchND, MaceOp.SpaceToDepth, MaceOp.LpNorm, - MaceOp.MVNorm] + MaceOp.MVNorm, + MaceOp.GroupNorm] MaceTransposableDataFormatOps = [MaceOp.Activation, MaceOp.AddN, @@ -255,6 +256,7 @@ class MaceKeyword(object): mace_opencl_mem_type = "opencl_mem_type" mace_framework_type_str = "framework_type" mace_group_str = "group" + mace_group_num_str = "group_num" mace_wino_arg_str = "wino_block_size" mace_quantize_flag_arg_str = "quantize_flag" mace_epsilon_str = 'epsilon' diff --git a/tools/python/transform/caffe_converter.py b/tools/python/transform/caffe_converter.py index 35ce6dc6..09fd9c4b 100644 --- a/tools/python/transform/caffe_converter.py +++ b/tools/python/transform/caffe_converter.py @@ -188,6 +188,7 @@ class CaffeConverter(base_converter.ConverterInterface): 'InnerProduct': self.convert_fully_connected, 'Interp': self.convert_interp, 'BatchNorm': self.convert_folded_batchnorm, + 'GroupNorm': self.convert_group_norm, 'Crop': self.convert_crop, 'Scale': self.convert_scale, 'ShuffleChannel': self.convert_channel_shuffle, @@ -555,6 +556,24 @@ class CaffeConverter(base_converter.ConverterInterface): op.input.extend([name for name in input_names]) op.output[:] = scale_op.layer.top[:] + def convert_group_norm(self, caffe_op): + op = self.convert_general_op(caffe_op) + op.type = MaceOp.GroupNorm.name + + epsilon_arg = op.arg.add() + epsilon_arg.name = MaceKeyword.mace_epsilon_str + group_num_arg = op.arg.add() + group_num_arg.name = MaceKeyword.mace_group_num_str + + if hasattr(caffe_op, 'layer') and \ + hasattr(caffe_op.layer, 'group_norm_param'): + param = caffe_op.layer.group_norm_param + epsilon_arg.f = param.eps + group_num_arg.i = param.group_num + else: + epsilon_arg.f = 1e-5 + group_num_arg.i = 32 + def convert_pooling(self, caffe_op): op = self.convert_general_op(caffe_op) param = caffe_op.layer.pooling_param diff --git a/tools/python/transform/shape_inference.py b/tools/python/transform/shape_inference.py index e93f862d..4bd3731e 100644 --- a/tools/python/transform/shape_inference.py +++ b/tools/python/transform/shape_inference.py @@ -38,6 +38,7 @@ class ShapeInference(object): MaceOp.DepthwiseDeconv2d.name: self.infer_shape_deconv, MaceOp.Eltwise.name: self.infer_shape_eltwise, MaceOp.BatchNorm.name: self.infer_shape_general, + MaceOp.GroupNorm.name: self.infer_shape_general, MaceOp.AddN.name: self.infer_shape_general, MaceOp.Activation.name: self.infer_shape_general, MaceOp.Pooling.name: self.infer_shape_conv_pool_shape, -- GitLab