提交 b7d1b433 编写于 作者: B bjjwwang

fix bsf memory error

上级 1cb60d3f
......@@ -88,8 +88,14 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
// 此时 lod 为空。
tensor_out.lod = batchTask._batch_out[fetchvar_index].lod;
// resize all batch memory at one time
size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index;
tensor_out.data.Resize(databuf_size);
void* databuf_data = MempoolWrapper::instance().malloc(databuf_size,memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, databuf_size);
tensor_out.data = paddleBuf;
//tensor_out.data.Resize(databuf_size);
} else {
// 当taskmeta_num = 1时,由于同时只有一个taskMeta操作task
// 不涉及线程安全问题,所以此时可以直接由taskMeta->task->resize->copy
......@@ -209,7 +215,7 @@ void TaskExecutor<TaskT>::stop() {
template <typename TaskT>
TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
const void* inVectorT_ptr,
void* outVectorT_ptr) { // NOLINT
void* outVectorT_ptr, MempoolRegion* memoryPtr) { // NOLINT
TaskT* task = butil::get_object<TaskT>();
if (!task) {
LOG(ERROR) << "Failed get TaskT from object pool";
......@@ -235,7 +241,8 @@ TaskHandler<TaskT> TaskExecutor<TaskT>::schedule(
task->read_fd = fds[0];
task->write_fd = fds[1];
task->owner_tid = ::syscall(SYS_gettid);
task->memoryPtr = memoryPtr;
//task->_bspec_key = _bspec_key;
task->inVectorT_ptr = (const InVectorT*)inVectorT_ptr;
task->outVectorT_ptr = (OutVectorT*)outVectorT_ptr;
if (!task->task_init()) {
......@@ -403,9 +410,9 @@ int TaskExecutor<TaskT>::work(ThreadContext<TaskT>* context) {
template <typename InItemT, typename OutItemT>
bool TaskManager<InItemT, OutItemT>::schedule(const void* in,
void* out) { // NOLINT
void* out, MempoolRegion* memoryPtr) { // NOLINT
TaskHandler<TaskT> handler =
TaskExecutorVector<TaskT>::instance()[_model_index].schedule(in, out);
TaskExecutorVector<TaskT>::instance()[_model_index].schedule(in, out, memoryPtr);
if (handler.valid()) {
_task_owned = handler;
......
......@@ -38,6 +38,8 @@ namespace im {
namespace bsf {
static const size_t DEFAULT_BATCH_SIZE = 100;
typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper;
typedef baidu::paddle_serving::predictor::MempoolRegion MempoolRegion;
// InItemT is paddle::PaddleTensor
// InVectorT std::vector<paddle::PaddleTensor>
......@@ -61,6 +63,7 @@ struct Task {
typedef Task<InItemT, OutItemT> TaskT;
typedef std::vector<size_t> ShapeVector;
typedef std::vector<ShapeVector> VectorOfShapeVector;
typedef baidu::paddle_serving::predictor::MempoolWrapper MempoolWrapper;
int read_fd;
int write_fd;
......@@ -79,6 +82,7 @@ struct Task {
bool fetch_init;
// taskmeta_num * set_feed_lod_index.size()
std::vector<OutVectorT> outLodTensorVector;
MempoolRegion* memoryPtr;
Task() {
read_fd = -1;
......@@ -364,7 +368,12 @@ struct Task {
}
// 一次性扩容PaddleTensor中的data和lod
paddle::PaddleTensor& fetchVarTensor = (*outVectorT_ptr)[feedvar_index];
fetchVarTensor.data.Resize(data_length);
void* databuf_data = MempoolWrapper::instance().malloc(data_length,memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, data_length);
fetchVarTensor.data = paddleBuf;
//fetchVarTensor.data.Resize(data_length);
// task中的lod补0
if (fetchVarTensor.lod.size() <= 0) {
fetchVarTensor.lod.push_back({0});
......@@ -625,8 +634,10 @@ class BatchTasks {
paddleTensor.lod = _batch_in_lod[feedvar_index];
paddleTensor.shape = feedVarTensor.shape;
paddleTensor.shape[0] = _total_shape0_batch_in[feedvar_index];
paddleTensor.data.Resize(feedvar_bytesize *
_total_shape0_batch_in[feedvar_index]);
size_t databuf_size = feedvar_bytesize * _total_shape0_batch_in[feedvar_index];
void* databuf_data = MempoolWrapper::instance().malloc(databuf_size);
paddle::PaddleBuf paddleBuf(databuf_data, databuf_size);
paddleTensor.data = paddleBuf;
_batch_in.push_back(paddleTensor);
}
......@@ -733,16 +744,27 @@ class BatchTasks {
// 此时,无法分辨是否是天然nobatch,此时set_fetch_nobatch_index会漏掉
// 后续希望在其他地方能够区分两者。
if (fetchvar_batch_size(fetchvar_index) != _total_fetch_batch) {
if(fetchvar_batch_size(fetchvar_index) <= 0){
// which means error.
if (fetchvar_batch_size(fetchvar_index) != 1 &&
_total_fetch_batch != 1) {
return false;
} else {
}else if(fetchvar_batch_size(fetchvar_index) == 1){
// which means fetchvar shape[0] = 1.
// shape[0] does not change with batch
set_fetch_nobatch_index.insert(fetchvar_index);
_total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
}else if(_total_fetch_batch == 1){
//这时意味着,之前的fetchvar shape[0] 全部都= 1
//当前的fetchvar shape[0] > 1
//所以,之前的都是no_batch
for(size_t temp_index = fetchvar_index-1; temp_index >= 0; --temp_index){
set_fetch_nobatch_index.insert(fetchvar_index);
}
_total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
}else{
// which means error.
return false;
}
}
// 将lod fetchvar index加入到vector中。
......@@ -824,7 +846,11 @@ class BatchTasks {
task->outLodTensorVector[taskmeta_index][fetch_lod_index];
size_t length = fetchvar_bytesize_index * shape0_length;
fetchVarTensor.shape[0] = shape0_length;
fetchVarTensor.data.Resize(length);
void* databuf_data = MempoolWrapper::instance().malloc(length,task->memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, length);
fetchVarTensor.data = paddleBuf;
//fetchVarTensor.data.Resize(length);
void* dst_ptr = fetchVarTensor.data.data();
void* source_ptr = _batch_out[fetchvar_index].data.data() +
shape0_index_start * fetchvar_bytesize_index;
......@@ -850,7 +876,12 @@ class BatchTasks {
(*task->outVectorT_ptr)[fetchvar_index];
size_t length = fetchvar_bytesize_index * shape0_length;
fetchVarTensor.shape[0] = shape0_length;
fetchVarTensor.data.Resize(length);
void* databuf_data = MempoolWrapper::instance().malloc(length,task->memoryPtr);
paddle::PaddleBuf paddleBuf(databuf_data, length);
fetchVarTensor.data = paddleBuf;
//fetchVarTensor.data.Resize(length);
void* dst_ptr = fetchVarTensor.data.data();
void* source_ptr = _batch_out[fetchvar_index].data.data() +
shape0_index_start * fetchvar_bytesize_index;
......@@ -1076,7 +1107,7 @@ class TaskExecutor {
int work(ThreadContext<TaskT>* context);
TaskHandler<TaskT> schedule(const void*, void*);
TaskHandler<TaskT> schedule(const void*, void*, MempoolRegion* memoryPtr);
bool move_task_to_batch(BatchTasks<TaskT>& batchTask); // NOLINT
......@@ -1159,7 +1190,7 @@ class TaskManager {
~TaskManager() { wait(); }
bool schedule(const void* in, void* out); // NOLINT
bool schedule(const void* in, void* out, MempoolRegion* memoryPtr); // NOLINT
void wait();
inline void clear() { wait(); }
......
......@@ -98,8 +98,7 @@ int ReloadableInferEngine::infer(const void* in,
im::bsf::TaskManager<paddle::PaddleTensor, paddle::PaddleTensor> task_manager(
_model_index);
task_manager.schedule(in, out);
task_manager.schedule(in, out, MempoolWrapper::instance().get_thread_memory_ptr());
task_manager.wait();
return 0;
}
......
......@@ -19,30 +19,6 @@ namespace baidu {
namespace paddle_serving {
namespace predictor {
// why we need MempoolRegion
// because we need to release the resource.
// so we need both Mempool and Region.
// Mempool is a wrapper class for us to use memory more safely.
// Region is the RAII class.
struct MempoolRegion {
MempoolRegion(im::fugue::memory::Region* region, im::Mempool* mempool)
: _region(region), _mempool(mempool) {}
im::fugue::memory::Region* region() { return _region; }
im::Mempool* mempool() { return _mempool; }
im::fugue::memory::Region* _region;
im::Mempool* _mempool;
~MempoolRegion() {
if (_region) {
delete _region;
_region = NULL;
}
if (_mempool) {
delete _mempool;
_mempool = NULL;
}
}
};
int MempoolWrapper::initialize() {
if (THREAD_KEY_CREATE(&_bspec_key, NULL) != 0) {
......@@ -112,6 +88,28 @@ void* MempoolWrapper::malloc(size_t size) {
return mempool->malloc(size);
}
void* MempoolWrapper::malloc(size_t size, MempoolRegion* my_mempool_region) {
MempoolRegion* mempool_region = my_mempool_region;
if (mempool_region == NULL) {
LOG(WARNING) << "THREAD_GETSPECIFIC() returned NULL";
return NULL;
}
im::Mempool* mempool = mempool_region->mempool();
if (!mempool) {
LOG(WARNING) << "Cannot malloc memory:" << size
<< ", since mempool is not thread initialized";
return NULL;
}
return mempool->malloc(size);
}
MempoolRegion* MempoolWrapper::get_thread_memory_ptr(){
MempoolRegion* mempool_region =
(MempoolRegion*)THREAD_GETSPECIFIC(_bspec_key);
return mempool_region;
}
void MempoolWrapper::free(void* p, size_t size) {
MempoolRegion* mempool_region =
(MempoolRegion*)THREAD_GETSPECIFIC(_bspec_key);
......
......@@ -21,6 +21,30 @@ namespace baidu {
namespace paddle_serving {
namespace predictor {
// why we need MempoolRegion
// because we need to release the resource.
// so we need both Mempool and Region.
// Mempool is a wrapper class for us to use memory more safely.
// Region is the RAII class.
struct MempoolRegion {
MempoolRegion(im::fugue::memory::Region* region, im::Mempool* mempool)
: _region(region), _mempool(mempool) {}
im::fugue::memory::Region* region() { return _region; }
im::Mempool* mempool() { return _mempool; }
im::fugue::memory::Region* _region;
im::Mempool* _mempool;
~MempoolRegion() {
if (_region) {
delete _region;
_region = NULL;
}
if (_mempool) {
delete _mempool;
_mempool = NULL;
}
}
};
class MempoolWrapper {
public:
MempoolWrapper() {}
......@@ -38,6 +62,10 @@ class MempoolWrapper {
void* malloc(size_t size);
void* malloc(size_t size, MempoolRegion* my_mempool_region);
MempoolRegion* get_thread_memory_ptr();
void free(void* p, size_t size);
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册