提交 67f9c391 编写于 作者: L liuqi

Add RWLock and make FileStorage and OpenCLRuntime thread-safe.

上级 eb1e5131
......@@ -29,6 +29,7 @@ api_test:
script:
- if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi
- python tools/bazel_adb_run.py --target="//mace/test:mace_api_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
- python tools/bazel_adb_run.py --target="//mace/test:mace_api_mt_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
ops_benchmark:
stage: ops_benchmark
......
......@@ -72,6 +72,7 @@ int FileStorage::Load() {
return -1;
}
}
utils::WriteLock lock(&data_mutex_);
int fd = open(file_path_.c_str(), O_RDONLY);
if (fd < 0) {
if (errno == ENOENT) {
......@@ -148,11 +149,13 @@ int FileStorage::Load() {
bool FileStorage::Insert(const std::string &key,
const std::vector<unsigned char> &value) {
utils::WriteLock lock(&data_mutex_);
data_.emplace(key, value);
return true;
}
const std::vector<unsigned char> *FileStorage::Find(const std::string &key) {
utils::ReadLock lock(&data_mutex_);
auto iter = data_.find(key);
if (iter == data_.end()) return nullptr;
......@@ -160,6 +163,7 @@ const std::vector<unsigned char> *FileStorage::Find(const std::string &key) {
}
int FileStorage::Flush() {
utils::WriteLock lock(&data_mutex_);
int fd = open(file_path_.c_str(), O_WRONLY | O_CREAT, 0600);
if (fd < 0) {
LOG(WARNING) << "open file " << file_path_
......
......@@ -20,6 +20,7 @@
#include <vector>
#include "mace/public/mace_runtime.h"
#include "mace/utils/rwlock.h"
namespace mace {
......@@ -37,6 +38,7 @@ class FileStorage : public KVStorage {
private:
std::string file_path_;
std::map<std::string, std::vector<unsigned char>> data_;
utils::RWMutex data_mutex_;
};
} // namespace mace
......
......@@ -512,6 +512,7 @@ void OpenCLRuntime::BuildProgramFromSource(
if (this->storage_ != nullptr) {
this->storage_->Insert(built_program_key, content);
std::lock_guard<std::mutex> lock(program_map_changed_mutex_);
this->program_map_changed_ = true;
}
......@@ -565,11 +566,15 @@ cl::Kernel OpenCLRuntime::BuildKernel(
}
void OpenCLRuntime::SaveBuiltCLProgram() {
if (program_map_changed_ && storage_ != nullptr) {
if (storage_->Flush() != 0) {
LOG(FATAL) << "Store opencl compiled kernel to file failed";
if (storage_ != nullptr) {
std::lock_guard<std::mutex> lock(program_map_changed_mutex_);
if (program_map_changed_) {
if (storage_->Flush() != 0) {
LOG(FATAL) << "Store OPENCL compiled kernel to file failed."
" Please Make sure the storage directory exist.";
}
program_map_changed_ = false;
}
program_map_changed_ = false;
}
}
......
......@@ -66,7 +66,6 @@ class OpenCLRuntime {
public:
static OpenCLRuntime *Global();
static void Configure(GPUPerfHint, GPUPriorityHint);
static void Configure(std::shared_ptr<KVStorage> storage_engine);
cl::Context &context();
cl::Device &device();
......@@ -76,21 +75,20 @@ class OpenCLRuntime {
const uint64_t device_global_mem_cache_size() const;
const uint32_t device_compute_units() const;
cl::Kernel BuildKernel(const std::string &program_name,
const std::string &kernel_name,
const std::set<std::string> &build_options);
void GetCallStats(const cl::Event &event, CallStats *stats);
uint64_t GetDeviceMaxWorkGroupSize();
uint64_t GetKernelMaxWorkGroupSize(const cl::Kernel &kernel);
uint64_t GetKernelWaveSize(const cl::Kernel &kernel);
const bool IsNonUniformWorkgroupsSupported();
const bool IsOutOfRangeCheckEnabled() const;
const GPUType ParseGPUType(const std::string &device_name);
const std::string ParseDeviceVersion(const std::string &device_version);
void SaveBuiltCLProgram();
const bool is_profiling_enabled() const;
cl::Kernel BuildKernel(const std::string &program_name,
const std::string &kernel_name,
const std::set<std::string> &build_options);
void SaveBuiltCLProgram();
private:
OpenCLRuntime();
~OpenCLRuntime();
......@@ -114,8 +112,12 @@ class OpenCLRuntime {
const std::string &built_program_key,
const std::string &build_options_str,
cl::Program *program);
const GPUType ParseGPUType(const std::string &device_name);
const std::string ParseDeviceVersion(const std::string &device_version);
private:
std::unique_ptr<KVStorage> storage_;
bool is_profiling_enabled_;
// All OpenCL object must be a pointer and manually deleted before unloading
// OpenCL library.
std::shared_ptr<cl::Context> context_;
......@@ -123,16 +125,14 @@ class OpenCLRuntime {
std::shared_ptr<cl::CommandQueue> command_queue_;
std::map<std::string, cl::Program> built_program_map_;
std::mutex program_build_mutex_;
GPUType gpu_type_;
std::mutex program_map_changed_mutex_;
std::string platform_info_;
std::string opencl_version_;
bool out_of_range_check_;
std::string platform_info_;
bool program_map_changed_;
std::unique_ptr<KVStorage> storage_;
bool is_profiling_enabled_;
uint64_t device_gloabl_mem_cache_size_;
uint32_t device_compute_units_;
GPUType gpu_type_;
static GPUPerfHint kGPUPerfHint;
static GPUPriorityHint kGPUPriorityHint;
......
......@@ -165,7 +165,7 @@ bool RunModel(const std::vector<std::string> &input_names,
}
// DO NOT USE tmp directory.
// please use APP's own directory
// Please use APP's own directory and make sure the directory exists.
const std::string kernel_file_path =
"/data/local/tmp/mace_run/cl";
......
......@@ -28,3 +28,23 @@ cc_test(
"@gtest//:gtest_main",
],
)
cc_test(
name = "mace_api_mt_test",
testonly = 1,
srcs = ["mace_api_mt_test.cc"],
copts = if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
"//mace/ops:test",
"//mace/kernels:kernels",
"//mace/ops:ops",
"@gtest//:gtest_main",
],
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <fstream>
#include <thread> // NOLINT(build/c++11)
#include "mace/core/operator.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
#include "mace/public/mace_runtime.h"
namespace mace {
namespace test {
class MaceMTAPITest : public ::testing::Test {};
namespace {
void GenerateInputs(const std::vector<std::string> &input_names,
const std::vector<int64_t> &input_shape,
std::map<std::string, mace::MaceTensor> *inputs) {
size_t input_size = input_names.size();
for (size_t i = 0; i < input_size; ++i) {
// Allocate input and output
int64_t input_size =
std::accumulate(input_shape.begin(), input_shape.end(), 1,
std::multiplies<int64_t>());
auto buffer_in = std::shared_ptr<float>(new float[input_size],
std::default_delete<float[]>());
// load input
std::vector<float> input_data;
ops::test::GenerateRandomRealTypeData(input_shape, &input_data);
memcpy(buffer_in.get(), input_data.data(), input_size * sizeof(float));
(*inputs)[input_names[i]] = mace::MaceTensor(input_shape, buffer_in);
}
}
void GenerateOutputs(const std::vector<std::string> &output_names,
const std::vector<int64_t> &output_shape,
std::map<std::string, mace::MaceTensor> *outputs) {
size_t output_size = output_names.size();
for (size_t i = 0; i < output_size; ++i) {
int64_t output_size =
std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int64_t>());
auto buffer_out = std::shared_ptr<float>(new float[output_size],
std::default_delete<float[]>());
(*outputs)[output_names[i]] = mace::MaceTensor(output_shape, buffer_out);
}
}
template <typename T>
void BufferToImage(const std::string &input_name,
const std::string &output_name,
const int buffer_type,
const std::vector<int> &mem_ids,
NetDef *net_def,
const int mode = NetMode::NORMAL) {
OperatorDef operator_def;
ops::test::OpDefBuilder("BufferToImage", "BufferToImageOp")
.Input(input_name)
.Output(output_name)
.AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("mode", mode)
.Finalize(&operator_def);
operator_def.set_mem_id(mem_ids);
net_def->add_op()->CopyFrom(operator_def);
}
template <typename T>
void ImageToBuffer(const std::string &input_name,
const std::string &output_name,
const int buffer_type,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("ImageToBuffer", "ImageToBufferOp")
.Input(input_name)
.Output(output_name)
.AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def);
}
template <typename T>
void Conv3x3(const std::string &input_name,
const std::string &filter_name,
const std::string &output_name,
const std::vector<int> &mem_ids,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("Conv2D", "Conv2dOp")
.Input(input_name)
.Input(filter_name)
.Output(output_name)
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(&operator_def);
operator_def.set_mem_id(mem_ids);
net_def->add_op()->CopyFrom(operator_def);
}
template <typename T>
void Relu(const std::string &input_name,
const std::string &output_name,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("Activation", "ReluTest")
.Input(input_name)
.Output(output_name)
.AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def);
}
template <typename T>
void AddTensor(const std::string &name,
const std::vector<int64_t> &shape,
T *data,
NetDef *net_def) {
ConstTensor tensor(name,
reinterpret_cast<unsigned char *>(data),
shape,
DataTypeToEnum<T>::value);
net_def->mutable_tensors().push_back(tensor);
}
template <DeviceType D, typename T>
void CheckOutputs(const NetDef &net_def,
const std::map<std::string, mace::MaceTensor> &inputs,
const std::map<std::string, mace::MaceTensor> &outputs) {
ops::test::OpsTestNet net;
for (auto input : inputs) {
auto input_shape = input.second.shape();
const int64_t data_size = std::accumulate(input_shape.begin(),
input_shape.end(), 1,
std::multiplies<int64_t>());
std::vector<float> input_data(data_size);
memcpy(input_data.data(), input.second.data().get(),
data_size * sizeof(float));
std::string input_name = MakeString("mace_input_node_",
input.first, ":0");
net.AddInputFromArray<D, float>(input_name, input.second.shape(),
input_data);
}
auto tensors = net_def.tensors();
for (auto tensor : tensors) {
auto shape = tensor.dims();
const int64_t data_size = std::accumulate(shape.begin(),
shape.end(), 1,
std::multiplies<int64_t>());
std::vector<T> data(data_size);
memcpy(data.data(), reinterpret_cast<const T *>(tensor.data()),
data_size * sizeof(T));
net.AddInputFromArray<D, T>(tensor.name(), shape, data);
}
net.RunNet(net_def, D);
for (auto output : outputs) {
std::unique_ptr<Tensor> tmp_tensor(
new Tensor(GetDeviceAllocator(DeviceType::CPU),
DataTypeToEnum<float>::v()));
auto output_shape = output.second.shape();
const int64_t data_size = std::accumulate(output_shape.begin(),
output_shape.end(), 1,
std::multiplies<float>());
tmp_tensor->Resize(output.second.shape());
float *data = tmp_tensor->mutable_data<float>();
memcpy(data, output.second.data().get(), data_size * sizeof(float));
std::string output_name = MakeString("mace_output_node_",
output.first, ":0");
ops::test::ExpectTensorNear<float>(*tmp_tensor,
*net.GetOutput(output_name.data()),
1e-5);
}
}
std::map<std::string, int> AddMemoryOptimization(
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes,
NetDef *net_def) {
std::map<std::string, int> res;
int mem_id = 0;
size_t input_shape_size = input_shapes.size();
uint32_t in_mem_block_x = 0;
uint32_t in_mem_block_y = 0;
for (size_t i = 0; i < input_shape_size; ++i) {
in_mem_block_x = std::max<uint32_t>(in_mem_block_x,
input_shapes[i][2] *
RoundUpDiv4(input_shapes[i][3]));
in_mem_block_y = std::max<uint32_t>(in_mem_block_y,
input_shapes[i][0] *
input_shapes[i][1]);
}
size_t input_size = input_names.size();
for (size_t i = 0; i < input_size; ++i) {
net_def->mutable_mem_arena().mutable_mem_block().push_back(
MemoryBlock(mem_id, in_mem_block_x, in_mem_block_y));
res[input_names[i]] = mem_id;
mem_id++;
}
size_t output_shape_size = output_shapes.size();
uint32_t out_mem_block_x = 0;
uint32_t out_mem_block_y = 0;
for (size_t i = 0; i < output_shape_size; ++i) {
out_mem_block_x = std::max<uint32_t>(out_mem_block_x,
output_shapes[i][2] *
RoundUpDiv4(output_shapes[i][3]));
out_mem_block_y = std::max<uint32_t>(out_mem_block_y,
output_shapes[i][0] *
output_shapes[i][1]);
}
size_t output_size = output_names.size();
for (size_t i = 0; i < output_size; ++i) {
net_def->mutable_mem_arena().mutable_mem_block().push_back(
MemoryBlock(mem_id, out_mem_block_x, out_mem_block_y));
res[output_names[i]] = mem_id;
mem_id++;
}
return res;
}
// The height and width of input and output must be equal.
void MaceRunFunc(const int in_out_size) {
std::vector<std::string> input_names;
std::vector<std::string> output_names;
for (int i = 0; i < in_out_size; ++i) {
input_names.push_back(MakeString("input", i));
output_names.push_back(MakeString("output", i));
}
std::string filter_tensor_name = "filter";
std::string filter_tensor_img_name = filter_tensor_name + "_image";
const DeviceType device = DeviceType::GPU;
const std::vector<std::vector<int64_t>> input_shapes = {{1, 32, 32, 16}};
const std::vector<std::vector<int64_t>> output_shapes = {{1, 32, 32, 16}};
const std::vector<int64_t> filter_shape = {3, 3, 16, 16};
NetDef net_def;
// Add memory optimization
auto mem_map = AddMemoryOptimization(input_names, output_names,
input_shapes, output_shapes,
&net_def);
std::vector<half> data;
ops::test::GenerateRandomRealTypeData<half>(filter_shape, &data);
AddTensor<half>(filter_tensor_name, filter_shape, data.data(), &net_def);
for (size_t i = 0; i < input_names.size(); ++i) {
std::string input_name = MakeString("mace_input_node_",
input_names[i], ":0");
BufferToImage<half>(input_name, input_names[i],
mace::kernels::IN_OUT_CHANNEL,
{mem_map[input_names[i]]},
&net_def);
}
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {},
&net_def, NetMode::INIT);
for (size_t i = 0; i < output_names.size(); ++i) {
Conv3x3<half>(input_names[i], filter_tensor_img_name,
output_names[i], {mem_map[output_names[i]]},
&net_def);
}
for (size_t i = 0; i < output_names.size(); ++i) {
std::string output_name = MakeString("mace_output_node_",
output_names[i], ":0");
ImageToBuffer<float>(output_names[i], output_name,
mace::kernels::IN_OUT_CHANNEL, &net_def);
}
const std::string file_path ="/data/local/tmp/mace";
std::shared_ptr<KVStorageFactory> storage_factory(
new FileStorageFactory(file_path));
mace::SetKVStorageFactory(storage_factory);
MaceEngine engine(&net_def, device, input_names, output_names);
std::map<std::string, mace::MaceTensor> inputs;
std::map<std::string, mace::MaceTensor> outputs;
for (int i = 0; i < 5; ++i) {
size_t input_shape_size = input_shapes.size();
for (size_t j = 0; j < input_shape_size; ++j) {
inputs.clear();
outputs.clear();
GenerateInputs(input_names, input_shapes[j], &inputs);
GenerateOutputs(output_names, output_shapes[j], &outputs);
engine.Run(inputs, &outputs);
}
}
CheckOutputs<DeviceType::GPU, half>(net_def, inputs, outputs);
}
} // namespace
TEST_F(MaceMTAPITest, MultipleThread) {
const int thread_num = 10;
std::vector<std::thread> threads;
for (int i = 0; i < thread_num; ++i) {
threads.push_back(std::thread(MaceRunFunc, i));
}
for (auto &t : threads) {
t.join();
}
}
} // namespace test
} // namespace mace
......@@ -18,6 +18,7 @@
#include "mace/core/operator.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
#include "mace/public/mace_runtime.h"
namespace mace {
namespace test {
......@@ -337,11 +338,10 @@ TEST_F(MaceAPITest, GPUVariableInputShape) {
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{3, 3, 16, 16});
MaceRun<float>(2,
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{3, 3, 16, 16});
MaceRun<half>(2,
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{3, 3, 16, 16});
}
} // namespace test
} // namespace mace
......@@ -24,6 +24,7 @@ cc_library(
"timer.h",
"tuner.h",
"utils.h",
"rwlock.h",
],
linkopts = if_android([
"-llog",
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_UTILS_RWLOCK_H_
#define MACE_UTILS_RWLOCK_H_
#include <condition_variable> // NOLINT(build/c++11)
#include <mutex> // NOLINT(build/c++11)
#include "mace/utils/logging.h"
namespace mace {
namespace utils {
class RWMutex {
public:
RWMutex() : counter_(0), waiting_readers_(0), waiting_writers_(0) {}
~RWMutex() = default;
RWMutex(const RWMutex &) = delete;
RWMutex(RWMutex &&) = delete;
RWMutex& operator=(const RWMutex &) = delete;
RWMutex& operator=(RWMutex &&) = delete;
int counter_; // -1 for writer, 0 for nobody, 1~n for reader
int waiting_readers_;
int waiting_writers_;
std::mutex mutex_;
std::condition_variable reader_cv_;
std::condition_variable writer_cv_;
};
// Writer first
class ReadLock {
public:
explicit ReadLock(RWMutex *rw_mutex): rw_mutex_(rw_mutex) {
if (rw_mutex_ == nullptr) {
return;
}
std::unique_lock<std::mutex> lock(rw_mutex->mutex_);
rw_mutex->waiting_readers_++;
rw_mutex->reader_cv_.wait(lock, [&]() -> bool {
return rw_mutex->waiting_writers_ == 0 && rw_mutex->counter_ >= 0;
});
rw_mutex->waiting_readers_--;
rw_mutex->counter_++;
}
~ReadLock() {
if (rw_mutex_ == nullptr) {
return;
}
std::unique_lock<std::mutex> lock(rw_mutex_->mutex_);
rw_mutex_->counter_ -= 1;
if (rw_mutex_->waiting_writers_ > 0) {
if (rw_mutex_->counter_ == 0) {
rw_mutex_->writer_cv_.notify_one();
}
} else {
rw_mutex_->reader_cv_.notify_all();
}
}
ReadLock(const ReadLock &) = delete;
ReadLock(ReadLock &&) = delete;
ReadLock& operator=(const ReadLock &) = delete;
ReadLock& operator=(ReadLock &&) = delete;
private:
RWMutex *rw_mutex_;
};
class WriteLock {
public:
explicit WriteLock(RWMutex *rw_mutex): rw_mutex_(rw_mutex) {
if (rw_mutex_ == nullptr) {
return;
}
std::unique_lock<std::mutex> lock(rw_mutex->mutex_);
rw_mutex->waiting_writers_++;
rw_mutex->writer_cv_.wait(lock, [&]() -> bool {
return rw_mutex->counter_ == 0;
});
rw_mutex->waiting_writers_--;
rw_mutex->counter_--;
}
~WriteLock() {
if (rw_mutex_ == nullptr) {
return;
}
std::unique_lock<std::mutex> lock(rw_mutex_->mutex_);
rw_mutex_->counter_ = 0;
if (rw_mutex_->waiting_writers_ > 0) {
rw_mutex_->writer_cv_.notify_one();
} else {
rw_mutex_->reader_cv_.notify_all();
}
}
WriteLock(const WriteLock &) = delete;
WriteLock(WriteLock &&) = delete;
WriteLock& operator=(const WriteLock &) = delete;
WriteLock& operator=(WriteLock &&) = delete;
private:
RWMutex *rw_mutex_;
};
} // namespace utils
} // namespace mace
#endif // MACE_UTILS_RWLOCK_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册