提交 fb22aa74 编写于 作者: Y Yin Li

Mace GPU memory sharing optimization

上级 22581f22
......@@ -91,8 +91,13 @@ class Operator : public OperatorBase {
}
for (const string &output_str : operator_def.output()) {
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, GetDeviceAllocator(D), DataTypeToEnum<T>::v())));
if (ws->HasTensor(output_str)) {
Tensor *found_tensor = ws->GetTensor(output_str);
outputs_.push_back(ws->GetTensor(output_str));
} else {
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, GetDeviceAllocator(D), DataTypeToEnum<T>::v())));
}
}
}
virtual bool Run() override = 0;
......
......@@ -199,14 +199,20 @@ class Tensor {
size_ = size;
MACE_CHECK(data_ == nullptr, "Buffer must be unmapped before resize");
if (is_image_) {
alloc_->DeleteImage(buffer_);
} else {
if (is_image_ && !image_shape_.empty()) {
MACE_ASSERT(image_shape_.size() == 2
&& image_shape_[0] >= image_shape[0]
|| image_shape_[1] >= image_shape[1],
"image shape not large enough");
}
if (!is_image_ && buffer_ != nullptr) {
alloc_->Delete(buffer_);
}
is_image_ = true;
image_shape_ = image_shape;
buffer_ = alloc_->NewImage(image_shape, dtype_);
if (image_shape_.empty()) {
image_shape_ = image_shape;
buffer_ = alloc_->NewImage(image_shape, dtype_);
}
}
}
......@@ -226,6 +232,17 @@ class Tensor {
}
}
inline void AllocateImageMemory(const std::vector<size_t> &image_shape) {
is_image_ = true;
if (image_shape_ != image_shape) {
if (buffer_ != nullptr) {
alloc_->DeleteImage(buffer_);
}
image_shape_ = image_shape;
buffer_ = alloc_->NewImage(image_shape, dtype_);
}
}
template <typename T>
inline void Copy(const T *src, index_t size) {
MACE_CHECK(size == size_, "copy src and dst with different size.");
......
......@@ -3,8 +3,8 @@
//
#include "mace/core/workspace.h"
#include "mace/core/common.h"
#include "mace/core/serializer.h"
#include "mace/core/proto_utils.h"
namespace mace {
......@@ -63,6 +63,34 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) {
tensor_map_[tensor_proto.name()] =
serializer.Deserialize(tensor_proto, type);
}
if (type == DeviceType::OPENCL) {
CreateImageOutputTensor(net_def);
}
}
void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
if (!net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) {
return;
}
std::map<std::string, std::shared_ptr<Tensor>> mem_tensor_map;
const DataType dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
net_def.op(0),
"T",
static_cast<int>(DT_FLOAT)));
for (auto &mem_block: net_def.mem_arena().mem_block()) {
string mem_block_name = MemBlockName(mem_block.mem_id());
mem_tensor_map[mem_block_name].reset(new Tensor(
GetDeviceAllocator(DeviceType::OPENCL),
dtype));
mem_tensor_map[mem_block_name]->AllocateImageMemory({mem_block.x(),
mem_block.y()});
}
for (auto &op: net_def.op()) {
if (op.has_mem_id()) {
tensor_map_[op.output(0)] = mem_tensor_map[MemBlockName(op.mem_id())];
}
}
}
} // namespace mace
\ No newline at end of file
......@@ -13,7 +13,7 @@ namespace mace {
class Workspace {
public:
typedef map<string, unique_ptr<Tensor>> TensorMap;
typedef map<string, std::shared_ptr<Tensor>> TensorMap;
Workspace() {}
......@@ -33,7 +33,13 @@ class Workspace {
void LoadModelTensor(const NetDef &net_def, DeviceType type);
inline std::string MemBlockName(int mem_id) const {
return internal::MakeString("mem_block_", mem_id);
};
private:
void CreateImageOutputTensor(const NetDef &net_def);
TensorMap tensor_map_;
DISABLE_COPY_AND_ASSIGN(Workspace);
......
......@@ -101,9 +101,12 @@ int main(int argc, char **argv) {
}
// Init model
VLOG(0) << "Run init";
auto net = CreateNet(net_def, &ws, device_type, NetMode::INIT);
net->Run();
VLOG(0) << "Run model";
// run model
net = CreateNet(net_def, &ws, device_type);
......
......@@ -128,6 +128,9 @@ message NetDef {
repeated Argument arg = 4;
repeated TensorProto tensors = 5;
// for mem optimization
optional MemoryArena mem_arena = 10;
// for hexagon mace-nnlib
repeated InputInfo input_info = 100;
repeated OutputInfo output_info = 101;
......
......@@ -20,6 +20,15 @@ py_binary(
],
)
py_binary(
name = "memory_optimizer",
srcs = ["memory_optimizer.py"],
srcs_version = "PY2AND3",
deps = [
"//mace/proto:mace_py",
],
)
py_binary(
name = "tf_ops_stats",
srcs = ["tf_ops_stats.py"],
......
import sys
import operator
from mace.proto import mace_pb2
class MemoryOptimizer(object):
def __init__(self, net_def):
self.net_def = net_def
self.idle_mem = set()
self.op_mem = {} # op_name->mem_id
self.mem_block = {} # mem_id->[x, y]
self.total_mem_count = 0
self.ref_counter = {}
consumers = {}
for op in net_def.op:
if self.is_buffer_image_op(op):
continue
for ipt in op.input:
if ipt not in consumers:
consumers[ipt] = []
consumers[ipt].append(op)
# only ref op's output tensor
for op in net_def.op:
if self.is_buffer_image_op(op):
continue
tensor_name = self._op_to_tensor(op)
if tensor_name in consumers:
self.ref_counter[tensor_name] = len(consumers[tensor_name])
else:
self.ref_counter[tensor_name] = 0
def _op_to_tensor(self, op):
return op.name + ':0'
def is_buffer_image_op(self, op):
return op.type == 'BufferToImage' or op.type == 'ImageToBuffer'
def optimize(self):
for op in self.net_def.op:
if self.is_buffer_image_op(op):
continue
if len(self.idle_mem) == 0:
# allocate new mem
mem_id = self.total_mem_count
self.total_mem_count += 1
else:
# reuse mem
mem_id = self.idle_mem.pop()
op.mem_id = mem_id
self.op_mem[self._op_to_tensor(op)] = mem_id
if mem_id not in self.mem_block:
self.mem_block[mem_id] = [0, 0]
mem_size = self.mem_block[mem_id]
mem_size[1] = max(mem_size[1], op.output_shape[0].dims[0] * op.output_shape[0].dims[1])
mem_size[0] = max(mem_size[0], op.output_shape[0].dims[2] * (op.output_shape[0].dims[3]+3)/4)
# de-ref input tensor mem
for ipt in op.input:
if ipt in self.ref_counter:
self.ref_counter[ipt] -= 1
if self.ref_counter[ipt] == 0:
self.idle_mem.add(self.op_mem[ipt])
elif self.ref_counter[ipt] < 0:
raise Exception('ref count is less than 0')
for mem in self.mem_block:
arena = net_def.mem_arena
block = arena.mem_block.add()
block.mem_id = mem
block.x = self.mem_block[mem][0]
block.y = self.mem_block[mem][1]
print('total op: %d', len(self.net_def.op))
origin_mem_size = 0
optimized_mem_size = 0
for op in self.net_def.op:
if self.is_buffer_image_op(op):
continue
origin_mem_size += reduce(operator.mul, op.output_shape[0].dims, 1)
for mem in self.mem_block:
optimized_mem_size += reduce(operator.mul, self.mem_block[mem], 4)
print('origin mem: %d, optimized mem: %d', origin_mem_size, optimized_mem_size)
if __name__ == '__main__':
model_file = sys.argv[1]
opt_model_file = sys.argv[2]
with open(model_file, "rb") as f:
net_def = mace_pb2.NetDef()
net_def.ParseFromString(f.read())
optimizer = MemoryOptimizer(net_def)
optimizer.optimize()
with open(opt_model_file, "wb") as f:
f.write(net_def.SerializeToString())
with open(opt_model_file + '_txt', "wb") as f:
net_def.ClearField('tensors')
f.write(str(net_def))
#!/bin/bash
# Must run at root dir of mace project.
set +x
Usage() {
echo 'Usage: bash tools/validate_gcn.sh tf_model_file'
}
......@@ -13,6 +13,7 @@ fi
TF_MODEL_FILE_PATH=$1
MODEL_DIR=$(dirname ${TF_MODEL_FILE_PATH})
MACE_MODEL_NAME='mace_model.pb'
MACE_OPT_MODEL_NAME='mace_opt_model.pb'
INPUT_FILE_NAME='model_input'
OUTPUT_FILE_NAME='gcn.out'
OUTPUT_LIST_FILE='gcn.list'
......@@ -26,14 +27,17 @@ python tools/validate.py --generate_data true --random_seed 1 \
--input_shape=512,512,3
# Step 2: convert tf model to mace model
echo "Step 2: convert tf model to mace model"
echo "Step 2: convert tf model to mace model and optimize memory"
bazel build //mace/python/tools:tf_converter
bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \
--output=${MODEL_DIR}/${MACE_MODEL_NAME} \
--input_node=input \
--output_node=GCN/br_result_2/fcn_br \
--data_type=DT_HALF\
--data_type=DT_HALF \
--runtime=gpu
bazel build mace/python/tools:memory_optimizer
bazel-bin/mace/python/tools/memory_optimizer ${MODEL_DIR}/${MACE_MODEL_NAME} \
${MODEL_DIR}/${MACE_OPT_MODEL_NAME}
# Step 3: Run model on the phone
......@@ -46,7 +50,7 @@ bazel build -c opt --strip always mace/examples:mace_run \
adb shell "mkdir -p ${PHONE_DATA_DIR}"
adb shell "mkdir -p ${KERNEL_DIR}"
adb push mace/kernels/opencl/cl/* ${KERNEL_DIR}
adb push ${MODEL_DIR}/${MACE_MODEL_NAME} ${PHONE_DATA_DIR}
adb push ${MODEL_DIR}/${MACE_OPT_MODEL_NAME} ${PHONE_DATA_DIR}
adb push ${MODEL_DIR}/${INPUT_FILE_NAME} ${PHONE_DATA_DIR}
adb push bazel-bin/mace/examples/mace_run ${PHONE_DATA_DIR}
......@@ -56,13 +60,14 @@ adb </dev/null shell MACE_RUN_PARAMETER_PATH=${PHONE_DATA_DIR}/mace_run.config \
MACE_KERNEL_PATH=$KERNEL_DIR \
OMP_NUM_THREADS=$num_threads \
${PHONE_DATA_DIR}/mace_run \
--model=${PHONE_DATA_DIR}/${MACE_MODEL_NAME} \
--model=${PHONE_DATA_DIR}/${MACE_OPT_MODEL_NAME} \
--input=mace_input_node \
--output=mace_output_node \
--input_shape=1,512,512,3\
--input_file=${PHONE_DATA_DIR}/${INPUT_FILE_NAME} \
--output_file=${PHONE_DATA_DIR}/${OUTPUT_FILE_NAME} \
--device=OPENCL
--device=OPENCL \
--round=1
# Step 4: pull the mace run result.
echo "Step 4: pull the mace run result."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册