提交 94d23921 编写于 作者: L liuqi

GPU memory reusing support multiple outputs.

上级 4cf4120c
...@@ -153,7 +153,9 @@ void OperatorDef::CopyFrom(const OperatorDef &from) { ...@@ -153,7 +153,9 @@ void OperatorDef::CopyFrom(const OperatorDef &from) {
output_type_.resize(from_data_type.size()); output_type_.resize(from_data_type.size());
std::copy(from_data_type.begin(), from_data_type.end(), output_type_.begin()); std::copy(from_data_type.begin(), from_data_type.end(), output_type_.begin());
mem_id_ = from.mem_id(); auto mem_ids = from.mem_id();
mem_id_.resize(mem_ids.size());
std::copy(mem_ids.begin(), mem_ids.end(), mem_id_.begin());
// nnlib // nnlib
node_id_ = from.node_id(); node_id_ = from.node_id();
...@@ -186,13 +188,11 @@ void OperatorDef::set_type(const std::string &type_) { ...@@ -186,13 +188,11 @@ void OperatorDef::set_type(const std::string &type_) {
} }
bool OperatorDef::has_type() const { return (has_bits_ & 0x00000002u) != 0; } bool OperatorDef::has_type() const { return (has_bits_ & 0x00000002u) != 0; }
void OperatorDef::set_has_type() { has_bits_ |= 0x00000002u; } void OperatorDef::set_has_type() { has_bits_ |= 0x00000002u; }
int OperatorDef::mem_id() const { return mem_id_; } const std::vector<int> &OperatorDef::mem_id() const { return mem_id_; }
void OperatorDef::set_mem_id(const int mem_id) { void OperatorDef::set_mem_id(const std::vector<int> &value) {
set_has_mem_id(); mem_id_.resize(value.size());
mem_id_ = mem_id; std::copy(value.begin(), value.end(), mem_id_.begin());
} }
bool OperatorDef::has_mem_id() const { return (has_bits_ & 0x00000004u) != 0; }
void OperatorDef::set_has_mem_id() { has_bits_ |= 0x00000004u; }
uint32_t OperatorDef::node_id() const { return node_id_; } uint32_t OperatorDef::node_id() const { return node_id_; }
void OperatorDef::set_node_id(uint32_t node_id) { node_id_ = node_id; } void OperatorDef::set_node_id(uint32_t node_id) { node_id_ = node_id; }
uint32_t OperatorDef::op_id() const { return op_id_; } uint32_t OperatorDef::op_id() const { return op_id_; }
......
...@@ -116,7 +116,7 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) { ...@@ -116,7 +116,7 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
// As DSP may have different data output type for each op, // As DSP may have different data output type for each op,
// we stick to the same concept. // we stick to the same concept.
for (auto &op : net_def.op()) { for (auto &op : net_def.op()) {
if (op.has_mem_id()) { if (! op.mem_id().empty()){
const DataType op_dtype = static_cast<DataType>( const DataType op_dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT))); op, "T", static_cast<int>(DT_FLOAT)));
...@@ -135,20 +135,21 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) { ...@@ -135,20 +135,21 @@ void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
} }
VLOG(3) << "Preallocate image to tensors"; VLOG(3) << "Preallocate image to tensors";
for (auto &op : net_def.op()) { for (auto &op : net_def.op()) {
if (op.has_mem_id()) { if (!op.mem_id().empty()) {
std::unique_ptr<Tensor> tensor( auto mem_ids = op.mem_id();
new Tensor(preallocated_allocator_.GetBuffer(op.mem_id()), dtype)); for (auto mem_id : mem_ids) {
std::unique_ptr<Tensor> tensor
(new Tensor(preallocated_allocator_.GetBuffer(mem_id), dtype));
tensor->SetSourceOpName(op.name()); tensor->SetSourceOpName(op.name());
VLOG(3) VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" << "; Mem: "
<< "Tensor: " << op.name() << "(" << op.type() << ")" << mem_id << "; Image shape: "
<< "; Mem: " << op.mem_id() << "; Image shape: "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())->image_shape()[0] << dynamic_cast<Image *>(tensor->UnderlyingBuffer())->image_shape()[0]
<< ", " << ", "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer()) << dynamic_cast<Image *>(tensor->UnderlyingBuffer())->image_shape()[1];
->image_shape()[1];
tensor_map_[op.output(0)] = std::move(tensor); tensor_map_[op.output(0)] = std::move(tensor);
} }
} }
}
} }
} // namespace mace } // namespace mace
...@@ -174,9 +174,8 @@ class OperatorDef { ...@@ -174,9 +174,8 @@ class OperatorDef {
const std::string &type() const; const std::string &type() const;
void set_type(const std::string &type_); void set_type(const std::string &type_);
bool has_type() const; bool has_type() const;
int mem_id() const; const std::vector<int> &mem_id() const;
void set_mem_id(const int mem_id); void set_mem_id(const std::vector<int> &value);
bool has_mem_id() const;
uint32_t node_id() const; uint32_t node_id() const;
void set_node_id(uint32_t node_id); void set_node_id(uint32_t node_id);
uint32_t op_id() const; uint32_t op_id() const;
...@@ -220,7 +219,7 @@ class OperatorDef { ...@@ -220,7 +219,7 @@ class OperatorDef {
std::vector<OutputShape> output_shape_; std::vector<OutputShape> output_shape_;
std::vector<DataType> output_type_; std::vector<DataType> output_type_;
int mem_id_; std::vector<int> mem_id_;
// nnlib // nnlib
uint32_t node_id_; uint32_t node_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册