diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 02966b3f113f451cc38deef637d61dcf2f0bf70a..012eeec84d2fa030a9455f1729b22818aab0018f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -29,6 +29,7 @@ api_test: script: - if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi - python tools/bazel_adb_run.py --target="//mace/test:mace_api_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS + - python tools/bazel_adb_run.py --target="//mace/test:mace_api_mt_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS ops_benchmark: stage: ops_benchmark diff --git a/mace/core/file_storage.cc b/mace/core/file_storage.cc index d0ac3d7833a6e9da3417e787aa4380bb1a0d9b8b..cf5f0099d77bf90d18151b6a31493f0d05aeee81 100644 --- a/mace/core/file_storage.cc +++ b/mace/core/file_storage.cc @@ -72,6 +72,7 @@ int FileStorage::Load() { return -1; } } + utils::WriteLock lock(&data_mutex_); int fd = open(file_path_.c_str(), O_RDONLY); if (fd < 0) { if (errno == ENOENT) { @@ -148,11 +149,13 @@ int FileStorage::Load() { bool FileStorage::Insert(const std::string &key, const std::vector &value) { + utils::WriteLock lock(&data_mutex_); data_.emplace(key, value); return true; } const std::vector *FileStorage::Find(const std::string &key) { + utils::ReadLock lock(&data_mutex_); auto iter = data_.find(key); if (iter == data_.end()) return nullptr; @@ -160,6 +163,7 @@ const std::vector *FileStorage::Find(const std::string &key) { } int FileStorage::Flush() { + utils::WriteLock lock(&data_mutex_); int fd = open(file_path_.c_str(), O_WRONLY | O_CREAT, 0600); if (fd < 0) { LOG(WARNING) << "open file " << file_path_ diff --git a/mace/core/file_storage.h b/mace/core/file_storage.h index 3dee8b4e73a45429141e7c96a5ab85999b6eee30..7ff12419f749540c5efc0cabdba9508b57546237 100644 --- a/mace/core/file_storage.h +++ b/mace/core/file_storage.h @@ -20,6 +20,7 @@ #include #include "mace/public/mace_runtime.h" +#include "mace/utils/rwlock.h" namespace mace { @@ -37,6 +38,7 @@ class FileStorage : public KVStorage { private: std::string file_path_; std::map> data_; + utils::RWMutex data_mutex_; }; } // namespace mace diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index f9b0d5e2e99ce00d8a4e961ab7b9c24c56b458ed..202a23cd16a29067531bb1c173a94435ba80e210 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -512,6 +512,7 @@ void OpenCLRuntime::BuildProgramFromSource( if (this->storage_ != nullptr) { this->storage_->Insert(built_program_key, content); + std::lock_guard lock(program_map_changed_mutex_); this->program_map_changed_ = true; } @@ -565,11 +566,15 @@ cl::Kernel OpenCLRuntime::BuildKernel( } void OpenCLRuntime::SaveBuiltCLProgram() { - if (program_map_changed_ && storage_ != nullptr) { - if (storage_->Flush() != 0) { - LOG(FATAL) << "Store opencl compiled kernel to file failed"; + if (storage_ != nullptr) { + std::lock_guard lock(program_map_changed_mutex_); + if (program_map_changed_) { + if (storage_->Flush() != 0) { + LOG(FATAL) << "Store OPENCL compiled kernel to file failed." + " Please Make sure the storage directory exist."; + } + program_map_changed_ = false; } - program_map_changed_ = false; } } diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index 521698b7fc33d4b1ee7fa8f3e19265e511870081..238ec734a65f4fad41537c28e381a0c84802aa8c 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -66,7 +66,6 @@ class OpenCLRuntime { public: static OpenCLRuntime *Global(); static void Configure(GPUPerfHint, GPUPriorityHint); - static void Configure(std::shared_ptr storage_engine); cl::Context &context(); cl::Device &device(); @@ -76,21 +75,20 @@ class OpenCLRuntime { const uint64_t device_global_mem_cache_size() const; const uint32_t device_compute_units() const; - cl::Kernel BuildKernel(const std::string &program_name, - const std::string &kernel_name, - const std::set &build_options); - void GetCallStats(const cl::Event &event, CallStats *stats); uint64_t GetDeviceMaxWorkGroupSize(); uint64_t GetKernelMaxWorkGroupSize(const cl::Kernel &kernel); uint64_t GetKernelWaveSize(const cl::Kernel &kernel); const bool IsNonUniformWorkgroupsSupported(); const bool IsOutOfRangeCheckEnabled() const; - const GPUType ParseGPUType(const std::string &device_name); - const std::string ParseDeviceVersion(const std::string &device_version); - void SaveBuiltCLProgram(); const bool is_profiling_enabled() const; + cl::Kernel BuildKernel(const std::string &program_name, + const std::string &kernel_name, + const std::set &build_options); + + void SaveBuiltCLProgram(); + private: OpenCLRuntime(); ~OpenCLRuntime(); @@ -114,8 +112,12 @@ class OpenCLRuntime { const std::string &built_program_key, const std::string &build_options_str, cl::Program *program); + const GPUType ParseGPUType(const std::string &device_name); + const std::string ParseDeviceVersion(const std::string &device_version); private: + std::unique_ptr storage_; + bool is_profiling_enabled_; // All OpenCL object must be a pointer and manually deleted before unloading // OpenCL library. std::shared_ptr context_; @@ -123,16 +125,14 @@ class OpenCLRuntime { std::shared_ptr command_queue_; std::map built_program_map_; std::mutex program_build_mutex_; - GPUType gpu_type_; + std::mutex program_map_changed_mutex_; + std::string platform_info_; std::string opencl_version_; bool out_of_range_check_; - std::string platform_info_; bool program_map_changed_; - std::unique_ptr storage_; - bool is_profiling_enabled_; uint64_t device_gloabl_mem_cache_size_; uint32_t device_compute_units_; - + GPUType gpu_type_; static GPUPerfHint kGPUPerfHint; static GPUPriorityHint kGPUPriorityHint; diff --git a/mace/examples/example.cc b/mace/examples/example.cc index 63ebde3ad7e617df19aced827f004a63fa14e302..91f8cb6a7dff9c8c931209f48cae4ee71d1d7e98 100644 --- a/mace/examples/example.cc +++ b/mace/examples/example.cc @@ -165,7 +165,7 @@ bool RunModel(const std::vector &input_names, } // DO NOT USE tmp directory. - // please use APP's own directory + // Please use APP's own directory and make sure the directory exists. const std::string kernel_file_path = "/data/local/tmp/mace_run/cl"; diff --git a/mace/test/BUILD b/mace/test/BUILD index f3345cfac12dedb66cbe9f7c4a5d02a120e2113e..c76287f91ffd9db8bb781ea2bd502dcbaa93b321 100644 --- a/mace/test/BUILD +++ b/mace/test/BUILD @@ -28,3 +28,23 @@ cc_test( "@gtest//:gtest_main", ], ) + +cc_test( + name = "mace_api_mt_test", + testonly = 1, + srcs = ["mace_api_mt_test.cc"], + copts = if_openmp_enabled(["-fopenmp"]) + + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + + if_android_armv7(["-mfpu=neon"]) + + if_android_armv7(["-mfloat-abi=softfp"]) + + if_android(["-DMACE_ENABLE_OPENCL"]) + + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), + linkopts = ["-fopenmp"], + linkstatic = 1, + deps = [ + "//mace/ops:test", + "//mace/kernels:kernels", + "//mace/ops:ops", + "@gtest//:gtest_main", + ], +) diff --git a/mace/test/mace_api_mt_test.cc b/mace/test/mace_api_mt_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5510afc236e2e6f5501edaf8e68b1b0308a77104 --- /dev/null +++ b/mace/test/mace_api_mt_test.cc @@ -0,0 +1,327 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include +#include // NOLINT(build/c++11) + +#include "mace/core/operator.h" +#include "mace/kernels/conv_pool_2d_util.h" +#include "mace/ops/ops_test_util.h" +#include "mace/public/mace_runtime.h" + +namespace mace { +namespace test { + +class MaceMTAPITest : public ::testing::Test {}; + +namespace { + +void GenerateInputs(const std::vector &input_names, + const std::vector &input_shape, + std::map *inputs) { + size_t input_size = input_names.size(); + for (size_t i = 0; i < input_size; ++i) { + // Allocate input and output + int64_t input_size = + std::accumulate(input_shape.begin(), input_shape.end(), 1, + std::multiplies()); + auto buffer_in = std::shared_ptr(new float[input_size], + std::default_delete()); + // load input + std::vector input_data; + ops::test::GenerateRandomRealTypeData(input_shape, &input_data); + memcpy(buffer_in.get(), input_data.data(), input_size * sizeof(float)); + (*inputs)[input_names[i]] = mace::MaceTensor(input_shape, buffer_in); + } +} + +void GenerateOutputs(const std::vector &output_names, + const std::vector &output_shape, + std::map *outputs) { + size_t output_size = output_names.size(); + for (size_t i = 0; i < output_size; ++i) { + int64_t output_size = + std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + auto buffer_out = std::shared_ptr(new float[output_size], + std::default_delete()); + (*outputs)[output_names[i]] = mace::MaceTensor(output_shape, buffer_out); + } +} + +template +void BufferToImage(const std::string &input_name, + const std::string &output_name, + const int buffer_type, + const std::vector &mem_ids, + NetDef *net_def, + const int mode = NetMode::NORMAL) { + OperatorDef operator_def; + + ops::test::OpDefBuilder("BufferToImage", "BufferToImageOp") + .Input(input_name) + .Output(output_name) + .AddIntArg("buffer_type", buffer_type) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .AddIntArg("mode", mode) + .Finalize(&operator_def); + + operator_def.set_mem_id(mem_ids); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void ImageToBuffer(const std::string &input_name, + const std::string &output_name, + const int buffer_type, + NetDef *net_def) { + OperatorDef operator_def; + + ops::test::OpDefBuilder("ImageToBuffer", "ImageToBufferOp") + .Input(input_name) + .Output(output_name) + .AddIntArg("buffer_type", buffer_type) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void Conv3x3(const std::string &input_name, + const std::string &filter_name, + const std::string &output_name, + const std::vector &mem_ids, + NetDef *net_def) { + OperatorDef operator_def; + ops::test::OpDefBuilder("Conv2D", "Conv2dOp") + .Input(input_name) + .Input(filter_name) + .Output(output_name) + .AddIntsArg("strides", {1, 1}) + .AddIntArg("padding", Padding::SAME) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + operator_def.set_mem_id(mem_ids); + net_def->add_op()->CopyFrom(operator_def); +} + +template +void Relu(const std::string &input_name, + const std::string &output_name, + NetDef *net_def) { + OperatorDef operator_def; + ops::test::OpDefBuilder("Activation", "ReluTest") + .Input(input_name) + .Output(output_name) + .AddStringArg("activation", "RELU") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void AddTensor(const std::string &name, + const std::vector &shape, + T *data, + NetDef *net_def) { + ConstTensor tensor(name, + reinterpret_cast(data), + shape, + DataTypeToEnum::value); + + net_def->mutable_tensors().push_back(tensor); +} + +template +void CheckOutputs(const NetDef &net_def, + const std::map &inputs, + const std::map &outputs) { + ops::test::OpsTestNet net; + for (auto input : inputs) { + auto input_shape = input.second.shape(); + const int64_t data_size = std::accumulate(input_shape.begin(), + input_shape.end(), 1, + std::multiplies()); + std::vector input_data(data_size); + memcpy(input_data.data(), input.second.data().get(), + data_size * sizeof(float)); + std::string input_name = MakeString("mace_input_node_", + input.first, ":0"); + net.AddInputFromArray(input_name, input.second.shape(), + input_data); + } + auto tensors = net_def.tensors(); + for (auto tensor : tensors) { + auto shape = tensor.dims(); + const int64_t data_size = std::accumulate(shape.begin(), + shape.end(), 1, + std::multiplies()); + std::vector data(data_size); + memcpy(data.data(), reinterpret_cast(tensor.data()), + data_size * sizeof(T)); + net.AddInputFromArray(tensor.name(), shape, data); + } + net.RunNet(net_def, D); + + for (auto output : outputs) { + std::unique_ptr tmp_tensor( + new Tensor(GetDeviceAllocator(DeviceType::CPU), + DataTypeToEnum::v())); + auto output_shape = output.second.shape(); + const int64_t data_size = std::accumulate(output_shape.begin(), + output_shape.end(), 1, + std::multiplies()); + tmp_tensor->Resize(output.second.shape()); + float *data = tmp_tensor->mutable_data(); + memcpy(data, output.second.data().get(), data_size * sizeof(float)); + std::string output_name = MakeString("mace_output_node_", + output.first, ":0"); + ops::test::ExpectTensorNear(*tmp_tensor, + *net.GetOutput(output_name.data()), + 1e-5); + } +} + +std::map AddMemoryOptimization( + const std::vector &input_names, + const std::vector &output_names, + const std::vector> &input_shapes, + const std::vector> &output_shapes, + NetDef *net_def) { + std::map res; + int mem_id = 0; + size_t input_shape_size = input_shapes.size(); + uint32_t in_mem_block_x = 0; + uint32_t in_mem_block_y = 0; + for (size_t i = 0; i < input_shape_size; ++i) { + in_mem_block_x = std::max(in_mem_block_x, + input_shapes[i][2] * + RoundUpDiv4(input_shapes[i][3])); + in_mem_block_y = std::max(in_mem_block_y, + input_shapes[i][0] * + input_shapes[i][1]); + } + size_t input_size = input_names.size(); + for (size_t i = 0; i < input_size; ++i) { + net_def->mutable_mem_arena().mutable_mem_block().push_back( + MemoryBlock(mem_id, in_mem_block_x, in_mem_block_y)); + res[input_names[i]] = mem_id; + mem_id++; + } + size_t output_shape_size = output_shapes.size(); + uint32_t out_mem_block_x = 0; + uint32_t out_mem_block_y = 0; + for (size_t i = 0; i < output_shape_size; ++i) { + out_mem_block_x = std::max(out_mem_block_x, + output_shapes[i][2] * + RoundUpDiv4(output_shapes[i][3])); + out_mem_block_y = std::max(out_mem_block_y, + output_shapes[i][0] * + output_shapes[i][1]); + } + size_t output_size = output_names.size(); + for (size_t i = 0; i < output_size; ++i) { + net_def->mutable_mem_arena().mutable_mem_block().push_back( + MemoryBlock(mem_id, out_mem_block_x, out_mem_block_y)); + res[output_names[i]] = mem_id; + mem_id++; + } + return res; +} + +// The height and width of input and output must be equal. +void MaceRunFunc(const int in_out_size) { + std::vector input_names; + std::vector output_names; + for (int i = 0; i < in_out_size; ++i) { + input_names.push_back(MakeString("input", i)); + output_names.push_back(MakeString("output", i)); + } + std::string filter_tensor_name = "filter"; + std::string filter_tensor_img_name = filter_tensor_name + "_image"; + + const DeviceType device = DeviceType::GPU; + + const std::vector> input_shapes = {{1, 32, 32, 16}}; + const std::vector> output_shapes = {{1, 32, 32, 16}}; + const std::vector filter_shape = {3, 3, 16, 16}; + + NetDef net_def; + + // Add memory optimization + auto mem_map = AddMemoryOptimization(input_names, output_names, + input_shapes, output_shapes, + &net_def); + + std::vector data; + ops::test::GenerateRandomRealTypeData(filter_shape, &data); + AddTensor(filter_tensor_name, filter_shape, data.data(), &net_def); + + for (size_t i = 0; i < input_names.size(); ++i) { + std::string input_name = MakeString("mace_input_node_", + input_names[i], ":0"); + BufferToImage(input_name, input_names[i], + mace::kernels::IN_OUT_CHANNEL, + {mem_map[input_names[i]]}, + &net_def); + } + BufferToImage(filter_tensor_name, filter_tensor_img_name, + mace::kernels::CONV2D_FILTER, {}, + &net_def, NetMode::INIT); + for (size_t i = 0; i < output_names.size(); ++i) { + Conv3x3(input_names[i], filter_tensor_img_name, + output_names[i], {mem_map[output_names[i]]}, + &net_def); + } + for (size_t i = 0; i < output_names.size(); ++i) { + std::string output_name = MakeString("mace_output_node_", + output_names[i], ":0"); + ImageToBuffer(output_names[i], output_name, + mace::kernels::IN_OUT_CHANNEL, &net_def); + } + + const std::string file_path ="/data/local/tmp/mace"; + std::shared_ptr storage_factory( + new FileStorageFactory(file_path)); + mace::SetKVStorageFactory(storage_factory); + + MaceEngine engine(&net_def, device, input_names, output_names); + + std::map inputs; + std::map outputs; + + for (int i = 0; i < 5; ++i) { + size_t input_shape_size = input_shapes.size(); + for (size_t j = 0; j < input_shape_size; ++j) { + inputs.clear(); + outputs.clear(); + GenerateInputs(input_names, input_shapes[j], &inputs); + GenerateOutputs(output_names, output_shapes[j], &outputs); + engine.Run(inputs, &outputs); + } + } + + CheckOutputs(net_def, inputs, outputs); +} + +} // namespace + +TEST_F(MaceMTAPITest, MultipleThread) { + const int thread_num = 10; + std::vector threads; + for (int i = 0; i < thread_num; ++i) { + threads.push_back(std::thread(MaceRunFunc, i)); + } + for (auto &t : threads) { + t.join(); + } +} + +} // namespace test +} // namespace mace diff --git a/mace/test/mace_api_test.cc b/mace/test/mace_api_test.cc index 44a5d3f6cf13624b079c0259de1cf3754700a041..73775520fc811114d4ab8e9e33bdf1a23dc5cb09 100644 --- a/mace/test/mace_api_test.cc +++ b/mace/test/mace_api_test.cc @@ -18,6 +18,7 @@ #include "mace/core/operator.h" #include "mace/kernels/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" +#include "mace/public/mace_runtime.h" namespace mace { namespace test { @@ -337,11 +338,10 @@ TEST_F(MaceAPITest, GPUVariableInputShape) { {{1, 16, 32, 16}, {1, 32, 64, 16}}, {{1, 16, 32, 16}, {1, 32, 64, 16}}, {3, 3, 16, 16}); - MaceRun(2, - {{1, 16, 32, 16}, {1, 32, 64, 16}}, - {{1, 16, 32, 16}, {1, 32, 64, 16}}, - {3, 3, 16, 16}); + MaceRun(2, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {3, 3, 16, 16}); } - } // namespace test } // namespace mace diff --git a/mace/utils/BUILD b/mace/utils/BUILD index 51f289e149796b441ec5ab35a27af187f4e5833a..85e0647d2db75971ecca95f8e9a251befdfd5f26 100644 --- a/mace/utils/BUILD +++ b/mace/utils/BUILD @@ -24,6 +24,7 @@ cc_library( "timer.h", "tuner.h", "utils.h", + "rwlock.h", ], linkopts = if_android([ "-llog", diff --git a/mace/utils/rwlock.h b/mace/utils/rwlock.h new file mode 100644 index 0000000000000000000000000000000000000000..c3bd2a8845044174c725e33dbb04ff62444a3d74 --- /dev/null +++ b/mace/utils/rwlock.h @@ -0,0 +1,118 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_UTILS_RWLOCK_H_ +#define MACE_UTILS_RWLOCK_H_ + +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include "mace/utils/logging.h" + +namespace mace { +namespace utils { + +class RWMutex { + public: + RWMutex() : counter_(0), waiting_readers_(0), waiting_writers_(0) {} + ~RWMutex() = default; + RWMutex(const RWMutex &) = delete; + RWMutex(RWMutex &&) = delete; + RWMutex& operator=(const RWMutex &) = delete; + RWMutex& operator=(RWMutex &&) = delete; + + int counter_; // -1 for writer, 0 for nobody, 1~n for reader + int waiting_readers_; + int waiting_writers_; + std::mutex mutex_; + std::condition_variable reader_cv_; + std::condition_variable writer_cv_; +}; + +// Writer first +class ReadLock { + public: + explicit ReadLock(RWMutex *rw_mutex): rw_mutex_(rw_mutex) { + if (rw_mutex_ == nullptr) { + return; + } + std::unique_lock lock(rw_mutex->mutex_); + rw_mutex->waiting_readers_++; + rw_mutex->reader_cv_.wait(lock, [&]() -> bool { + return rw_mutex->waiting_writers_ == 0 && rw_mutex->counter_ >= 0; + }); + rw_mutex->waiting_readers_--; + rw_mutex->counter_++; + } + ~ReadLock() { + if (rw_mutex_ == nullptr) { + return; + } + std::unique_lock lock(rw_mutex_->mutex_); + rw_mutex_->counter_ -= 1; + if (rw_mutex_->waiting_writers_ > 0) { + if (rw_mutex_->counter_ == 0) { + rw_mutex_->writer_cv_.notify_one(); + } + } else { + rw_mutex_->reader_cv_.notify_all(); + } + } + ReadLock(const ReadLock &) = delete; + ReadLock(ReadLock &&) = delete; + ReadLock& operator=(const ReadLock &) = delete; + ReadLock& operator=(ReadLock &&) = delete; + + private: + RWMutex *rw_mutex_; +}; + +class WriteLock { + public: + explicit WriteLock(RWMutex *rw_mutex): rw_mutex_(rw_mutex) { + if (rw_mutex_ == nullptr) { + return; + } + std::unique_lock lock(rw_mutex->mutex_); + rw_mutex->waiting_writers_++; + rw_mutex->writer_cv_.wait(lock, [&]() -> bool { + return rw_mutex->counter_ == 0; + }); + rw_mutex->waiting_writers_--; + rw_mutex->counter_--; + } + ~WriteLock() { + if (rw_mutex_ == nullptr) { + return; + } + std::unique_lock lock(rw_mutex_->mutex_); + rw_mutex_->counter_ = 0; + if (rw_mutex_->waiting_writers_ > 0) { + rw_mutex_->writer_cv_.notify_one(); + } else { + rw_mutex_->reader_cv_.notify_all(); + } + } + WriteLock(const WriteLock &) = delete; + WriteLock(WriteLock &&) = delete; + WriteLock& operator=(const WriteLock &) = delete; + WriteLock& operator=(WriteLock &&) = delete; + + private: + RWMutex *rw_mutex_; +}; + +} // namespace utils +} // namespace mace + +#endif // MACE_UTILS_RWLOCK_H_