提交 bf78cf25 编写于 作者: L Liangliang He

Add OpenCL runtime

上级 b646fd87
......@@ -12,22 +12,29 @@ load("//mace:mace.bzl", "if_android")
cc_library(
name = "opencl_runtime",
srcs = glob([
"platform/opencl/cl.hpp",
"platform/opencl/cl2.hpp",
"platform/opencl/opencl_wrapper.h",
"platform/opencl/opencl_wrapper.cc",
"runtime/opencl/cl.hpp",
"runtime/opencl/cl2.hpp",
"runtime/opencl/opencl_allocator.cc",
"runtime/opencl/opencl_wrapper.cc",
"runtime/opencl/opencl_runtime.cc",
]),
hdrs = [
"runtime/opencl/opencl_allocator.h",
"runtime/opencl/opencl_runtime.h",
"runtime/opencl/opencl_wrapper.h",
],
copts = ["-std=c++11"],
deps = [
"@opencl_headers//:opencl20_headers",
"core",
"@opencl_headers//:opencl20_headers",
],
alwayslink = 1,
)
cc_binary(
name = "opencl_smoketest",
srcs = glob([
"platform/opencl/opencl_smoketest.cc",
"runtime/opencl/opencl_smoketest.cc",
]),
copts = ["-std=c++11"],
deps = [
......@@ -44,15 +51,15 @@ cc_library(
"*.h",
]),
copts = ["-std=c++11"],
deps = [
"//mace/proto:cc_proto",
"//mace/proto:stats_proto",
"//mace/utils:utils",
],
linkopts = if_android([
"-llog",
"-pie",
]),
deps = [
"//mace/proto:cc_proto",
"//mace/proto:stats_proto",
"//mace/utils",
],
)
# Main program for tests
......
......@@ -6,20 +6,21 @@
namespace mace {
static std::unique_ptr<CPUAllocator> g_cpu_allocator(new CPUAllocator());
CPUAllocator* cpu_allocator() { return g_cpu_allocator.get(); }
void SetCPUAllocator(CPUAllocator* alloc) { g_cpu_allocator.reset(alloc); }
std::map<int32_t, Allocator *> *gAllocatorRegistry() {
static std::map<int32_t, Allocator *> g_allocator_registry;
return &g_allocator_registry;
}
Allocator* GetDeviceAllocator(DeviceType type) {
switch (type) {
case DeviceType::CPU:
case DeviceType::NEON:
return cpu_allocator();
default:
MACE_CHECK(false, "device type ", type, " is not supported.");
Allocator *GetDeviceAllocator(DeviceType type) {
auto iter = gAllocatorRegistry()->find(type);
if (iter == gAllocatorRegistry()->end()) {
LOG(ERROR) << "Allocator not found for device " << type;
return nullptr;
}
return nullptr;
return iter->second;
}
MACE_REGISTER_ALLOCATOR(DeviceType::CPU, new CPUAllocator());
MACE_REGISTER_ALLOCATOR(DeviceType::NEON, new CPUAllocator());
} // namespace mace
......@@ -8,6 +8,7 @@
#include <malloc.h>
#include "mace/core/common.h"
#include "mace/core/registry.h"
#include "mace/proto/mace.pb.h"
namespace mace {
......@@ -24,17 +25,19 @@ class Allocator {
public:
Allocator() {}
virtual ~Allocator() noexcept {}
virtual void* New(size_t nbytes) = 0;
virtual void Delete(void* data) = 0;
virtual void CopyBytes(void* dst, const void* src, size_t size) = 0;
virtual void *New(size_t nbytes) = 0;
virtual void Delete(void *data) = 0;
virtual void *Map(void *buffer, size_t nbytes) = 0;
virtual void Unmap(void *buffer, void *mapper_ptr) = 0;
virtual bool OnHost() = 0;
template <typename T>
T* New(size_t num_elements) {
T *New(size_t num_elements) {
if (num_elements > (std::numeric_limits<size_t>::max() / sizeof(T))) {
return nullptr;
}
void* p = New(sizeof(T) * num_elements);
T* typed_p = reinterpret_cast<T*>(p);
void *p = New(sizeof(T) * num_elements);
T *typed_p = reinterpret_cast<T *>(p);
return typed_p;
}
};
......@@ -42,8 +45,8 @@ class Allocator {
class CPUAllocator : public Allocator {
public:
~CPUAllocator() override {}
void* New(size_t nbytes) override {
void* data = nullptr;
void *New(size_t nbytes) override {
void *data = nullptr;
#ifdef __ANDROID__
data = memalign(kMaceAlignment, nbytes);
#else
......@@ -55,33 +58,32 @@ class CPUAllocator : public Allocator {
return data;
}
void Delete(void* data) override { free(data); }
void CopyBytes(void* dst, const void* src, size_t size) override {
memcpy(dst, src, size);
}
void Delete(void *data) override { free(data); }
void *Map(void *buffer, size_t nbytes) { return buffer; }
void Unmap(void *buffer, void *mapper_ptr) {}
bool OnHost() { return true; }
};
// Get the CPU Alloctor.
CPUAllocator* cpu_allocator();
// Sets the CPU allocator to the given allocator: the caller gives away the
// ownership of the pointer.
void SetCPUAllocator(CPUAllocator* alloc);
std::map<int32_t, Allocator *> *gAllocatorRegistry();
template <DeviceType D>
struct DeviceContext {};
template <>
struct DeviceContext<DeviceType::CPU> {
static Allocator* allocator() { return cpu_allocator(); }
};
Allocator *GetDeviceAllocator(DeviceType type);
template <>
struct DeviceContext<DeviceType::NEON> {
static Allocator* allocator() { return cpu_allocator(); }
struct AllocatorRegisterer {
explicit AllocatorRegisterer(DeviceType type, Allocator *alloc) {
if (gAllocatorRegistry()->count(type)) {
LOG(ERROR) << "Allocator for device type " << type
<< " registered twice. This should not happen."
<< gAllocatorRegistry()->count(type);
std::exit(1);
}
gAllocatorRegistry()->emplace(type, alloc);
}
};
Allocator* GetDeviceAllocator(DeviceType type);
#define MACE_REGISTER_ALLOCATOR(type, alloc) \
namespace { \
static AllocatorRegisterer MACE_ANONYMOUS_VARIABLE(Allocator)(type, alloc); \
}
} // namespace mace
......
......@@ -6,31 +6,37 @@
namespace mace {
std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry() {
static std::map<int32_t, OperatorRegistry*> g_device_type_registry;
std::map<int32_t, OperatorRegistry *> *gDeviceTypeRegistry() {
static std::map<int32_t, OperatorRegistry *> g_device_type_registry;
return &g_device_type_registry;
}
MACE_DEFINE_REGISTRY(CPUOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
const OperatorDef &,
Workspace *);
MACE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry);
MACE_DEFINE_REGISTRY(NEONOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
const OperatorDef &,
Workspace *);
MACE_REGISTER_DEVICE_TYPE(DeviceType::NEON, NEONOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator(const OperatorDef& operator_def,
Workspace* ws,
MACE_DEFINE_REGISTRY(OPENCLOperatorRegistry,
OperatorBase,
const OperatorDef &,
Workspace *);
MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws,
DeviceType type) {
OperatorRegistry* registry = gDeviceTypeRegistry()->at(type);
OperatorRegistry *registry = gDeviceTypeRegistry()->at(type);
return registry->Create(operator_def.type(), operator_def, ws);
}
OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws)
: operator_ws_(ws),
operator_def_(std::make_shared<OperatorDef>(operator_def)) {}
......
......@@ -92,7 +92,7 @@ class Operator : public OperatorBase {
for (const string &output_str : operator_def.output()) {
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, DeviceContext<D>::allocator(), DataTypeToEnum<T>::v())));
output_str, GetDeviceAllocator(D), DataTypeToEnum<T>::v())));
}
}
virtual bool Run() override = 0;
......@@ -160,6 +160,16 @@ MACE_DECLARE_REGISTRY(NEONOperatorRegistry,
#define REGISTER_NEON_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(NEONOperatorRegistry, name, __VA_ARGS__)
MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry,
OperatorBase,
const OperatorDef &,
Workspace *);
#define REGISTER_OPENCL_OPERATOR_CREATOR(key, ...) \
MACE_REGISTER_CREATOR(OPENCLOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_OPENCL_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(OPENCLOperatorRegistry, name, __VA_ARGS__)
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws,
DeviceType type);
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#define CL_HPP_MINIMUM_OPENCL_VERSION 200
#define CL_HPP_TARGET_OPENCL_VERSION 200
#include "mace/core/logging.h"
#include "mace/core/platform/opencl/cl2.hpp"
#include "mace/core/platform/opencl/opencl_wrapper.h"
int main() {
LOG(INFO) << "OpenCL support: " << mace::OpenCLSupported();
if (!mace::OpenCLSupported()) return 1;
LOG(INFO) << "Start OpenCL test";
// get all platforms (drivers)
std::vector<cl::Platform> all_platforms;
cl::Platform::get(&all_platforms);
if (all_platforms.size() == 0) {
LOG(INFO) << " No OpenCL platforms found";
return 1;
}
LOG(INFO) << "Platform sizes: " << all_platforms.size();
cl::Platform default_platform = all_platforms[0];
LOG(INFO) << "Using platform: "
<< default_platform.getInfo<CL_PLATFORM_NAME>() << ", "
<< default_platform.getInfo<CL_PLATFORM_PROFILE>() << ", "
<< default_platform.getInfo<CL_PLATFORM_VERSION>();
// get default device (CPUs, GPUs) of the default platform
std::vector<cl::Device> all_devices;
default_platform.getDevices(CL_DEVICE_TYPE_ALL, &all_devices);
if (all_devices.size() == 0) {
LOG(INFO) << "No OpenCL devices found";
return 1;
}
// Use the last device
cl::Device default_device = *all_devices.rbegin();
LOG(INFO) << "Using device: " << default_device.getInfo<CL_DEVICE_NAME>()
<< ", " << default_device.getInfo<CL_DEVICE_TYPE>();
// a context is like a "runtime link" to the device and platform;
// i.e. communication is possible
cl::Context context({default_device});
// create the program that we want to execute on the device
cl::Program::Sources sources;
// calculates for each element; C = A + B
std::string kernel_code =
" void kernel simple_add(global const int* A, global const int* B, "
"global int* C, "
" global const int* N) {"
" int ID, Nthreads, n, ratio, start, stop;"
""
" ID = get_global_id(0);"
" Nthreads = get_global_size(0);"
" n = N[0];"
""
" ratio = (n / Nthreads);" // number of elements for each thread
" start = ratio * ID;"
" stop = ratio * (ID + 1);"
""
" for (int i=start; i<stop; i++)"
" C[i] = A[i] + B[i];"
" }";
sources.push_back({kernel_code.c_str(), kernel_code.length()});
cl::Program program(context, sources);
if (program.build({default_device}) != CL_SUCCESS) {
LOG(INFO) << "Error building: "
<< program.getBuildInfo<CL_PROGRAM_BUILD_LOG>(default_device);
return 1;
}
// apparently OpenCL only likes arrays ...
// N holds the number of elements in the vectors we want to add
int N[1] = {1000};
int n = N[0];
// create buffers on device (allocate space on GPU)
cl::Buffer buffer_A(context, CL_MEM_READ_WRITE, sizeof(int) * n);
cl::Buffer buffer_B(context, CL_MEM_READ_WRITE, sizeof(int) * n);
cl::Buffer buffer_C(context, CL_MEM_READ_WRITE, sizeof(int) * n);
cl::Buffer buffer_N(context, CL_MEM_READ_ONLY, sizeof(int));
// create things on here (CPU)
int A[n], B[n];
for (int i = 0; i < n; i++) {
A[i] = i;
B[i] = 2 * i;
}
// create a queue (a queue of commands that the GPU will execute)
cl::CommandQueue queue(context, default_device);
// push write commands to queue
queue.enqueueWriteBuffer(buffer_A, CL_TRUE, 0, sizeof(int) * n, A);
queue.enqueueWriteBuffer(buffer_B, CL_TRUE, 0, sizeof(int) * n, B);
queue.enqueueWriteBuffer(buffer_N, CL_TRUE, 0, sizeof(int), N);
auto simple_add =
cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer, cl::Buffer>(
program, "simple_add");
cl_int error;
simple_add(cl::EnqueueArgs(queue, cl::NDRange(100), cl::NDRange(10)),
buffer_A, buffer_B, buffer_C, buffer_N, error);
if (error != 0) {
LOG(ERROR) << "Failed to execute kernel " << error;
}
int C[n];
// read result from GPU to here
queue.enqueueReadBuffer(buffer_C, CL_TRUE, 0, sizeof(int) * n, C);
bool correct = true;
for (int i = 0; i < n; i++) {
if (C[i] != A[i] + B[i]) correct = false;
}
LOG(INFO) << "OpenCL test result: " << (correct ? "correct" : "incorrect");
return 0;
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "CL/opencl.h"
#include "mace/core/logging.h"
#include <dlfcn.h>
#include <mutex>
namespace mace {
bool OpenCLSupported();
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/runtime/opencl/opencl_allocator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/cl2.hpp"
namespace mace {
OpenCLAllocator::OpenCLAllocator() {}
OpenCLAllocator::~OpenCLAllocator() {}
void *OpenCLAllocator::New(size_t nbytes) {
cl_int error;
cl::Buffer *buffer = new cl::Buffer(OpenCLRuntime::Get()->context(),
CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR,
nbytes, nullptr, &error);
MACE_CHECK(error == CL_SUCCESS);
MACE_CHECK_NOTNULL(buffer);
return static_cast<void *>(buffer);
}
void OpenCLAllocator::Delete(void *buffer) {
if (buffer != nullptr) {
cl::Buffer *cl_buffer = static_cast<cl::Buffer *>(buffer);
delete cl_buffer;
}
}
void *OpenCLAllocator::Map(void *buffer, size_t nbytes) {
auto cl_buffer = static_cast<cl::Buffer *>(buffer);
auto queue = OpenCLRuntime::Get()->command_queue();
// TODO (heliangliang) Non-blocking call
cl_int error;
void *mapped_ptr =
queue.enqueueMapBuffer(*cl_buffer, CL_TRUE, CL_MAP_READ | CL_MAP_WRITE, 0,
nbytes, nullptr, nullptr, &error);
MACE_CHECK(error == CL_SUCCESS);
return mapped_ptr;
}
void OpenCLAllocator::Unmap(void *buffer, void *mapped_ptr) {
auto cl_buffer = static_cast<cl::Buffer *>(buffer);
auto queue = OpenCLRuntime::Get()->command_queue();
MACE_CHECK(queue.enqueueUnmapMemObject(*cl_buffer, mapped_ptr, nullptr,
nullptr) == CL_SUCCESS);
}
bool OpenCLAllocator::OnHost() { return false; }
MACE_REGISTER_ALLOCATOR(DeviceType::OPENCL, new OpenCLAllocator());
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_RUNTIME_OPENCL_OPENCL_ALLOCATOR_H_
#define MACE_CORE_RUNTIME_OPENCL_OPENCL_ALLOCATOR_H_
#include "mace/core/allocator.h"
namespace mace {
class OpenCLAllocator : public Allocator {
public:
OpenCLAllocator();
~OpenCLAllocator() override;
void *New(size_t nbytes) override;
void Delete(void *buffer) override;
void *Map(void *buffer, size_t nbytes) override;
void Unmap(void *buffer, void *mapped_ptr) override;
bool OnHost() override;
};
} // namespace mace
#endif // MACE_CORE_RUNTIME_OPENCL_OPENCL_ALLOCATOR_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <cstdlib>
#include <fstream>
#include <mutex>
#include "mace/core/logging.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/opencl_wrapper.h"
namespace mace {
namespace {
bool ReadSourceFile(const char *filename, std::string *content) {
MACE_CHECK_NOTNULL(filename);
MACE_CHECK_NOTNULL(content);
*content = "";
std::ifstream ifs(filename, std::ifstream::in);
if (!ifs.is_open()) {
LOG(ERROR) << "Failed to open file " << filename;
return false;
}
std::string line;
while (std::getline(ifs, line)) {
*content += line;
}
ifs.close();
return true;
}
bool BuildProgram(OpenCLRuntime *runtime, const char *filename, cl::Program *program) {
MACE_CHECK_NOTNULL(filename);
MACE_CHECK_NOTNULL(program);
std::string kernel_code;
if (!ReadSourceFile(filename, &kernel_code)) {
LOG(ERROR) << "Failed to read kernel source " << filename;
return false;
}
cl::Program::Sources sources;
sources.push_back({kernel_code.c_str(), kernel_code.length()});
*program = cl::Program(runtime->context(), sources);
if (program->build({runtime->device()}) != CL_SUCCESS) {
LOG(INFO) << "Error building: "
<< program->getBuildInfo<CL_PROGRAM_BUILD_LOG>(runtime->device());
return false;
}
return true;
}
} // namespace
OpenCLRuntime *OpenCLRuntime::Get() {
static std::once_flag init_once;
static OpenCLRuntime *instance = nullptr;
std::call_once(init_once, []() {
if (!mace::OpenCLLibrary::Supported()) {
LOG(ERROR) << "OpenCL not supported";
return;
}
std::vector<cl::Platform> all_platforms;
cl::Platform::get(&all_platforms);
if (all_platforms.size() == 0) {
LOG(ERROR) << "No OpenCL platforms found";
return;
}
cl::Platform default_platform = all_platforms[0];
VLOG(1) << "Using platform: "
<< default_platform.getInfo<CL_PLATFORM_NAME>() << ", "
<< default_platform.getInfo<CL_PLATFORM_PROFILE>() << ", "
<< default_platform.getInfo<CL_PLATFORM_VERSION>();
// get default device (CPUs, GPUs) of the default platform
std::vector<cl::Device> all_devices;
default_platform.getDevices(CL_DEVICE_TYPE_ALL, &all_devices);
if (all_devices.size() == 0) {
LOG(ERROR) << "No OpenCL devices found";
return;
}
bool gpu_detected = false;
cl::Device gpu_device;
for (auto device : all_devices) {
if (device.getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_GPU) {
gpu_device = device;
gpu_detected = true;
VLOG(1) << "Using device: " << device.getInfo<CL_DEVICE_NAME>();
break;
}
}
if (!gpu_detected) {
LOG(ERROR) << "No GPU device found";
return;
}
// a context is like a "runtime link" to the device and platform;
// i.e. communication is possible
cl::Context context({gpu_device});
cl::CommandQueue command_queue(context, gpu_device);
instance = new OpenCLRuntime(context, gpu_device, command_queue);
});
return instance;
}
OpenCLRuntime::OpenCLRuntime(cl::Context context,
cl::Device device,
cl::CommandQueue command_queue)
: context_(context), device_(device), command_queue_(command_queue) {}
OpenCLRuntime::~OpenCLRuntime() {}
cl::Context &OpenCLRuntime::context() { return context_; }
cl::Device &OpenCLRuntime::device() { return device_; }
cl::CommandQueue &OpenCLRuntime::command_queue() { return command_queue_; }
cl::Program OpenCLRuntime::GetProgram(const std::string &name) {
static const char *kernel_source_path = getenv("MACE_KERNEL_SOURCE_PATH");
std::string filename = name;
if (kernel_source_path != nullptr) {
filename = kernel_source_path + name;
}
std::lock_guard<std::mutex> lock(program_lock_);
// TODO (heliangliang) Support binary format
auto iter = programs_.find(name);
if (iter != programs_.end()) {
return iter->second;
} else {
cl::Program program;
MACE_CHECK(BuildProgram(this, filename.c_str(), &program));
programs_.emplace(name, program);
return program;
}
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_RUNTIME_OPENCL_OPENCL_RUNTIME_H_
#define MACE_CORE_RUNTIME_OPENCL_OPENCL_RUNTIME_H_
#ifndef CL_HPP_TARGET_OPENCL_VERSION
#define CL_HPP_TARGET_OPENCL_VERSION 200
#endif
#include <map>
#include <mutex>
#include "mace/core/runtime/opencl/cl2.hpp"
#include "mace/core/runtime/opencl/opencl_wrapper.h"
namespace mace {
class OpenCLRuntime {
public:
static OpenCLRuntime *Get();
OpenCLRuntime(cl::Context context,
cl::Device device,
cl::CommandQueue command_queue);
~OpenCLRuntime();
cl::Context &context();
cl::Device &device();
cl::CommandQueue &command_queue();
cl::Program GetProgram(const std::string &name);
private:
cl::Context context_;
cl::CommandQueue command_queue_;
cl::Device device_;
std::map<std::string, cl::Program> programs_;
std::mutex program_lock_;
};
} // namespace mace
#endif // MACE_CORE_RUNTIME_OPENCL_OPENCL_RUNTIME_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/logging.h"
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/opencl_wrapper.h"
int main() {
using namespace mace;
auto runtime = mace::OpenCLRuntime::Get();
mace::Tensor ta(GetDeviceAllocator(DeviceType::OPENCL), DataType::DT_INT32);
mace::Tensor tb(GetDeviceAllocator(DeviceType::OPENCL), DataType::DT_INT32);
mace::Tensor tc(GetDeviceAllocator(DeviceType::OPENCL), DataType::DT_INT32);
mace::Tensor tstep(GetDeviceAllocator(DeviceType::OPENCL),
DataType::DT_INT32);
int n = 1000;
std::vector<index_t> shape = {n};
ta.Resize(shape);
tb.Resize(shape);
tc.Resize(shape);
tstep.Resize({1});
int step_size = 10;
int global_size = n / step_size;
{
mace::Tensor::MappingGuard ta_mapper(&ta);
mace::Tensor::MappingGuard tb_mapper(&tb);
mace::Tensor::MappingGuard tstep_mapper(&tstep);
int32_t *a = ta.mutable_data<int32_t>();
int32_t *b = tb.mutable_data<int32_t>();
int32_t *step = tstep.mutable_data<int32_t>();
for (int i = 0; i < n; i++) {
a[i] = i;
b[i] = 2 * i;
}
step[0] = step_size;
}
auto program = runtime->GetProgram("simple_add.cl");
auto simple_add =
cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer, cl::Buffer>(
program, "simple_add");
cl_int error;
simple_add(cl::EnqueueArgs(runtime->command_queue(), cl::NDRange(global_size),
cl::NullRange),
*(static_cast<cl::Buffer *>(ta.buffer())),
*(static_cast<cl::Buffer *>(tb.buffer())),
*(static_cast<cl::Buffer *>(tc.buffer())),
*(static_cast<cl::Buffer *>(tstep.buffer())), error);
if (error != 0) {
LOG(ERROR) << "Failed to execute kernel " << error;
}
{
mace::Tensor::MappingGuard ta_mapper(&ta);
mace::Tensor::MappingGuard tb_mapper(&tb);
mace::Tensor::MappingGuard tc_mapper(&tc);
int32_t *a = ta.mutable_data<int32_t>();
int32_t *b = tb.mutable_data<int32_t>();
int32_t *c = tc.mutable_data<int32_t>();
bool correct = true;
for (int i = 0; i < n; i++) {
if (c[i] != a[i] + b[i]) correct = false;
}
LOG(INFO) << "OpenCL test result: " << (correct ? "correct" : "incorrect");
}
return 0;
}
......@@ -5,7 +5,7 @@
#include "CL/opencl.h"
#include "mace/core/logging.h"
#include "mace/core/platform/opencl/opencl_wrapper.h"
#include "mace/core/runtime/opencl/opencl_wrapper.h"
#include <dlfcn.h>
#include <mutex>
......@@ -14,10 +14,14 @@
* Wrapper of OpenCL 2.0 (based on 1.2)
*/
namespace mace {
class OpenCLStub final {
namespace {
class OpenCLLibraryImpl final {
public:
static OpenCLStub &Get();
bool loaded() { return loaded_; }
static OpenCLLibraryImpl &Get();
bool Load();
void Unload();
bool loaded() { return handle_ != nullptr; }
using clGetPlatformIDsFunc = cl_int (*)(cl_uint, cl_platform_id *, cl_uint *);
using clGetPlatformInfoFunc =
......@@ -177,21 +181,23 @@ class OpenCLStub final {
#undef DEFINE_FUNC_PTR
private:
bool TryLoadAll();
bool Load(const std::string &library);
bool loaded_ = false;
void *LoadFromPath(const std::string &path);
void *handle_ = nullptr;
};
OpenCLStub &OpenCLStub::Get() {
OpenCLLibraryImpl &OpenCLLibraryImpl::Get() {
static std::once_flag load_once;
static OpenCLStub instance;
std::call_once(load_once, []() { instance.TryLoadAll(); });
static OpenCLLibraryImpl instance;
std::call_once(load_once, []() { instance.Load(); });
return instance;
}
bool OpenCLStub::TryLoadAll() {
bool OpenCLLibraryImpl::Load() {
if (loaded()) return true;
// TODO (heliangliang) Make this configurable
static const std::vector<std::string> pathes = {
// TODO (heliangliang) Benchmark 64 bit overhead
static const std::vector<std::string> paths = {
#if defined(__aarch64__)
// Qualcomm Adreno
"/system/vendor/lib64/libOpenCL.so",
......@@ -209,9 +215,11 @@ bool OpenCLStub::TryLoadAll() {
#endif
};
for (const auto &path : pathes) {
for (const auto &path : paths) {
VLOG(2) << "Loading OpenCL from " << path;
if (Load(path)) {
void *handle = LoadFromPath(path);
if (handle != nullptr) {
handle_ = handle;
return true;
}
}
......@@ -220,13 +228,22 @@ bool OpenCLStub::TryLoadAll() {
return false;
}
bool OpenCLStub::Load(const std::string &path) {
void OpenCLLibraryImpl::Unload() {
if (handle_ != nullptr) {
if (dlclose(handle_) != 0) {
LOG(ERROR) << "dlclose failed for OpenCL library";
}
handle_ = nullptr;
}
}
void *OpenCLLibraryImpl::LoadFromPath(const std::string &path) {
void *handle = dlopen(path.c_str(), RTLD_LAZY | RTLD_LOCAL);
if (handle == nullptr) {
VLOG(2) << "Failed to load OpenCL library from path " << path
<< " error code: " << dlerror();
return false;
return nullptr;
}
#define ASSIGN_FROM_DLSYM(func) \
......@@ -234,9 +251,8 @@ bool OpenCLStub::Load(const std::string &path) {
void *ptr = dlsym(handle, #func); \
if (ptr == nullptr) { \
LOG(ERROR) << "Failed to load " << #func << " from " << path; \
loaded_ = false; \
dlclose(handle); \
return false; \
return nullptr; \
} \
func = reinterpret_cast<func##Func>(ptr); \
VLOG(2) << "Loaded " << #func << " from " << path; \
......@@ -283,19 +299,22 @@ bool OpenCLStub::Load(const std::string &path) {
#undef ASSIGN_FROM_DLSYM
loaded_ = true;
// TODO (heliangliang) Call dlclose if we are dlclosed
return true;
return handle;
}
} // namespace
bool OpenCLLibrary::Supported() { return OpenCLLibraryImpl::Get().loaded(); }
void OpenCLLibrary::Load() { OpenCLLibraryImpl::Get().Load(); }
bool OpenCLSupported() { return OpenCLStub::Get().loaded(); }
void OpenCLLibrary::Unload() { OpenCLLibraryImpl::Get().Unload(); }
} // namespace mace
cl_int clGetPlatformIDs(cl_uint num_entries,
cl_platform_id *platforms,
cl_uint *num_platforms) {
auto func = mace::OpenCLStub::Get().clGetPlatformIDs;
auto func = mace::OpenCLLibraryImpl::Get().clGetPlatformIDs;
if (func != nullptr) {
return func(num_entries, platforms, num_platforms);
} else {
......@@ -307,7 +326,7 @@ cl_int clGetPlatformInfo(cl_platform_id platform,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
auto func = mace::OpenCLStub::Get().clGetPlatformInfo;
auto func = mace::OpenCLLibraryImpl::Get().clGetPlatformInfo;
if (func != nullptr) {
return func(platform, param_name, param_value_size, param_value,
param_value_size_ret);
......@@ -323,7 +342,7 @@ cl_int clBuildProgram(cl_program program,
void(CL_CALLBACK *pfn_notify)(cl_program program,
void *user_data),
void *user_data) {
auto func = mace::OpenCLStub::Get().clBuildProgram;
auto func = mace::OpenCLLibraryImpl::Get().clBuildProgram;
if (func != nullptr) {
return func(program, num_devices, device_list, options, pfn_notify,
user_data);
......@@ -341,7 +360,7 @@ cl_int clEnqueueNDRangeKernel(cl_command_queue command_queue,
cl_uint num_events_in_wait_list,
const cl_event *event_wait_list,
cl_event *event) {
auto func = mace::OpenCLStub::Get().clEnqueueNDRangeKernel;
auto func = mace::OpenCLLibraryImpl::Get().clEnqueueNDRangeKernel;
if (func != nullptr) {
return func(command_queue, kernel, work_dim, global_work_offset,
global_work_size, local_work_size, num_events_in_wait_list,
......@@ -355,7 +374,7 @@ cl_int clSetKernelArg(cl_kernel kernel,
cl_uint arg_index,
size_t arg_size,
const void *arg_value) {
auto func = mace::OpenCLStub::Get().clSetKernelArg;
auto func = mace::OpenCLLibraryImpl::Get().clSetKernelArg;
if (func != nullptr) {
return func(kernel, arg_index, arg_size, arg_value);
} else {
......@@ -364,7 +383,7 @@ cl_int clSetKernelArg(cl_kernel kernel,
}
cl_int clRetainMemObject(cl_mem memobj) {
auto func = mace::OpenCLStub::Get().clRetainMemObject;
auto func = mace::OpenCLLibraryImpl::Get().clRetainMemObject;
if (func != nullptr) {
return func(memobj);
} else {
......@@ -373,7 +392,7 @@ cl_int clRetainMemObject(cl_mem memobj) {
}
cl_int clReleaseMemObject(cl_mem memobj) {
auto func = mace::OpenCLStub::Get().clReleaseMemObject;
auto func = mace::OpenCLLibraryImpl::Get().clReleaseMemObject;
if (func != nullptr) {
return func(memobj);
} else {
......@@ -387,7 +406,7 @@ cl_int clEnqueueUnmapMemObject(cl_command_queue command_queue,
cl_uint num_events_in_wait_list,
const cl_event *event_wait_list,
cl_event *event) {
auto func = mace::OpenCLStub::Get().clEnqueueUnmapMemObject;
auto func = mace::OpenCLLibraryImpl::Get().clEnqueueUnmapMemObject;
if (func != nullptr) {
return func(command_queue, memobj, mapped_ptr, num_events_in_wait_list,
event_wait_list, event);
......@@ -397,7 +416,7 @@ cl_int clEnqueueUnmapMemObject(cl_command_queue command_queue,
}
cl_int clRetainCommandQueue(cl_command_queue command_queue) {
auto func = mace::OpenCLStub::Get().clRetainCommandQueue;
auto func = mace::OpenCLLibraryImpl::Get().clRetainCommandQueue;
if (func != nullptr) {
return func(command_queue);
} else {
......@@ -411,7 +430,7 @@ cl_context clCreateContext(
void(CL_CALLBACK *pfn_notify)(const char *, const void *, size_t, void *),
void *user_data,
cl_int *errcode_ret) {
auto func = mace::OpenCLStub::Get().clCreateContext;
auto func = mace::OpenCLLibraryImpl::Get().clCreateContext;
if (func != nullptr) {
return func(properties, num_devices, devices, pfn_notify, user_data,
errcode_ret);
......@@ -425,7 +444,7 @@ cl_context clCreateContextFromType(
void(CL_CALLBACK *pfn_notify)(const char *, const void *, size_t, void *),
void *user_data,
cl_int *errcode_ret) {
auto func = mace::OpenCLStub::Get().clCreateContextFromType;
auto func = mace::OpenCLLibraryImpl::Get().clCreateContextFromType;
if (func != nullptr) {
return func(properties, device_type, pfn_notify, user_data, errcode_ret);
} else {
......@@ -434,7 +453,7 @@ cl_context clCreateContextFromType(
}
cl_int clReleaseContext(cl_context context) {
auto func = mace::OpenCLStub::Get().clReleaseContext;
auto func = mace::OpenCLLibraryImpl::Get().clReleaseContext;
if (func != nullptr) {
return func(context);
} else {
......@@ -443,7 +462,7 @@ cl_int clReleaseContext(cl_context context) {
}
cl_int clWaitForEvents(cl_uint num_events, const cl_event *event_list) {
auto func = mace::OpenCLStub::Get().clWaitForEvents;
auto func = mace::OpenCLLibraryImpl::Get().clWaitForEvents;
if (func != nullptr) {
return func(num_events, event_list);
} else {
......@@ -452,7 +471,7 @@ cl_int clWaitForEvents(cl_uint num_events, const cl_event *event_list) {
}
cl_int clReleaseEvent(cl_event event) {
auto func = mace::OpenCLStub::Get().clReleaseEvent;
auto func = mace::OpenCLLibraryImpl::Get().clReleaseEvent;
if (func != nullptr) {
return func(event);
} else {
......@@ -469,7 +488,7 @@ cl_int clEnqueueWriteBuffer(cl_command_queue command_queue,
cl_uint num_events_in_wait_list,
const cl_event *event_wait_list,
cl_event *event) {
auto func = mace::OpenCLStub::Get().clEnqueueWriteBuffer;
auto func = mace::OpenCLLibraryImpl::Get().clEnqueueWriteBuffer;
if (func != nullptr) {
return func(command_queue, buffer, blocking_write, offset, size, ptr,
num_events_in_wait_list, event_wait_list, event);
......@@ -487,7 +506,7 @@ cl_int clEnqueueReadBuffer(cl_command_queue command_queue,
cl_uint num_events_in_wait_list,
const cl_event *event_wait_list,
cl_event *event) {
auto func = mace::OpenCLStub::Get().clEnqueueReadBuffer;
auto func = mace::OpenCLLibraryImpl::Get().clEnqueueReadBuffer;
if (func != nullptr) {
return func(command_queue, buffer, blocking_read, offset, size, ptr,
num_events_in_wait_list, event_wait_list, event);
......@@ -502,7 +521,7 @@ cl_int clGetProgramBuildInfo(cl_program program,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
auto func = mace::OpenCLStub::Get().clGetProgramBuildInfo;
auto func = mace::OpenCLLibraryImpl::Get().clGetProgramBuildInfo;
if (func != nullptr) {
return func(program, device, param_name, param_value_size, param_value,
param_value_size_ret);
......@@ -512,7 +531,7 @@ cl_int clGetProgramBuildInfo(cl_program program,
}
cl_int clRetainProgram(cl_program program) {
auto func = mace::OpenCLStub::Get().clRetainProgram;
auto func = mace::OpenCLLibraryImpl::Get().clRetainProgram;
if (func != nullptr) {
return func(program);
} else {
......@@ -530,7 +549,7 @@ void *clEnqueueMapBuffer(cl_command_queue command_queue,
const cl_event *event_wait_list,
cl_event *event,
cl_int *errcode_ret) {
auto func = mace::OpenCLStub::Get().clEnqueueMapBuffer;
auto func = mace::OpenCLLibraryImpl::Get().clEnqueueMapBuffer;
if (func != nullptr) {
return func(command_queue, buffer, blocking_map, map_flags, offset, size,
num_events_in_wait_list, event_wait_list, event, errcode_ret);
......@@ -546,7 +565,7 @@ cl_command_queue clCreateCommandQueueWithProperties(
cl_device_id device,
const cl_queue_properties *properties,
cl_int *errcode_ret) {
auto func = mace::OpenCLStub::Get().clCreateCommandQueueWithProperties;
auto func = mace::OpenCLLibraryImpl::Get().clCreateCommandQueueWithProperties;
if (func != nullptr) {
return func(context, device, properties, errcode_ret);
} else {
......@@ -555,7 +574,7 @@ cl_command_queue clCreateCommandQueueWithProperties(
}
cl_int clReleaseCommandQueue(cl_command_queue command_queue) {
auto func = mace::OpenCLStub::Get().clReleaseCommandQueue;
auto func = mace::OpenCLLibraryImpl::Get().clReleaseCommandQueue;
if (func != nullptr) {
return func(command_queue);
} else {
......@@ -570,7 +589,7 @@ cl_program clCreateProgramWithBinary(cl_context context,
const unsigned char **binaries,
cl_int *binary_status,
cl_int *errcode_ret) {
auto func = mace::OpenCLStub::Get().clCreateProgramWithBinary;
auto func = mace::OpenCLLibraryImpl::Get().clCreateProgramWithBinary;
if (func != nullptr) {
return func(context, num_devices, device_list, lengths, binaries,
binary_status, errcode_ret);
......@@ -583,7 +602,7 @@ cl_program clCreateProgramWithBinary(cl_context context,
}
cl_int clRetainContext(cl_context context) {
auto func = mace::OpenCLStub::Get().clRetainContext;
auto func = mace::OpenCLLibraryImpl::Get().clRetainContext;
if (func != nullptr) {
return func(context);
} else {
......@@ -596,7 +615,7 @@ cl_int clGetContextInfo(cl_context context,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
auto func = mace::OpenCLStub::Get().clGetContextInfo;
auto func = mace::OpenCLLibraryImpl::Get().clGetContextInfo;
if (func != nullptr) {
return func(context, param_name, param_value_size, param_value,
param_value_size_ret);
......@@ -606,7 +625,7 @@ cl_int clGetContextInfo(cl_context context,
}
cl_int clReleaseProgram(cl_program program) {
auto func = mace::OpenCLStub::Get().clReleaseProgram;
auto func = mace::OpenCLLibraryImpl::Get().clReleaseProgram;
if (func != nullptr) {
return func(program);
} else {
......@@ -615,7 +634,7 @@ cl_int clReleaseProgram(cl_program program) {
}
cl_int clFlush(cl_command_queue command_queue) {
auto func = mace::OpenCLStub::Get().clFlush;
auto func = mace::OpenCLLibraryImpl::Get().clFlush;
if (func != nullptr) {
return func(command_queue);
} else {
......@@ -624,7 +643,7 @@ cl_int clFlush(cl_command_queue command_queue) {
}
cl_int clFinish(cl_command_queue command_queue) {
auto func = mace::OpenCLStub::Get().clFinish;
auto func = mace::OpenCLLibraryImpl::Get().clFinish;
if (func != nullptr) {
return func(command_queue);
} else {
......@@ -637,7 +656,7 @@ cl_int clGetProgramInfo(cl_program program,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
auto func = mace::OpenCLStub::Get().clGetProgramInfo;
auto func = mace::OpenCLLibraryImpl::Get().clGetProgramInfo;
if (func != nullptr) {
return func(program, param_name, param_value_size, param_value,
param_value_size_ret);
......@@ -649,7 +668,7 @@ cl_int clGetProgramInfo(cl_program program,
cl_kernel clCreateKernel(cl_program program,
const char *kernel_name,
cl_int *errcode_ret) {
auto func = mace::OpenCLStub::Get().clCreateKernel;
auto func = mace::OpenCLLibraryImpl::Get().clCreateKernel;
if (func != nullptr) {
return func(program, kernel_name, errcode_ret);
} else {
......@@ -661,7 +680,7 @@ cl_kernel clCreateKernel(cl_program program,
}
cl_int clRetainKernel(cl_kernel kernel) {
auto func = mace::OpenCLStub::Get().clRetainKernel;
auto func = mace::OpenCLLibraryImpl::Get().clRetainKernel;
if (func != nullptr) {
return func(kernel);
} else {
......@@ -674,7 +693,7 @@ cl_mem clCreateBuffer(cl_context context,
size_t size,
void *host_ptr,
cl_int *errcode_ret) {
auto func = mace::OpenCLStub::Get().clCreateBuffer;
auto func = mace::OpenCLLibraryImpl::Get().clCreateBuffer;
if (func != nullptr) {
return func(context, flags, size, host_ptr, errcode_ret);
} else {
......@@ -690,7 +709,7 @@ cl_program clCreateProgramWithSource(cl_context context,
const char **strings,
const size_t *lengths,
cl_int *errcode_ret) {
auto func = mace::OpenCLStub::Get().clCreateProgramWithSource;
auto func = mace::OpenCLLibraryImpl::Get().clCreateProgramWithSource;
if (func != nullptr) {
return func(context, count, strings, lengths, errcode_ret);
} else {
......@@ -702,7 +721,7 @@ cl_program clCreateProgramWithSource(cl_context context,
}
cl_int clReleaseKernel(cl_kernel kernel) {
auto func = mace::OpenCLStub::Get().clReleaseKernel;
auto func = mace::OpenCLLibraryImpl::Get().clReleaseKernel;
if (func != nullptr) {
return func(kernel);
} else {
......@@ -715,7 +734,7 @@ cl_int clGetDeviceIDs(cl_platform_id platform,
cl_uint num_entries,
cl_device_id *devices,
cl_uint *num_devices) {
auto func = mace::OpenCLStub::Get().clGetDeviceIDs;
auto func = mace::OpenCLLibraryImpl::Get().clGetDeviceIDs;
if (func != nullptr) {
return func(platform, device_type, num_entries, devices, num_devices);
} else {
......@@ -728,7 +747,7 @@ cl_int clGetDeviceInfo(cl_device_id device,
size_t param_value_size,
void *param_value,
size_t *param_value_size_ret) {
auto func = mace::OpenCLStub::Get().clGetDeviceInfo;
auto func = mace::OpenCLLibraryImpl::Get().clGetDeviceInfo;
if (func != nullptr) {
return func(device, param_name, param_value_size, param_value,
param_value_size_ret);
......@@ -738,7 +757,7 @@ cl_int clGetDeviceInfo(cl_device_id device,
}
cl_int clRetainDevice(cl_device_id device) {
auto func = mace::OpenCLStub::Get().clRetainDevice;
auto func = mace::OpenCLLibraryImpl::Get().clRetainDevice;
if (func != nullptr) {
return func(device);
} else {
......@@ -747,7 +766,7 @@ cl_int clRetainDevice(cl_device_id device) {
}
cl_int clReleaseDevice(cl_device_id device) {
auto func = mace::OpenCLStub::Get().clReleaseDevice;
auto func = mace::OpenCLLibraryImpl::Get().clReleaseDevice;
if (func != nullptr) {
return func(device);
} else {
......@@ -756,7 +775,7 @@ cl_int clReleaseDevice(cl_device_id device) {
}
cl_int clRetainEvent(cl_event event) {
auto func = mace::OpenCLStub::Get().clRetainEvent;
auto func = mace::OpenCLLibraryImpl::Get().clRetainEvent;
if (func != nullptr) {
return func(event);
} else {
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_RUNTIME_OPENCL_OPENCL_WRAPPER_H_
#define MACE_CORE_RUNTIME_OPENCL_OPENCL_WRAPPER_H_
namespace mace {
class OpenCLLibrary {
public:
static bool Supported();
static void Load();
static void Unload();
};
} // namespace mace
#endif // MACE_CORE_RUNTIME_OPENCL_OPENCL_WRAPPER_H_
......@@ -48,20 +48,30 @@ namespace mace {
class Tensor {
public:
Tensor()
: alloc_(cpu_allocator()), size_(0), dtype_(DT_FLOAT), data_(nullptr){};
Tensor(Allocator* a, DataType type)
: alloc_(a), size_(0), dtype_(type), data_(nullptr){};
: alloc_(GetDeviceAllocator(DeviceType::CPU)),
size_(0),
dtype_(DT_FLOAT),
buffer_(nullptr),
data_(nullptr){};
Tensor(Allocator *alloc, DataType type)
: alloc_(alloc),
size_(0),
dtype_(type),
buffer_(nullptr),
data_(nullptr){};
~Tensor() {
if (alloc_ && data_.get()) {
data_.reset();
MACE_CHECK(data_ == nullptr, "Buffer must be unmapped before destroy");
if (buffer_ != nullptr) {
MACE_CHECK_NOTNULL(alloc_);
alloc_->Delete(buffer_);
}
}
inline DataType dtype() const { return dtype_; }
inline const vector<index_t>& shape() const { return shape_; }
inline const vector<index_t> &shape() const { return shape_; }
inline index_t dim_size() const { return shape_.size(); }
......@@ -72,69 +82,93 @@ class Tensor {
inline index_t size() const { return size_; }
inline const void* raw_data() const {
MACE_CHECK(data_.get() || size_ == 0);
return data_.get();
inline const bool OnHost() const { return alloc_->OnHost(); }
/*
* Map the device buffer as CPU buffer to access the data, unmap must be
* called later
*/
inline void Map() {
if (!OnHost()) {
MACE_CHECK(buffer_ != nullptr && data_ == nullptr);
data_ = alloc_->Map(buffer_, size_ * SizeOfType());
}
}
template <typename T>
inline const T* data() const {
MACE_CHECK(
data_.get() || size_ == 0,
"The tensor is of non-zero shape, but its data is not allocated yet. ");
return static_cast<T*>(data_.get());
}
inline void* raw_mutable_data() {
if (data_.get() || size_ == 0) {
return data_.get();
} else {
CASES(dtype_, data_.reset(alloc_->New(size_ * sizeof(T)),
[this](void* ptr) { alloc_->Delete(ptr); }));
return data_.get();
/*
* Unmap the device buffer
*/
inline void Unmap() {
if (!OnHost()) {
MACE_CHECK(buffer_ != nullptr && data_ != nullptr);
alloc_->Unmap(buffer_, data_);
data_ = nullptr;
}
}
void *buffer() const { return buffer_; }
inline const void *raw_data() const {
void *data = MappedBuffer();
MACE_CHECK(data != nullptr || size_ == 0,
"The tensor is of non-zero shape, but its data is not allocated "
"or mapped yet.");
return data;
}
template <typename T>
inline T* mutable_data() {
if (size_ == 0 || data_.get()) {
return static_cast<T*>(data_.get());
}
return static_cast<T*>(raw_mutable_data());
inline const T *data() const {
return static_cast<const T *>(raw_data());
}
inline void *raw_mutable_data() {
void *data = MappedBuffer();
MACE_CHECK(data != nullptr || size_ == 0,
"The tensor is of non-zero shape, but its data is not allocated "
"or mapped yet.");
return data;
}
inline void Resize(const vector<index_t>& shape) {
template <typename T>
inline T *mutable_data() {
return static_cast<T *>(raw_mutable_data());
}
inline void Resize(const vector<index_t> &shape) {
shape_ = shape;
index_t size = NumElements();
if (size_ != size) {
size_ = size;
data_.reset();
MACE_CHECK(data_ == nullptr, "Buffer must be unmapped before resize");
alloc_->Delete(buffer_);
CASES(dtype_, buffer_ = alloc_->New(size_ * sizeof(T)));
}
}
inline void ResizeLike(const Tensor& other) { Resize(other.shape()); }
inline void ResizeLike(const Tensor &other) { Resize(other.shape()); }
inline void ResizeLike(const Tensor* other) { Resize(other->shape()); }
inline void ResizeLike(const Tensor *other) { Resize(other->shape()); }
template <typename T>
inline void Copy(const T* src, index_t size) {
inline void Copy(const T *src, index_t size) {
MACE_CHECK(size == size_, "copy src and dst with different size.");
CopyBytes(static_cast<const void*>(src), sizeof(T) * size);
CopyBytes(static_cast<const void *>(src), sizeof(T) * size);
}
template <typename SrcType, typename DstType>
inline void CopyWithCast(const SrcType* src, size_t size) {
inline void CopyWithCast(const SrcType *src, size_t size) {
MACE_CHECK(static_cast<index_t>(size) == size_,
"copy src and dst with different size.");
unique_ptr<DstType[]> buffer(new DstType[size]);
for (size_t i = 0; i < size; ++i) {
buffer[i] = static_cast<DstType>(src[i]);
}
CopyBytes(static_cast<const void*>(buffer.get()), sizeof(DstType) * size);
CopyBytes(static_cast<const void *>(buffer.get()), sizeof(DstType) * size);
}
inline void CopyBytes(const void* src, size_t size) {
alloc_->CopyBytes(raw_mutable_data(), src, size);
inline void CopyBytes(const void *src, size_t size) {
MappingGuard map_this(this);
memcpy(raw_mutable_data(), src, size);
}
inline void DebugPrint() const {
......@@ -159,24 +193,47 @@ class Tensor {
return type_size;
}
inline void Copy(const Tensor& other) {
inline void Copy(Tensor &other) {
alloc_ = other.alloc_;
dtype_ = other.dtype_;
ResizeLike(other);
const void* other_data = other.raw_data();
memcpy(raw_mutable_data(), other_data, size_ * SizeOfType());
MappingGuard map_other(&other);
CopyBytes(other.raw_data(), size_ * SizeOfType());
}
class MappingGuard {
public:
MappingGuard(Tensor *tensor) : tensor_(tensor) {
MACE_ASSERT(tensor_ != nullptr);
tensor_->Map();
}
~MappingGuard() { tensor_->Unmap(); }
private:
Tensor *tensor_;
};
private:
inline int64_t NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1,
std::multiplies<int64_t>());
}
Allocator* alloc_;
inline void *MappedBuffer() const {
if (OnHost()) {
return buffer_;
}
return data_;
}
Allocator *alloc_;
index_t size_;
DataType dtype_;
std::shared_ptr<void> data_;
// Raw buffer, must be mapped as host accessable data before
// read or write
void *buffer_;
// Mapped buffer
void *data_;
vector<index_t> shape_;
DISABLE_COPY_AND_ASSIGN(Tensor);
......
void kernel simple_add(global const int *a,
global const int *b,
global int *c,
global const int *step) {
int id = get_global_id(0);
int start = step[0] * id;
int stop = start + step[0];
for (int i = start; i < stop; i++) c[i] = a[i] + b[i];
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册