From fb22aa74b9fb68afc4aecb107442965cdab0394a Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 6 Dec 2017 22:19:38 +0800 Subject: [PATCH] Mace GPU memory sharing optimization --- mace/core/operator.h | 9 ++- mace/core/tensor.h | 27 +++++-- mace/core/workspace.cc | 30 +++++++- mace/core/workspace.h | 8 +- mace/examples/mace_run.cc | 3 + mace/proto/mace.proto | 3 + mace/python/tools/BUILD | 9 +++ mace/python/tools/memory_optimizer.py | 102 ++++++++++++++++++++++++++ tools/validate_gcn.sh | 17 +++-- 9 files changed, 193 insertions(+), 15 deletions(-) create mode 100644 mace/python/tools/memory_optimizer.py diff --git a/mace/core/operator.h b/mace/core/operator.h index 137a2e1a..4cd52fb3 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -91,8 +91,13 @@ class Operator : public OperatorBase { } for (const string &output_str : operator_def.output()) { - outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor( - output_str, GetDeviceAllocator(D), DataTypeToEnum::v()))); + if (ws->HasTensor(output_str)) { + Tensor *found_tensor = ws->GetTensor(output_str); + outputs_.push_back(ws->GetTensor(output_str)); + } else { + outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor( + output_str, GetDeviceAllocator(D), DataTypeToEnum::v()))); + } } } virtual bool Run() override = 0; diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 2af45d6c..94f95228 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -199,14 +199,20 @@ class Tensor { size_ = size; MACE_CHECK(data_ == nullptr, "Buffer must be unmapped before resize"); - if (is_image_) { - alloc_->DeleteImage(buffer_); - } else { + if (is_image_ && !image_shape_.empty()) { + MACE_ASSERT(image_shape_.size() == 2 + && image_shape_[0] >= image_shape[0] + || image_shape_[1] >= image_shape[1], + "image shape not large enough"); + } + if (!is_image_ && buffer_ != nullptr) { alloc_->Delete(buffer_); } is_image_ = true; - image_shape_ = image_shape; - buffer_ = alloc_->NewImage(image_shape, dtype_); + if (image_shape_.empty()) { + image_shape_ = image_shape; + buffer_ = alloc_->NewImage(image_shape, dtype_); + } } } @@ -226,6 +232,17 @@ class Tensor { } } + inline void AllocateImageMemory(const std::vector &image_shape) { + is_image_ = true; + if (image_shape_ != image_shape) { + if (buffer_ != nullptr) { + alloc_->DeleteImage(buffer_); + } + image_shape_ = image_shape; + buffer_ = alloc_->NewImage(image_shape, dtype_); + } + } + template inline void Copy(const T *src, index_t size) { MACE_CHECK(size == size_, "copy src and dst with different size."); diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 2a172f3e..e8fc98f9 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -3,8 +3,8 @@ // #include "mace/core/workspace.h" -#include "mace/core/common.h" #include "mace/core/serializer.h" +#include "mace/core/proto_utils.h" namespace mace { @@ -63,6 +63,34 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) { tensor_map_[tensor_proto.name()] = serializer.Deserialize(tensor_proto, type); } + if (type == DeviceType::OPENCL) { + CreateImageOutputTensor(net_def); + } +} + +void Workspace::CreateImageOutputTensor(const NetDef &net_def) { + if (!net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) { + return; + } + std::map> mem_tensor_map; + const DataType dtype = static_cast( + ArgumentHelper::GetSingleArgument( + net_def.op(0), + "T", + static_cast(DT_FLOAT))); + for (auto &mem_block: net_def.mem_arena().mem_block()) { + string mem_block_name = MemBlockName(mem_block.mem_id()); + mem_tensor_map[mem_block_name].reset(new Tensor( + GetDeviceAllocator(DeviceType::OPENCL), + dtype)); + mem_tensor_map[mem_block_name]->AllocateImageMemory({mem_block.x(), + mem_block.y()}); + } + for (auto &op: net_def.op()) { + if (op.has_mem_id()) { + tensor_map_[op.output(0)] = mem_tensor_map[MemBlockName(op.mem_id())]; + } + } } } // namespace mace \ No newline at end of file diff --git a/mace/core/workspace.h b/mace/core/workspace.h index 291bc059..8a706b87 100644 --- a/mace/core/workspace.h +++ b/mace/core/workspace.h @@ -13,7 +13,7 @@ namespace mace { class Workspace { public: - typedef map> TensorMap; + typedef map> TensorMap; Workspace() {} @@ -33,7 +33,13 @@ class Workspace { void LoadModelTensor(const NetDef &net_def, DeviceType type); + inline std::string MemBlockName(int mem_id) const { + return internal::MakeString("mem_block_", mem_id); + }; + private: + void CreateImageOutputTensor(const NetDef &net_def); + TensorMap tensor_map_; DISABLE_COPY_AND_ASSIGN(Workspace); diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index 8ca9765b..73fa6767 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -101,9 +101,12 @@ int main(int argc, char **argv) { } // Init model + VLOG(0) << "Run init"; auto net = CreateNet(net_def, &ws, device_type, NetMode::INIT); net->Run(); + VLOG(0) << "Run model"; + // run model net = CreateNet(net_def, &ws, device_type); diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto index 37a34943..2aa79796 100644 --- a/mace/proto/mace.proto +++ b/mace/proto/mace.proto @@ -128,6 +128,9 @@ message NetDef { repeated Argument arg = 4; repeated TensorProto tensors = 5; + // for mem optimization + optional MemoryArena mem_arena = 10; + // for hexagon mace-nnlib repeated InputInfo input_info = 100; repeated OutputInfo output_info = 101; diff --git a/mace/python/tools/BUILD b/mace/python/tools/BUILD index 964ea528..f5b1f15a 100644 --- a/mace/python/tools/BUILD +++ b/mace/python/tools/BUILD @@ -20,6 +20,15 @@ py_binary( ], ) +py_binary( + name = "memory_optimizer", + srcs = ["memory_optimizer.py"], + srcs_version = "PY2AND3", + deps = [ + "//mace/proto:mace_py", + ], +) + py_binary( name = "tf_ops_stats", srcs = ["tf_ops_stats.py"], diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py new file mode 100644 index 00000000..f64df5ba --- /dev/null +++ b/mace/python/tools/memory_optimizer.py @@ -0,0 +1,102 @@ +import sys +import operator +from mace.proto import mace_pb2 + +class MemoryOptimizer(object): + def __init__(self, net_def): + self.net_def = net_def + self.idle_mem = set() + self.op_mem = {} # op_name->mem_id + self.mem_block = {} # mem_id->[x, y] + self.total_mem_count = 0 + self.ref_counter = {} + + consumers = {} + for op in net_def.op: + if self.is_buffer_image_op(op): + continue + for ipt in op.input: + if ipt not in consumers: + consumers[ipt] = [] + consumers[ipt].append(op) + # only ref op's output tensor + for op in net_def.op: + if self.is_buffer_image_op(op): + continue + tensor_name = self._op_to_tensor(op) + if tensor_name in consumers: + self.ref_counter[tensor_name] = len(consumers[tensor_name]) + else: + self.ref_counter[tensor_name] = 0 + + def _op_to_tensor(self, op): + return op.name + ':0' + + def is_buffer_image_op(self, op): + return op.type == 'BufferToImage' or op.type == 'ImageToBuffer' + + def optimize(self): + for op in self.net_def.op: + if self.is_buffer_image_op(op): + continue + if len(self.idle_mem) == 0: + # allocate new mem + mem_id = self.total_mem_count + self.total_mem_count += 1 + else: + # reuse mem + mem_id = self.idle_mem.pop() + + op.mem_id = mem_id + self.op_mem[self._op_to_tensor(op)] = mem_id + if mem_id not in self.mem_block: + self.mem_block[mem_id] = [0, 0] + mem_size = self.mem_block[mem_id] + mem_size[1] = max(mem_size[1], op.output_shape[0].dims[0] * op.output_shape[0].dims[1]) + mem_size[0] = max(mem_size[0], op.output_shape[0].dims[2] * (op.output_shape[0].dims[3]+3)/4) + + # de-ref input tensor mem + for ipt in op.input: + if ipt in self.ref_counter: + self.ref_counter[ipt] -= 1 + if self.ref_counter[ipt] == 0: + self.idle_mem.add(self.op_mem[ipt]) + elif self.ref_counter[ipt] < 0: + raise Exception('ref count is less than 0') + + for mem in self.mem_block: + arena = net_def.mem_arena + block = arena.mem_block.add() + block.mem_id = mem + block.x = self.mem_block[mem][0] + block.y = self.mem_block[mem][1] + + print('total op: %d', len(self.net_def.op)) + origin_mem_size = 0 + optimized_mem_size = 0 + for op in self.net_def.op: + if self.is_buffer_image_op(op): + continue + origin_mem_size += reduce(operator.mul, op.output_shape[0].dims, 1) + for mem in self.mem_block: + optimized_mem_size += reduce(operator.mul, self.mem_block[mem], 4) + + print('origin mem: %d, optimized mem: %d', origin_mem_size, optimized_mem_size) + +if __name__ == '__main__': + model_file = sys.argv[1] + opt_model_file = sys.argv[2] + with open(model_file, "rb") as f: + net_def = mace_pb2.NetDef() + net_def.ParseFromString(f.read()) + optimizer = MemoryOptimizer(net_def) + optimizer.optimize() + + with open(opt_model_file, "wb") as f: + f.write(net_def.SerializeToString()) + with open(opt_model_file + '_txt', "wb") as f: + net_def.ClearField('tensors') + f.write(str(net_def)) + + + diff --git a/tools/validate_gcn.sh b/tools/validate_gcn.sh index 524c752b..b62cb784 100644 --- a/tools/validate_gcn.sh +++ b/tools/validate_gcn.sh @@ -1,6 +1,6 @@ #!/bin/bash # Must run at root dir of mace project. - +set +x Usage() { echo 'Usage: bash tools/validate_gcn.sh tf_model_file' } @@ -13,6 +13,7 @@ fi TF_MODEL_FILE_PATH=$1 MODEL_DIR=$(dirname ${TF_MODEL_FILE_PATH}) MACE_MODEL_NAME='mace_model.pb' +MACE_OPT_MODEL_NAME='mace_opt_model.pb' INPUT_FILE_NAME='model_input' OUTPUT_FILE_NAME='gcn.out' OUTPUT_LIST_FILE='gcn.list' @@ -26,14 +27,17 @@ python tools/validate.py --generate_data true --random_seed 1 \ --input_shape=512,512,3 # Step 2: convert tf model to mace model -echo "Step 2: convert tf model to mace model" +echo "Step 2: convert tf model to mace model and optimize memory" bazel build //mace/python/tools:tf_converter bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \ --output=${MODEL_DIR}/${MACE_MODEL_NAME} \ --input_node=input \ --output_node=GCN/br_result_2/fcn_br \ - --data_type=DT_HALF\ + --data_type=DT_HALF \ --runtime=gpu +bazel build mace/python/tools:memory_optimizer +bazel-bin/mace/python/tools/memory_optimizer ${MODEL_DIR}/${MACE_MODEL_NAME} \ + ${MODEL_DIR}/${MACE_OPT_MODEL_NAME} # Step 3: Run model on the phone @@ -46,7 +50,7 @@ bazel build -c opt --strip always mace/examples:mace_run \ adb shell "mkdir -p ${PHONE_DATA_DIR}" adb shell "mkdir -p ${KERNEL_DIR}" adb push mace/kernels/opencl/cl/* ${KERNEL_DIR} -adb push ${MODEL_DIR}/${MACE_MODEL_NAME} ${PHONE_DATA_DIR} +adb push ${MODEL_DIR}/${MACE_OPT_MODEL_NAME} ${PHONE_DATA_DIR} adb push ${MODEL_DIR}/${INPUT_FILE_NAME} ${PHONE_DATA_DIR} adb push bazel-bin/mace/examples/mace_run ${PHONE_DATA_DIR} @@ -56,13 +60,14 @@ adb