提交 2d50eea4 编写于 作者: L Liangliang He

Merge branch 'rm-buffer' into 'master'

Remove unsed buffer

See merge request !197
...@@ -530,9 +530,13 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type) : ...@@ -530,9 +530,13 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type) :
if (!net->Run()) { if (!net->Run()) {
LOG(FATAL) << "Net init run failed"; LOG(FATAL) << "Net init run failed";
} }
ws_->RemoveUnsedTensor();
ws_->CreateTensor("mace_input_node:0", ws_->CreateTensor("mace_input_node:0",
GetDeviceAllocator(device_type_), GetDeviceAllocator(device_type_),
DT_FLOAT); DT_FLOAT);
ws_->CreateTensor("mace_output_node:0",
GetDeviceAllocator(device_type_),
DT_FLOAT);
net_ = std::move(CreateNet(op_registry_, *net_def, ws_.get(), device_type)); net_ = std::move(CreateNet(op_registry_, *net_def, ws_.get(), device_type));
} }
} }
...@@ -548,14 +552,8 @@ bool MaceEngine::Run(const float *input, ...@@ -548,14 +552,8 @@ bool MaceEngine::Run(const float *input,
const std::vector<index_t> &input_shape, const std::vector<index_t> &input_shape,
float *output) { float *output) {
MACE_CHECK(output != nullptr, "output ptr cannot be NULL"); MACE_CHECK(output != nullptr, "output ptr cannot be NULL");
Tensor *input_tensor = Tensor *input_tensor = ws_->GetTensor("mace_input_node:0");
ws_->CreateTensor("mace_input_node:0", Tensor *output_tensor = ws_->GetTensor("mace_output_node:0");
GetDeviceAllocator(device_type_),
DT_FLOAT);
Tensor *output_tensor =
ws_->CreateTensor("mace_output_node:0",
GetDeviceAllocator(device_type_),
DT_FLOAT);
input_tensor->Resize(input_shape); input_tensor->Resize(input_shape);
{ {
Tensor::MappingGuard input_guard(input_tensor); Tensor::MappingGuard input_guard(input_tensor);
......
...@@ -70,6 +70,7 @@ class Tensor { ...@@ -70,6 +70,7 @@ class Tensor {
dtype_(DT_FLOAT), dtype_(DT_FLOAT),
buffer_(nullptr), buffer_(nullptr),
data_(nullptr), data_(nullptr),
unused_(false),
is_image_(false){}; is_image_(false){};
Tensor(Allocator *alloc, DataType type) Tensor(Allocator *alloc, DataType type)
...@@ -78,6 +79,7 @@ class Tensor { ...@@ -78,6 +79,7 @@ class Tensor {
dtype_(type), dtype_(type),
buffer_(nullptr), buffer_(nullptr),
data_(nullptr), data_(nullptr),
unused_(false),
is_image_(false){}; is_image_(false){};
~Tensor() { ~Tensor() {
...@@ -114,6 +116,8 @@ class Tensor { ...@@ -114,6 +116,8 @@ class Tensor {
inline index_t raw_size() const { return size_ * SizeOfType(); } inline index_t raw_size() const { return size_ * SizeOfType(); }
inline const bool unused() const { return unused_; }
inline int64_t NumElements() const { inline int64_t NumElements() const {
return std::accumulate(shape_.begin(), shape_.end(), 1, return std::accumulate(shape_.begin(), shape_.end(), 1,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
...@@ -303,6 +307,10 @@ class Tensor { ...@@ -303,6 +307,10 @@ class Tensor {
} }
} }
inline void MarkUnused() {
this->unused_ = true;
}
class MappingGuard { class MappingGuard {
public: public:
MappingGuard(const Tensor *tensor) : tensor_(tensor) { MappingGuard(const Tensor *tensor) : tensor_(tensor) {
...@@ -343,6 +351,7 @@ class Tensor { ...@@ -343,6 +351,7 @@ class Tensor {
mutable void *data_; mutable void *data_;
vector<index_t> shape_; vector<index_t> shape_;
// Image for opencl // Image for opencl
bool unused_;
bool is_image_; bool is_image_;
std::vector<size_t> image_shape_; std::vector<size_t> image_shape_;
......
...@@ -47,6 +47,18 @@ const Tensor *Workspace::GetTensor(const string &name) const { ...@@ -47,6 +47,18 @@ const Tensor *Workspace::GetTensor(const string &name) const {
return nullptr; return nullptr;
} }
void Workspace::RemoveUnsedTensor() {
auto iter = tensor_map_.begin();
auto end_iter = tensor_map_.end();
while(iter != end_iter) {
auto old_iter = iter++;
if(old_iter->second->unused()) {
tensor_map_.erase(old_iter);
}
}
}
Tensor *Workspace::GetTensor(const string &name) { Tensor *Workspace::GetTensor(const string &name) {
return const_cast<Tensor *>( return const_cast<Tensor *>(
static_cast<const Workspace *>(this)->GetTensor(name)); static_cast<const Workspace *>(this)->GetTensor(name));
......
...@@ -23,6 +23,8 @@ class Workspace { ...@@ -23,6 +23,8 @@ class Workspace {
bool RemoveTensor(const string &name); bool RemoveTensor(const string &name);
void RemoveUnsedTensor();
inline bool HasTensor(const string &name) const { inline bool HasTensor(const string &name) const {
return tensor_map_.count(name); return tensor_map_.count(name);
} }
......
...@@ -27,7 +27,6 @@ struct BufferToImageFunctor : BufferToImageFunctorBase{ ...@@ -27,7 +27,6 @@ struct BufferToImageFunctor : BufferToImageFunctorBase{
StatsFuture *future) { StatsFuture *future) {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
bool i2b_;
}; };
template<typename T> template<typename T>
......
...@@ -19,6 +19,7 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer, ...@@ -19,6 +19,7 @@ void BufferToImageFunctor<DeviceType::OPENCL, T>::operator()(Tensor *buffer,
if (!i2b_) { if (!i2b_) {
CalImage2DShape(buffer->shape(), type, image_shape); CalImage2DShape(buffer->shape(), type, image_shape);
image->ResizeImage(buffer->shape(), image_shape); image->ResizeImage(buffer->shape(), image_shape);
buffer->MarkUnused();
} else { } else {
image_shape = image->image_shape(); image_shape = image->image_shape();
buffer->Resize(image->shape()); buffer->Resize(image->shape());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册