diff --git a/mace/core/BUILD b/mace/core/BUILD index 63a30357240f99df964d1c23d2884739f61827db..5adf9010ec01e7d69292a92de492eaaf64f4654c 100644 --- a/mace/core/BUILD +++ b/mace/core/BUILD @@ -21,26 +21,46 @@ cc_library( ]), copts = ["-std=c++11"], deps = [ - "core", + ":logging", "@opencl_headers//:opencl20_headers", ], alwayslink = 1, ) + cc_library( - name = "core", - srcs = glob([ - "*.cc", - ]), - hdrs = glob([ - "*.h", - ]), + name = "logging", + srcs = [ + "logging.cc", + ], + hdrs = [ + "logging.h", + ], copts = ["-std=c++11"], linkopts = if_android([ "-llog", + ]), +) + +cc_library( + name = "core", + srcs = glob( + ["*.cc",], + exclude=[ + "logging.cc" + ]), + hdrs = glob( + ["*.h"], + exclude=[ + "logging.h" + ]), + copts = ["-std=c++11"], + linkopts = if_android([ "-pie", ]), deps = [ + ":logging", + ":opencl_runtime", "//mace/proto:cc_proto", "//mace/proto:stats_proto", "//mace/utils", diff --git a/mace/core/allocator.cc b/mace/core/allocator.cc index d05c45b352e37e2e7c67226aee28441a15c665b8..707ea4cb0e0a3dd267e229b6a7f52e39d42e9773 100644 --- a/mace/core/allocator.cc +++ b/mace/core/allocator.cc @@ -3,6 +3,7 @@ // #include "mace/core/allocator.h" +#include "mace/core/opencl_allocator.h" namespace mace { @@ -22,5 +23,6 @@ Allocator *GetDeviceAllocator(DeviceType type) { MACE_REGISTER_ALLOCATOR(DeviceType::CPU, new CPUAllocator()); MACE_REGISTER_ALLOCATOR(DeviceType::NEON, new CPUAllocator()); +MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator()); } // namespace mace diff --git a/mace/core/net.cc b/mace/core/net.cc index 22a2fd11ba014dcebe0cd03cd1e031dc251d1a49..2707ec5145b3c9dc3f94e0d2e17db6f7d2682d90 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -4,6 +4,7 @@ #include "mace/core/net.h" #include "mace/utils/utils.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" namespace mace { @@ -15,7 +16,7 @@ NetBase::NetBase(const std::shared_ptr &net_def, SimpleNet::SimpleNet(const std::shared_ptr &net_def, Workspace *ws, DeviceType type) - : NetBase(net_def, ws, type) { + : NetBase(net_def, ws, type), device_type_(type){ VLOG(1) << "Constructing SimpleNet " << net_def->name(); for (int idx = 0; idx < net_def->op_size(); ++idx) { const auto &operator_def = net_def->op(idx); @@ -47,6 +48,8 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def()); return false; } + if (device_type_ == DeviceType::OPENCL) + OpenCLRuntime::Get()->command_queue().finish(); if (op_stats) { op_stats->set_op_end_rel_micros(NowInMicroSec() - op_stats->all_start_micros()); diff --git a/mace/core/net.h b/mace/core/net.h index 541f1b8292eb67a60cd20a80b5526e2d993b4c63..013ca715cafce82242ad148d8ff12c7df8fd9fb4 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -40,6 +40,7 @@ class SimpleNet : public NetBase { protected: vector > operators_; + DeviceType device_type_; DISABLE_COPY_AND_ASSIGN(SimpleNet); }; diff --git a/mace/core/runtime/opencl/opencl_allocator.cc b/mace/core/opencl_allocator.cc similarity index 93% rename from mace/core/runtime/opencl/opencl_allocator.cc rename to mace/core/opencl_allocator.cc index d40432e2a2bb18b264d641d55f66868a26dafc22..5d96fe7dd7bb2e8c62dc2990e50a2dcf6b5e47ed 100644 --- a/mace/core/runtime/opencl/opencl_allocator.cc +++ b/mace/core/opencl_allocator.cc @@ -3,7 +3,7 @@ // #include "mace/core/runtime/opencl/cl2_header.h" -#include "mace/core/runtime/opencl/opencl_allocator.h" +#include "mace/core/opencl_allocator.h" #include "mace/core/runtime/opencl/opencl_runtime.h" namespace mace { @@ -49,6 +49,5 @@ void OpenCLAllocator::Unmap(void *buffer, void *mapped_ptr) { bool OpenCLAllocator::OnHost() { return false; } -MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator()); } // namespace mace diff --git a/mace/core/runtime/opencl/opencl_allocator.h b/mace/core/opencl_allocator.h similarity index 100% rename from mace/core/runtime/opencl/opencl_allocator.h rename to mace/core/opencl_allocator.h diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index e002b1418576be3ddbb6f7feaa3584dff18a17ae..7a3f3007e31915e1652b7e8e72b0e16d58e99bcd 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -20,7 +20,6 @@ cc_library( linkopts = if_android(["-lm"]), deps = [ "//mace/core", - "//mace/core:opencl_runtime", "//mace/utils", "//mace/utils:tuner", ], diff --git a/mace/kernels/opencl/cl/conv_2d_1x1.cl b/mace/kernels/opencl/cl/conv_2d_1x1.cl index d371698862460087fa67812282cc72b181e431df..8025074f9d51c1698aa7e93bcc78649cee776a5a 100644 --- a/mace/kernels/opencl/cl/conv_2d_1x1.cl +++ b/mace/kernels/opencl/cl/conv_2d_1x1.cl @@ -24,33 +24,87 @@ __kernel void conv_2d_1x1_naive(__global const float *input, /* n, c, h, w */ } } +#define vec_conv_2d_1x1_s1 \ + float4 in0 = vload4(0, input_ptr); \ + float4 in1 = vload4(0, input_ptr + in_pixel); \ + float4 in2 = vload4(0, input_ptr + 2 * in_pixel); \ + float4 in3 = vload4(0, input_ptr + 3 * in_pixel); + + +#define vec_conv_2d_1x1_s2 \ + float4 in00 = vload4(0, input_ptr); \ + float3 in01 = vload3(0, input_ptr + 4); \ + float4 in10 = vload4(0, input_ptr + in_pixel); \ + float3 in11 = vload3(0, input_ptr + in_pixel + 4); \ + float4 in20 = vload4(0, input_ptr + 2 * in_pixel); \ + float3 in21 = vload3(0, input_ptr + 2 * in_pixel + 4);\ + float4 in30 = vload4(0, input_ptr + 3 * in_pixel); \ + float3 in31 = vload3(0, input_ptr + 3 * in_pixel + 4); \ + float4 in0 = (float4)(in00.s02, in01.s02); \ + float4 in1 = (float4)(in10.s02, in11.s02); \ + float4 in2 = (float4)(in20.s02, in21.s02); \ + float4 in3 = (float4)(in30.s02, in31.s02); + + +#define vec_conv_2d_1x1_compute_loop \ + for (int oc = 0; oc < 4; ++oc) { \ + float4 weights = vload4(0, filter_ptr + oc * in_chan_num); \ + float4 out = vload4(0, output_ptr + oc * out_pixel); \ + out += in0 * weights.x; \ + out += in1 * weights.y; \ + out += in2 * weights.z; \ + out += in3 * weights.w; \ + vstore4(out, 0, output_ptr + oc * out_pixel); \ + } + +#define vec_conv_2d_1x1_compute \ + float4 weights = vload4(0, filter_ptr); \ + float4 out = vload4(0, output_ptr); \ + out += in0 * weights.x; \ + out += in1 * weights.y; \ + out += in2 * weights.z; \ + out += in3 * weights.w; \ + vstore4(out, 0, output_ptr); + __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ __global const float *filter, /* o, i, kh, kw */ __global const float *bias, /* o */ __global float *output, /* n, c, h, w */ __private const int in_chan_num, __private const int out_chan_num, - __private const int pixel_num) { + __private const int in_height, + __private const int in_width, + __private const int out_height, + __private const int out_width, + __private const int stride) { int batch = get_global_id(0); int out_chan_blk = get_global_id(1); int out_pixel_blk = get_global_id(2); + const int in_pixel = in_height * in_width; + const int out_pixel = out_height * out_width; + + const int round_out_width = (out_width + 3) / 4; + const int out_pixel_height = out_pixel_blk / round_out_width; + const int out_pixel_width = out_pixel_blk % round_out_width; + const int out_chan_begin = out_chan_blk * 4; const int out_chan_end = min(out_chan_begin + 4, out_chan_num); - const int out_pixel_begin = out_pixel_blk * 4; - const int out_pixel_end = min(out_pixel_begin + 4, pixel_num); + const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4; + const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width); + const int in_pixel_begin = out_pixel_height * stride * in_width + out_pixel_width * stride * 4; - const int in_offset = batch * in_chan_num * pixel_num; - const int out_offset = batch * out_chan_num * pixel_num; + const int in_offset = batch * in_chan_num * in_pixel; + const int out_offset = batch * out_chan_num * out_pixel; - const float *input_base = input + in_offset + out_pixel_begin; + const float *input_base = input + in_offset + in_pixel_begin; float *output_base = output + out_offset + out_pixel_begin; int out_chan_len = out_chan_end - out_chan_begin; int pixel_len = out_pixel_end - out_pixel_begin; for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { - float *output_ptr = output_base + out_chan * pixel_num; + float *output_ptr = output_base + out_chan * out_pixel; float bias_value = bias == NULL ? 0 : bias[out_chan]; for (int p = 0; p < pixel_len; ++p) { output_ptr[p] = bias_value; @@ -59,53 +113,51 @@ __kernel void conv_2d_1x1_v2(__global const float *input, /* n, c, h, w */ int in_chan = 0; if (pixel_len == 4) { - for (; in_chan + 3 < in_chan_num; in_chan += 4) { - const float *input_ptr = input_base + in_chan * pixel_num; - int out_chan = out_chan_begin; - for (; out_chan + 3 < out_chan_end; out_chan += 4) { - const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; - float *output_ptr = output_base + out_chan * pixel_num; - float4 in0 = vload4(0, input_ptr); - float4 in1 = vload4(0, input_ptr + pixel_num); - float4 in2 = vload4(0, input_ptr + 2 * pixel_num); - float4 in3 = vload4(0, input_ptr + 3 * pixel_num); - #pragma unroll - for (int oc = 0; oc < 4; ++oc) { - float4 weights = vload4(0, filter_ptr + oc * in_chan_num); - float4 out = vload4(0, output_ptr + oc * pixel_num); - out += in0 * weights.x; - out += in1 * weights.y; - out += in2 * weights.z; - out += in3 * weights.w; - vstore4(out, 0, output_ptr + oc * pixel_num); + if (stride == 1) { + for (; in_chan + 3 < in_chan_num; in_chan += 4) { + const float *input_ptr = input_base + in_chan * in_pixel; + int out_chan = out_chan_begin; + for (; out_chan + 3 < out_chan_end; out_chan += 4) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; + float *output_ptr = output_base + out_chan * out_pixel; + vec_conv_2d_1x1_s1; + vec_conv_2d_1x1_compute_loop; + } + for (; out_chan < out_chan_end; ++out_chan) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; + float *output_ptr = output_base + out_chan * out_pixel; + vec_conv_2d_1x1_s1; + vec_conv_2d_1x1_compute; } } - for (; out_chan < out_chan_end; ++out_chan) { - const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; - float *output_ptr = output_base + out_chan * pixel_num; - float4 weights = vload4(0, filter_ptr); - float4 in0 = vload4(0, input_ptr); - float4 in1 = vload4(0, input_ptr + pixel_num); - float4 in2 = vload4(0, input_ptr + 2 * pixel_num); - float4 in3 = vload4(0, input_ptr + 3 * pixel_num); - float4 out = vload4(0, output_ptr); - out += in0 * weights.x; - out += in1 * weights.y; - out += in2 * weights.z; - out += in3 * weights.w; - vstore4(out, 0, output_ptr); + } else if (stride == 2) { + for (; in_chan + 3 < in_chan_num; in_chan += 4) { + const float *input_ptr = input_base + in_chan * in_pixel; + int out_chan = out_chan_begin; + for (; out_chan + 3 < out_chan_end; out_chan += 4) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; + float *output_ptr = output_base + out_chan * out_pixel; + vec_conv_2d_1x1_s2; + vec_conv_2d_1x1_compute_loop; + } + for (; out_chan < out_chan_end; ++out_chan) { + const float* filter_ptr = filter + out_chan * in_chan_num + in_chan; + float *output_ptr = output_base + out_chan * out_pixel; + vec_conv_2d_1x1_s2; + vec_conv_2d_1x1_compute; + } } } } for (; in_chan < in_chan_num; ++in_chan) { - const float *input_ptr = input_base + in_chan * pixel_num; + const float *input_ptr = input_base + in_chan * in_pixel; for (int out_chan = out_chan_begin; out_chan < out_chan_end; ++out_chan) { float weights = filter[out_chan * in_chan_num + in_chan]; - float *output_ptr = output_base + out_chan * pixel_num; + float *output_ptr = output_base + out_chan * out_pixel; for (int p = 0; p < pixel_len; ++p) { - float in = input_ptr[p]; + float in = input_ptr[p*stride]; output_ptr[p] += in * weights; } } diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index b3f7735d5f6ac78e465d3bfefd2ab6aeed903250..317daaafe18f557f1469e9dde4fcbab460176109 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -41,14 +41,19 @@ void kernel conv_2d_3x3(global const float *input, if (pixels == 4) { float4 res = bias == NULL ? 0 : (float4)bias[i]; - for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) { - const float* input_ptr = input_base + in_chan_idx * in_pixel; - const float* filter_ptr = filter_base + in_chan_idx * 9; - if (stride_w == 1) { + + if (stride_w == 1) { + for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) { + const float* input_ptr = input_base + in_chan_idx * in_pixel; + const float* filter_ptr = filter_base + in_chan_idx * 9; res += conv1x3_s1(input_ptr + 0 * in_width, filter_ptr + 0 * 3); res += conv1x3_s1(input_ptr + 1 * in_width, filter_ptr + 1 * 3); res += conv1x3_s1(input_ptr + 2 * in_width, filter_ptr + 2 * 3); - } else { + } + } else { + for (int in_chan_idx = 0; in_chan_idx < in_chan_num; ++in_chan_idx) { + const float* input_ptr = input_base + in_chan_idx * in_pixel; + const float* filter_ptr = filter_base + in_chan_idx * 9; res += conv1x3_s2(input_ptr + 0 * in_width, filter_ptr + 0 * 3); res += conv1x3_s2(input_ptr + 1 * in_width, filter_ptr + 1 * 3); res += conv1x3_s2(input_ptr + 2 * in_width, filter_ptr + 2 * 3); diff --git a/mace/kernels/opencl/conv_2d_opencl.cc b/mace/kernels/opencl/conv_2d_opencl.cc index 2ff4a9c50da2e533d92d7e6dece0db285f91406f..ffb0314549f46ca64fee2ab4c88bc630459d0592 100644 --- a/mace/kernels/opencl/conv_2d_opencl.cc +++ b/mace/kernels/opencl/conv_2d_opencl.cc @@ -10,6 +10,9 @@ namespace kernels { extern void Conv2dOpenclK1x1S1(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output); +extern void Conv2dOpenclK1x1S2(const Tensor *input, const Tensor *filter, + const Tensor *bias, Tensor *output); + extern void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter, const Tensor *bias, Tensor *output); @@ -24,7 +27,7 @@ void Conv2dFunctor::operator()(const Tensor *input, const Tensor *bias, Tensor *output); // Selection matrix: kernel_size x stride_size static const Conv2dOpenclFunction selector[5][2] = { - {Conv2dOpenclK1x1S1, nullptr}, + {Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2}, {nullptr, nullptr}, {Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2}, {nullptr, nullptr}, diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index ba784d0552bd3f5a67558ab1392905db35ae2c4a..0c043b8c8758da3079e041f03d875a43fb2fd200 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -45,6 +45,7 @@ void Conv1x1Naive(const Tensor *input, void Conv1x1V2(const Tensor *input, const Tensor *filter, const Tensor *bias, + const int stride, Tensor *output) { const index_t batch = output->dim(0); const index_t channels = output->dim(1); @@ -54,9 +55,8 @@ void Conv1x1V2(const Tensor *input, auto runtime = OpenCLRuntime::Get(); auto program = runtime->program(); - const index_t pixels = height * width; const index_t channel_blocks = (channels + 3) / 4; - const index_t pixel_blocks = (pixels + 3) / 4; + const index_t pixel_blocks = (width + 3) / 4 * height; // TODO KernelFunctor has an extra clReleaseCommandQueue due to a copy // TODO check wired clReleaseCommandQueue latency @@ -77,7 +77,11 @@ void Conv1x1V2(const Tensor *input, conv_2d_kernel.setArg(idx++, *(static_cast(output->buffer()))); conv_2d_kernel.setArg(idx++, static_cast(input_channels)); conv_2d_kernel.setArg(idx++, static_cast(channels)); - conv_2d_kernel.setArg(idx++, static_cast(pixels)); + conv_2d_kernel.setArg(idx++, static_cast(input->dim(2))); + conv_2d_kernel.setArg(idx++, static_cast(input->dim(3))); + conv_2d_kernel.setArg(idx++, static_cast(height)); + conv_2d_kernel.setArg(idx++, static_cast(width)); + conv_2d_kernel.setArg(idx++, stride); auto command_queue = runtime->command_queue(); cl_int error = command_queue.enqueueNDRangeKernel( @@ -189,7 +193,16 @@ extern void Conv2dOpenclK1x1S1(const Tensor *input, MACE_CHECK(input_batch == batch && input_height == height && input_width == width); - Conv1x1V2(input, filter, bias, output); + Conv1x1V2(input, filter, bias, 1, output); +}; + +extern void Conv2dOpenclK1x1S2(const Tensor *input, + const Tensor *filter, + const Tensor *bias, + Tensor *output) { + MACE_CHECK(input->dim(0) == output->dim(0)); + + Conv1x1V2(input, filter, bias, 2, output); }; } // namespace kernels diff --git a/mace/ops/BUILD b/mace/ops/BUILD index e823136d965670cc198ed14c0f682cbf0b152d00..683d87abfbb5c013a758147298643eeadd0cf4b5 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -17,7 +17,6 @@ cc_library( ], deps = [ "//mace/core", - "//mace/core:opencl_runtime", "@gtest//:gtest", ], ) diff --git a/mace/tools/benchmark/benchmark_model.cc b/mace/tools/benchmark/benchmark_model.cc index 6ecfc1f4ed417e13d91c161a1bf9149be147684d..d4ae7b5d3bf754bca7d3d369b966ccea22ffdab4 100644 --- a/mace/tools/benchmark/benchmark_model.cc +++ b/mace/tools/benchmark/benchmark_model.cc @@ -42,6 +42,7 @@ bool SplitAndParseToInts(const string &str, tmp = tmp.substr(next_offset + 1); } } + return true; } } // namespace str_util @@ -254,6 +255,10 @@ int Main(int argc, char **argv) { stats_options.show_summary = show_summary; stats.reset(new StatSummarizer(stats_options)); + DeviceType device_type; + DeviceType_Parse(device, &device_type); + VLOG(0) << device_type; + // load model std::ifstream model_file_stream(model_file, std::ios::in | std::ios::binary); if (!model_file_stream.is_open()) { @@ -265,29 +270,30 @@ int Main(int argc, char **argv) { model_file_stream.close(); Workspace ws; - ws.LoadModelTensor(net_def, DeviceType::CPU); + ws.LoadModelTensor(net_def, device_type); // Load inputs for (size_t i = 0; i < inputs_count; ++i) { Tensor *input_tensor = - ws.CreateTensor(input_layers[i], GetDeviceAllocator(DeviceType::CPU), DT_FLOAT); + ws.CreateTensor(input_layers[i], GetDeviceAllocator(device_type), DT_FLOAT); vector shapes; str_util::SplitAndParseToInts(input_layer_shapes[i], ',', &shapes); input_tensor->Resize(shapes); - float *input_data = input_tensor->mutable_data(); - - // load input - if (i < input_layer_files.size()) { - std::ifstream in_file(input_layer_files[i], - std::ios::in | std::ios::binary); - in_file.read(reinterpret_cast(input_data), - input_tensor->size() * sizeof(float)); - in_file.close(); + { + Tensor::MappingGuard input_guard(input_tensor); + float *input_data = input_tensor->mutable_data(); + + // load input + if (i < input_layer_files.size()) { + std::ifstream in_file(input_layer_files[i], + std::ios::in | std::ios::binary); + in_file.read(reinterpret_cast(input_data), + input_tensor->size() * sizeof(float)); + in_file.close(); + } } } // create net - DeviceType device_type; - DeviceType_Parse(device, &device_type); auto net = CreateNet(net_def, &ws, device_type); int64_t warmup_time_us = 0;