diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index a2b4c5cbe34a550d40224afb9e6c2d77aeacb980..bb2cea3b55dba1d32debfb6d30333f93259b2614 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -8,11 +8,9 @@ #include #include -#include #include "mace/core/logging.h" #include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/core/runtime/opencl/opencl_wrapper.h" namespace mace { namespace { @@ -66,7 +64,7 @@ bool BuildProgram(OpenCLRuntime *runtime, }; *program = cl::Program(runtime->context(), sources); - std::string build_options = "-Werror -cl-mad-enable -I" + path; + std::string build_options = "-Werror -cl-mad-enable -cl-fast-relaxed-math -I" + path; // TODO(heliangliang) -cl-unsafe-math-optimizations -cl-fast-relaxed-math if (program->build({runtime->device()}, build_options.c_str()) != CL_SUCCESS) { if (program->getBuildInfo(runtime->device()) == diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index 057b2a80320130d322a1146698cfcb40334ca0a4..e18c9f93dae38ff8307adecbf321e65252d47427 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -20,15 +20,18 @@ namespace mace { class OpenCLRuntime { public: static OpenCLRuntime *Get(); - OpenCLRuntime(cl::Context context, - cl::Device device, - cl::CommandQueue command_queue); - ~OpenCLRuntime(); cl::Context &context(); cl::Device &device(); cl::CommandQueue &command_queue(); cl::Program &program(); + private: + OpenCLRuntime(cl::Context context, + cl::Device device, + cl::CommandQueue command_queue); + ~OpenCLRuntime(); + OpenCLRuntime(const OpenCLRuntime&) = delete; + OpenCLRuntime &operator=(const OpenCLRuntime&) = delete; private: cl::Context context_; diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 8999750f49156ff82edd4b75ccd4690937bacf9d..67b810f5149ea51e2456071415857bca467abbf0 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -24,25 +24,21 @@ void BatchNormFunctor::operator()( auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); - auto batch_norm_kernel = - cl::KernelFunctor(program, "batch_norm"); - cl_int error; - auto res_event = batch_norm_kernel(cl::EnqueueArgs(runtime->command_queue(), - cl::NDRange(n * channel * sample_size), - cl::NDRange(128)), - *(static_cast(input->buffer())), - *(static_cast(scale->buffer())), - *(static_cast(offset->buffer())), - *(static_cast(mean->buffer())), - *(static_cast(var->buffer())), - *(static_cast(epsilon->buffer())), - static_cast(channel), - static_cast(sample_size), - *(static_cast(output->buffer())), - error); - res_event.wait(); + auto _kernel = cl::Kernel(program, "batch_norm"); + _kernel.setArg(0, *(static_cast(input->buffer()))); + _kernel.setArg(1, *(static_cast(scale->buffer()))); + _kernel.setArg(2, *(static_cast(offset->buffer()))); + _kernel.setArg(3, *(static_cast(mean->buffer()))); + _kernel.setArg(4, *(static_cast(var->buffer()))); + _kernel.setArg(5, *(static_cast(epsilon->buffer()))); + _kernel.setArg(6, static_cast(sample_size)); + _kernel.setArg(7, *(static_cast(output->buffer()))); + _kernel.setArg(8, 32u, nullptr); + _kernel.setArg(9, 32u, nullptr); + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + _kernel, cl::NullRange, + cl::NDRange(n, channel, sample_size), + cl::NDRange(1, 1, 128)); MACE_CHECK(error == CL_SUCCESS); } diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index f0d1b77e1ed34f82aedf19c7c44fc32b141bbea5..d5927071f222b6ff0c0cdb7dd32e2b15979983f8 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -4,16 +4,28 @@ void kernel batch_norm(global const float *input, global const float *mean, global const float *var, global const float *epsilon, - private const int channels, private const int pixels, - global float *output) { - int idx = get_global_id(0); - int channel = (idx % (channels * pixels)) / pixels; + global float *output, + __local float *new_scale, + __local float *new_offset) { + const int batch = get_global_id(0); + const int channel = get_global_id(1); + const int channels = get_global_size(1); + const int pixel_offset = get_global_id(2); + const unsigned int local_channel = get_local_id(1); + const int local_pixel_idx = get_local_id(2); - const float *input_ptr = input + idx; - const float new_scale = scale[channel] * rsqrt(var[channel] + *epsilon); - const float new_offset = offset[channel] - mean[channel] * new_scale; - float *output_ptr = output + idx; - *output_ptr = new_scale * *input_ptr + new_offset; + if(local_pixel_idx == 0) { + new_scale[local_channel] = scale[channel] * rsqrt(var[channel] + *epsilon); + new_offset[local_channel] = offset[channel] - mean[channel] * new_scale[local_channel]; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const int sample_offset = (batch * channels + channel) * pixels + pixel_offset; + + const float *input_ptr = input + sample_offset; + float *output_ptr = output + sample_offset; + *output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel]; } diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 83574b53373fd52226aef06ebd80f392131c6732..e823136d965670cc198ed14c0f682cbf0b152d00 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -17,6 +17,7 @@ cc_library( ], deps = [ "//mace/core", + "//mace/core:opencl_runtime", "@gtest//:gtest", ], ) @@ -39,7 +40,6 @@ cc_library( "-fopenmp", ], deps = [ - "//mace/core", "//mace/kernels", "//mace/proto:cc_proto", ], @@ -72,7 +72,6 @@ cc_test( deps = [ ":ops", ":test", - "//mace/core", "//mace/core:test_benchmark_main", ], ) diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index 2276eaeb2ecf6d677641478abed4bfde45ded078..3d17aca7ec153060629a551e4eba3829c22d4529 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -34,11 +34,13 @@ static void BatchNorm( // Warm-up for (int i = 0; i < 5; ++i) { net.RunOp(D); + net.Sync(); } mace::testing::StartTiming(); while (iters--) { net.RunOp(D); + net.Sync(); } } diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index ceb4963905273a0449ffac3ad5be502c333f42d1..4c5d73bbe3981ec2af972a5c370ec2794c22ac2d 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -208,6 +208,7 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { // Run NEON net.RunOp(DeviceType::OPENCL); + net.Sync(); // Check Tensor expected; diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 3b2ddfe0f69dcac9f624ff67907ad640a4f5505a..678f855fc7b98b5c7ec66f13e07e75ab121e067e 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -11,6 +11,7 @@ #include "mace/core/common.h" #include "mace/core/net.h" #include "mace/core/tensor.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" namespace mace { @@ -152,6 +153,12 @@ class OpsTestNet { return ws_.GetTensor(output_name); } + void Sync() { + if (net_) { + OpenCLRuntime::Get()->command_queue().finish(); + } + } + public: Workspace ws_; OperatorDef op_def_;