提交 e2ae6261 编写于 作者: L liuqi

Add sync function for opencl test and benchmark.

上级 bcec92d0
...@@ -8,11 +8,9 @@ ...@@ -8,11 +8,9 @@
#include <mutex> #include <mutex>
#include <dirent.h> #include <dirent.h>
#include <errno.h>
#include "mace/core/logging.h" #include "mace/core/logging.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/opencl_wrapper.h"
namespace mace { namespace mace {
namespace { namespace {
...@@ -66,7 +64,7 @@ bool BuildProgram(OpenCLRuntime *runtime, ...@@ -66,7 +64,7 @@ bool BuildProgram(OpenCLRuntime *runtime,
}; };
*program = cl::Program(runtime->context(), sources); *program = cl::Program(runtime->context(), sources);
std::string build_options = "-Werror -cl-mad-enable -I" + path; std::string build_options = "-Werror -cl-mad-enable -cl-fast-relaxed-math -I" + path;
// TODO(heliangliang) -cl-unsafe-math-optimizations -cl-fast-relaxed-math // TODO(heliangliang) -cl-unsafe-math-optimizations -cl-fast-relaxed-math
if (program->build({runtime->device()}, build_options.c_str()) != CL_SUCCESS) { if (program->build({runtime->device()}, build_options.c_str()) != CL_SUCCESS) {
if (program->getBuildInfo<CL_PROGRAM_BUILD_STATUS>(runtime->device()) == if (program->getBuildInfo<CL_PROGRAM_BUILD_STATUS>(runtime->device()) ==
......
...@@ -20,15 +20,18 @@ namespace mace { ...@@ -20,15 +20,18 @@ namespace mace {
class OpenCLRuntime { class OpenCLRuntime {
public: public:
static OpenCLRuntime *Get(); static OpenCLRuntime *Get();
OpenCLRuntime(cl::Context context,
cl::Device device,
cl::CommandQueue command_queue);
~OpenCLRuntime();
cl::Context &context(); cl::Context &context();
cl::Device &device(); cl::Device &device();
cl::CommandQueue &command_queue(); cl::CommandQueue &command_queue();
cl::Program &program(); cl::Program &program();
private:
OpenCLRuntime(cl::Context context,
cl::Device device,
cl::CommandQueue command_queue);
~OpenCLRuntime();
OpenCLRuntime(const OpenCLRuntime&) = delete;
OpenCLRuntime &operator=(const OpenCLRuntime&) = delete;
private: private:
cl::Context context_; cl::Context context_;
......
...@@ -24,25 +24,21 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()( ...@@ -24,25 +24,21 @@ void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); auto program = runtime->program();
auto batch_norm_kernel = auto _kernel = cl::Kernel(program, "batch_norm");
cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer, _kernel.setArg(0, *(static_cast<const cl::Buffer *>(input->buffer())));
cl::Buffer, cl::Buffer, cl::Buffer, _kernel.setArg(1, *(static_cast<cl::Buffer *>(scale->buffer())));
int, int, cl::Buffer>(program, "batch_norm"); _kernel.setArg(2, *(static_cast<cl::Buffer *>(offset->buffer())));
cl_int error; _kernel.setArg(3, *(static_cast<cl::Buffer *>(mean->buffer())));
auto res_event = batch_norm_kernel(cl::EnqueueArgs(runtime->command_queue(), _kernel.setArg(4, *(static_cast<cl::Buffer *>(var->buffer())));
cl::NDRange(n * channel * sample_size), _kernel.setArg(5, *(static_cast<cl::Buffer *>(epsilon->buffer())));
cl::NDRange(128)), _kernel.setArg(6, static_cast<int>(sample_size));
*(static_cast<const cl::Buffer *>(input->buffer())), _kernel.setArg(7, *(static_cast<cl::Buffer *>(output->buffer())));
*(static_cast<cl::Buffer *>(scale->buffer())), _kernel.setArg(8, 32u, nullptr);
*(static_cast<cl::Buffer *>(offset->buffer())), _kernel.setArg(9, 32u, nullptr);
*(static_cast<cl::Buffer *>(mean->buffer())), cl_int error = runtime->command_queue().enqueueNDRangeKernel(
*(static_cast<cl::Buffer *>(var->buffer())), _kernel, cl::NullRange,
*(static_cast<cl::Buffer *>(epsilon->buffer())), cl::NDRange(n, channel, sample_size),
static_cast<int>(channel), cl::NDRange(1, 1, 128));
static_cast<int>(sample_size),
*(static_cast<cl::Buffer *>(output->buffer())),
error);
res_event.wait();
MACE_CHECK(error == CL_SUCCESS); MACE_CHECK(error == CL_SUCCESS);
} }
......
...@@ -4,16 +4,28 @@ void kernel batch_norm(global const float *input, ...@@ -4,16 +4,28 @@ void kernel batch_norm(global const float *input,
global const float *mean, global const float *mean,
global const float *var, global const float *var,
global const float *epsilon, global const float *epsilon,
private const int channels,
private const int pixels, private const int pixels,
global float *output) { global float *output,
int idx = get_global_id(0); __local float *new_scale,
int channel = (idx % (channels * pixels)) / pixels; __local float *new_offset) {
const int batch = get_global_id(0);
const int channel = get_global_id(1);
const int channels = get_global_size(1);
const int pixel_offset = get_global_id(2);
const unsigned int local_channel = get_local_id(1);
const int local_pixel_idx = get_local_id(2);
const float *input_ptr = input + idx; if(local_pixel_idx == 0) {
const float new_scale = scale[channel] * rsqrt(var[channel] + *epsilon); new_scale[local_channel] = scale[channel] * rsqrt(var[channel] + *epsilon);
const float new_offset = offset[channel] - mean[channel] * new_scale; new_offset[local_channel] = offset[channel] - mean[channel] * new_scale[local_channel];
float *output_ptr = output + idx; }
*output_ptr = new_scale * *input_ptr + new_offset;
barrier(CLK_LOCAL_MEM_FENCE);
const int sample_offset = (batch * channels + channel) * pixels + pixel_offset;
const float *input_ptr = input + sample_offset;
float *output_ptr = output + sample_offset;
*output_ptr = new_scale[local_channel] * *input_ptr + new_offset[local_channel];
} }
...@@ -17,6 +17,7 @@ cc_library( ...@@ -17,6 +17,7 @@ cc_library(
], ],
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/core:opencl_runtime",
"@gtest//:gtest", "@gtest//:gtest",
], ],
) )
...@@ -39,7 +40,6 @@ cc_library( ...@@ -39,7 +40,6 @@ cc_library(
"-fopenmp", "-fopenmp",
], ],
deps = [ deps = [
"//mace/core",
"//mace/kernels", "//mace/kernels",
"//mace/proto:cc_proto", "//mace/proto:cc_proto",
], ],
...@@ -72,7 +72,6 @@ cc_test( ...@@ -72,7 +72,6 @@ cc_test(
deps = [ deps = [
":ops", ":ops",
":test", ":test",
"//mace/core",
"//mace/core:test_benchmark_main", "//mace/core:test_benchmark_main",
], ],
) )
...@@ -34,11 +34,13 @@ static void BatchNorm( ...@@ -34,11 +34,13 @@ static void BatchNorm(
// Warm-up // Warm-up
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
net.RunOp(D); net.RunOp(D);
net.Sync();
} }
mace::testing::StartTiming(); mace::testing::StartTiming();
while (iters--) { while (iters--) {
net.RunOp(D); net.RunOp(D);
net.Sync();
} }
} }
......
...@@ -208,6 +208,7 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { ...@@ -208,6 +208,7 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
// Run NEON // Run NEON
net.RunOp(DeviceType::OPENCL); net.RunOp(DeviceType::OPENCL);
net.Sync();
// Check // Check
Tensor expected; Tensor expected;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "mace/core/common.h" #include "mace/core/common.h"
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace { namespace mace {
...@@ -152,6 +153,12 @@ class OpsTestNet { ...@@ -152,6 +153,12 @@ class OpsTestNet {
return ws_.GetTensor(output_name); return ws_.GetTensor(output_name);
} }
void Sync() {
if (net_) {
OpenCLRuntime::Get()->command_queue().finish();
}
}
public: public:
Workspace ws_; Workspace ws_;
OperatorDef op_def_; OperatorDef op_def_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册