From 4d379957d3c5acf7bdc210a124ebca220f7525ee Mon Sep 17 00:00:00 2001 From: yiicy Date: Mon, 9 Mar 2020 21:43:06 +0800 Subject: [PATCH] [OPENCL] add instance norm kernel and ut, test=develop (#3122) add instance norm kernel and ut --- .../cl_kernel/image/instance_norm_kernel.cl | 79 ++++++ lite/kernels/opencl/CMakeLists.txt | 5 +- .../opencl/instance_norm_image_compute.cc | 188 ++++++++++++++ .../instance_norm_image_compute_test.cc | 240 ++++++++++++++++++ 4 files changed, 511 insertions(+), 1 deletion(-) create mode 100644 lite/backends/opencl/cl_kernel/image/instance_norm_kernel.cl create mode 100644 lite/kernels/opencl/instance_norm_image_compute.cc create mode 100644 lite/kernels/opencl/instance_norm_image_compute_test.cc diff --git a/lite/backends/opencl/cl_kernel/image/instance_norm_kernel.cl b/lite/backends/opencl/cl_kernel/image/instance_norm_kernel.cl new file mode 100644 index 0000000000..b5346e3af4 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/instance_norm_kernel.cl @@ -0,0 +1,79 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + + +__kernel void instance_norm(__read_only image2d_t input, + __write_only image2d_t output, + __read_only image2d_t scale, + __read_only image2d_t bias, + const float epsilon, + const int in_h, + const int in_w){ + __local CL_DTYPE4 saved_mean[1024]; + __local CL_DTYPE4 saved_variance[1024]; + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + const int gidx = get_group_id(0); + const int gidy = get_group_id(1); + const int spatial_size = in_h * in_w; + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + CL_DTYPE4 mean = (CL_DTYPE4)(0.f, 0.f, 0.f, 0.f); + CL_DTYPE4 variance = (CL_DTYPE4)(0.f, 0.f, 0.f, 0.f); + CL_DTYPE4 vepsilon = (CL_DTYPE4)(epsilon, epsilon, epsilon, epsilon); + const int x_offset = gidx * in_w; + const int y_offset = gidy * in_h; + int2 coor; + for (int i = lid; i < spatial_size; i += lsize) { + coor.x = i % in_w + x_offset; + coor.y = i / in_w + y_offset; + CL_DTYPE4 pixel = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coor); + mean += pixel; + variance += pixel * pixel; + } + saved_mean[lid] = mean; + saved_variance[lid] = variance; + barrier(CLK_LOCAL_MEM_FENCE); + + //! do reduction + int dynamic_size = lsize >> 1; + for (; dynamic_size > 0; dynamic_size >>= 1){ + if (lid < dynamic_size) { + saved_mean[lid] += saved_mean[lid + dynamic_size]; + saved_variance[lid] += saved_variance[lid + dynamic_size]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + mean = saved_mean[0] / spatial_size; + variance = saved_variance[0] / spatial_size - mean * mean; + variance = rsqrt(variance + vepsilon); + + //! do instance norm + coor.x = gidx; + coor.y = gidy; + CL_DTYPE4 vscale = READ_IMG_TYPE(CL_DTYPE_CHAR, scale, sampler, coor); + vscale *= variance; + CL_DTYPE4 vbias = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coor); + for (int i = lid; i < spatial_size; i += lsize) { + coor.x = i % in_w + x_offset; + coor.y = i / in_w + y_offset; + CL_DTYPE4 pixel = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coor); + pixel = (pixel - mean) * vscale + vbias; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, coor, pixel); + } +} diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index b6a4059c06..b98348f685 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -29,6 +29,7 @@ add_kernel(scale_opencl OPENCL basic SRCS scale_image_compute.cc DEPS ${cl_kerne add_kernel(grid_sampler_opencl OPENCL basic SRCS grid_sampler_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(lrn_opencl OPENCL basic SRCS lrn_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(bilinear_interp_opencl OPENCL basic SRCS bilinear_interp_image_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(instance_norm_opencl OPENCL basic SRCS instance_norm_image_compute.cc DEPS ${cl_kernel_deps}) # extra # wait to add ... @@ -84,7 +85,9 @@ lite_cc_test(test_lrn_image_opencl SRCS lrn_image_compute_test.cc lite_cc_test(test_bilinear_interp_image_opencl SRCS bilinear_interp_image_compute_test.cc DEPS bilinear_interp_opencl op_registry program context) - + +lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_test.cc + DEPS instance_norm_opencl op_registry program context) ###################### # buffer kernel # ###################### diff --git a/lite/kernels/opencl/instance_norm_image_compute.cc b/lite/kernels/opencl/instance_norm_image_compute.cc new file mode 100644 index 0000000000..d90acdb02d --- /dev/null +++ b/lite/kernels/opencl/instance_norm_image_compute.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_image_converter.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/logging.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { +class InstanceNormImageCompute : public KernelLite { + public: + using param_t = operators::InstanceNormParam; + + std::string doc() const override { + return "InstanceNorm using cl::Image2D(ImageDefault/RGBA), kFP16"; + } + + void PrepareForRun() override { + instance_norm_param_ = param_.get_mutable(); + auto channel = instance_norm_param_->scale->dims()[0]; + auto batch = instance_norm_param_->x->dims()[0]; + int64_t cgroup = (channel + 3) / 4; + int64_t cround = cgroup * 4; + std::vector scale_img(cround * batch); + std::vector bias_img(cround * batch); + const float* scale_data = instance_norm_param_->scale->data(); + const float* bias_data = instance_norm_param_->bias->data(); + //! init scale_img bias_img data + for (int i = 0; i < channel; ++i) { + scale_img[i] = Float2Half(scale_data[i]); + bias_img[i] = Float2Half(bias_data[i]); + } + for (int i = channel; i < cround; ++i) { + scale_img[i] = Float2Half(0.f); + bias_img[i] = Float2Half(0.f); + } + for (int i = 1; i < batch; ++i) { + memcpy(scale_img.data() + i * cround, + scale_img.data(), + cround * sizeof(half_t)); + memcpy(bias_img.data() + i * cround, + bias_img.data(), + cround * sizeof(half_t)); + } + DDim scale_img_size{{cgroup, batch}}; + scale_image_.mutable_data( + scale_img_size[0], scale_img_size[1], scale_img.data()); + bias_image_.mutable_data( + scale_img_size[0], scale_img_size[1], bias_img.data()); + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/instance_norm_kernel.cl", build_options_); + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + } + + void Run() override { + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto* x = instance_norm_param_->x; + auto* out = instance_norm_param_->out; + auto in_dims = x->dims(); + + int batch = in_dims[0]; + int channel = in_dims[1]; + int in_h = in_dims[2]; + int in_w = in_dims[3]; + + VLOG(4) << "x->target():" << TargetToStr(x->target()); + VLOG(4) << "out->target():" << TargetToStr(out->target()); + VLOG(4) << "x->dims():" << in_dims; + + auto out_image_shape = InitImageDimInfoWith(in_dims); + auto* x_img = x->data(); + + auto* out_img = out->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + VLOG(4) << "out_image_shape[w,h]: " << out_image_shape["width"] << " " + << out_image_shape["height"]; + + VLOG(4) << "in_h: " << in_h << ", in_w: " << in_w; + + int threads = 512; + int group_size_x = (channel + 3) / 4; + int group_size_y = batch; + auto local_work_size = cl::NDRange{static_cast(threads), + static_cast(1), + static_cast(1)}; + auto global_work_size = + cl::NDRange{static_cast(group_size_x * threads), + static_cast(group_size_y), + static_cast(1)}; + VLOG(4) << "local_work_size:[2D]:" << local_work_size[0] << " " + << local_work_size[1] << " " << local_work_size[2]; + VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " " + << global_work_size[1] << " " << global_work_size[2]; + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + auto* scale_img = scale_image_.data(); + auto* bias_img = bias_image_.data(); + float epsilon = instance_norm_param_->epsilon; + int arg_idx = 0; + + cl_int status = kernel.setArg(arg_idx++, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, *scale_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, *bias_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, epsilon); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, in_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, in_w); + CL_CHECK_FATAL(status); + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + local_work_size, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_img, event_); + } + + protected: + param_t* instance_norm_param_{nullptr}; + std::string kernel_func_name_{"instance_norm"}; + std::string build_options_{"-DCL_DTYPE_half"}; + std::shared_ptr event_{new cl::Event}; + Tensor scale_image_; + Tensor bias_image_; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace ocl = paddle::lite::kernels::opencl; +REGISTER_LITE_KERNEL(instance_norm, + kOpenCL, + kFP16, + kImageDefault, + ocl::InstanceNormImageCompute, + ImageDefault) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Y", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/opencl/instance_norm_image_compute_test.cc b/lite/kernels/opencl/instance_norm_image_compute_test.cc new file mode 100644 index 0000000000..63d172f5ed --- /dev/null +++ b/lite/kernels/opencl/instance_norm_image_compute_test.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/profile/timer.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (5e-3) +DEFINE_int32(warmup, 0, "warmup times"); +DEFINE_int32(repeats, 1, "repeats times"); + +using paddle::lite::profile::Timer; + +namespace paddle { +namespace lite { +void instance_norm_ref(Tensor* x, + Tensor* y, + Tensor* scale, + Tensor* bias, + Tensor* saved_mean, + Tensor* saved_variance, + float epsilon) { + auto x_data = x->data(); + auto scale_data = scale->data(); + auto bias_data = bias->data(); + auto y_data = y->mutable_data(); + auto saved_mean_data = saved_mean->mutable_data(); + auto saved_variance_data = saved_variance->mutable_data(); + int n = x->dims()[0]; + int c = x->dims()[1]; + int spatial_size = x->dims()[2] * x->dims()[3]; + + // compute mean + for (int i = 0; i < n * c; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float sum = 0.f; + for (int j = 0; j < spatial_size; ++j) { + sum += x_ptr[j]; + } + saved_mean_data[i] = sum / spatial_size; + } + // compute variance + for (int i = 0; i < n * c; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float sum = 0.f; + for (int j = 0; j < spatial_size; ++j) { + sum += (x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]); + } + saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon); + } + // compute out + for (int i = 0; i < n * c; ++i) { + const float* x_ptr = x_data + i * spatial_size; + float* y_ptr = y_data + i * spatial_size; + float scale_val = scale_data[i % c]; + float bias_val = bias_data[i % c]; + for (int j = 0; j < spatial_size; ++j) { + y_ptr[j] = + scale_val * (x_ptr[j] - saved_mean_data[i]) * saved_variance_data[i] + + bias_val; + } + } +} + +// #define INSTANCE_NORM_FP16_LOOP_TEST +// #define INSTANCE_NORM_FP16_PRINT_RESULT +TEST(instance_norm_image2d, compute) { +#ifdef INSTANCE_NORM_FP16_LOOP_TEST + for (auto n : {1, 3}) { + for (auto c : {1, 3, 8, 32, 65}) { + for (auto h : {4, 20, 64, 112, 224}) { + for (auto w : {2, 20, 64, 112, 224}) { +#else + const int n = 1; + const int c = 32; + const int h = 224; + const int w = 224; +#endif // INSTANCE_NORM_FP16_LOOP_TEST + + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " " + << h << " " << w << " ========"; + + auto kernels = + KernelRegistry::Global().Create("instance_norm", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + LOG(INFO) << "get kernel:" << kernel->doc(); + + lite::Tensor x, out, out_ref, scale, bias, saved_mean, saved_variance; + operators::InstanceNormParam param; + param.x = &x; + param.out = &out; + param.scale = &scale; + param.bias = &bias; + param.saved_mean = &saved_mean; + param.saved_variance = &saved_variance; + param.epsilon = 1e-5; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr instance_context(new KernelContext); + context->As().CopySharedTo( + &(instance_context->As())); + kernel->SetContext(std::move(instance_context)); + + const DDim in_dim = DDim(std::vector{n, c, h, w}); + x.Resize(in_dim); + out.Resize(in_dim); + out_ref.Resize(in_dim); + scale.Resize({c}); + bias.Resize({c}); + saved_mean.Resize({n * c}); + saved_variance.Resize({n * c}); + auto* x_data = x.mutable_data(); + auto* scale_data = scale.mutable_data(); + auto* bias_data = bias.mutable_data(); + auto* saved_mean_data = saved_mean.mutable_data(); + auto* saved_variance_data = saved_variance.mutable_data(); + std::default_random_engine engine; + std::uniform_real_distribution dist(-1, 1); + int sum = n * c * h * w; + for (int i = 0; i < sum; ++i) { + x_data[i] = dist(engine); + } + for (int i = 0; i < c; ++i) { + scale_data[i] = dist(engine); + bias_data[i] = dist(engine); + } + //! run reference instance norm + instance_norm_ref( + &x, &out_ref, &scale, &bias, &saved_mean, &saved_variance, 1e-5); + LOG(INFO) << "prepare input"; + CLImageConverterDefault* default_converter = + new CLImageConverterDefault(); + DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim); + LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " " + << x_image_shape[1]; + std::vector x_image_data(x_image_shape.production() * + 4); // 4 : RGBA + default_converter->NCHWToImage(x_data, x_image_data.data(), in_dim); + auto* x_image = x.mutable_data( + x_image_shape[0], x_image_shape[1], x_image_data.data()); + + auto* out_image = out.mutable_data( + x_image_shape[0], x_image_shape[1]); + + //! warm up + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel->Launch(); + } + context->As().cl_context()->GetCommandQueue().finish(); + //! compute + Timer t0; + t0.Start(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel->Launch(); + } + context->As().cl_context()->GetCommandQueue().finish(); + t0.Stop(); + double gops = 6 * sum; + LOG(INFO) << "avg time: " << t0.LapTimes().Avg() / FLAGS_repeats + << " ms, " + << "avg GOPs: " + << 1e-6 * gops * FLAGS_repeats / t0.LapTimes().Avg() + << " GOPs"; + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + half_t* out_image_data = new half_t[x_image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(out_image_data, + out_image, + x_image_shape[0], + x_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + float* out_data = new float[x_image_shape.production() * 4]; + default_converter->ImageToNCHW( + out_image_data, out_data, x_image_shape, in_dim); +// result +#ifdef INSTANCE_NORM_FP16_PRINT_RESULT + LOG(INFO) << "---- print kernel result (input -> output) ----"; + for (int eidx = 0; eidx < in_dim.production(); ++eidx) { + std::cout << x_data[eidx] << " -> " << out_data[eidx] << std::endl; + } +#endif // INSTANCE_NORM_FP16_PRINT_RESULT + auto* out_ref_data = out_ref.data(); + for (int i = 0; i < in_dim.production(); i++) { + auto abs_diff = abs(out_data[i] - out_ref_data[i]); + auto relative_diff = + COMPUTE_RELATIVE_DIFF(out_data[i], out_ref_data[i]); + EXPECT_EQ( + (relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << i << ", in_data[" << i + << "]: " << x_data[i] << ", out_data[" << i + << "]: " << out_data[i] << ", out_ref[" << i + << "]: " << out_ref_data[i] + << ", abs_diff: " << abs_diff + << ", relative_diff: " << relative_diff + << ", FP16_MAX_DIFF: " << FP16_MAX_DIFF; + } + } + delete[] out_data; + delete[] out_image_data; +#ifdef INSTANCE_NORM_FP16_LOOP_TEST + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(instance_norm, kOpenCL, kFP16, kImageDefault, ImageDefault); -- GitLab