提交 499fa1b8 编写于 作者: X xiebaiyuan 提交者: GitHub

open cl mem optimise, split with cpu codes. fix a bug when some memory is not equal 4 . (#2161)

* open cl mem optimise, split with cpu codes.
fix a bug when some memory is not equal 4 .

* open cl mem optimise, split with cpu codes.   fix a bug when some memory is not equal 4  fix bad des
上级 36d8b5bc
...@@ -126,15 +126,24 @@ Executor<Device, T>::Executor(const Program<Device> &program, ...@@ -126,15 +126,24 @@ Executor<Device, T>::Executor(const Program<Device> &program,
printf("================[ op init profile ]==================\n"); printf("================[ op init profile ]==================\n");
PrintProfile(profile); PrintProfile(profile);
#endif #endif
ApplyMemoryOptimise(config, lod_mode);
}
template <typename Device, typename T>
void Executor<Device, T>::ApplyMemoryOptimise(
const PaddleMobileConfigInternal &config, const bool lod_mode) const {}
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
template <>
void Executor<GPU_CL, float>::ApplyMemoryOptimise(
const PaddleMobileConfigInternal &config, const bool lod_mode) const {
if (!config.load_when_predict && !lod_mode && if (!config.load_when_predict && !lod_mode &&
config_.memory_optimization_level != NoMemoryOptimization) { config_.memory_optimization_level != NoMemoryOptimization) {
pass::MemoryOptPassCl()(program_desc_.get(), program_.scope.get(), pass::MemoryOptPassCl()(program_desc_.get(), program_.scope.get(),
config_.memory_optimization_level); config_.memory_optimization_level);
} }
#endif
} }
#endif
template <typename Device, typename T> template <typename Device, typename T>
void Executor<Device, T>::InitFeedFetchList() { void Executor<Device, T>::InitFeedFetchList() {
......
...@@ -118,6 +118,8 @@ class Executor { ...@@ -118,6 +118,8 @@ class Executor {
void PrintProfile(const vector<Executor<Device, T>::ProfInfo> &profile) const; void PrintProfile(const vector<Executor<Device, T>::ProfInfo> &profile) const;
#endif #endif
void ApplyMemoryOptimise(const PaddleMobileConfigInternal &config,
const bool lod_mode) const;
}; };
} // namespace framework } // namespace framework
......
...@@ -181,10 +181,10 @@ void MemoryOptPassCl::ShareData( ...@@ -181,10 +181,10 @@ void MemoryOptPassCl::ShareData(
const int64_t numl = tensor->numel(); const int64_t numl = tensor->numel();
auto origin_tensor_dims = tensor->dims(); auto origin_tensor_dims = tensor->dims();
PADDLE_MOBILE_ENFORCE(origin_tensor_dims.size() == 4,
"tensor dims must larger than 4");
// for super ,hack origin dims // for super ,hack origin dims
if (target_dims.size() == 4) { if (target_dims.size() == 4) {
PADDLE_MOBILE_ENFORCE(origin_tensor_dims.size() == 4,
"tensor dims must be equal to 4");
origin_tensor_dims = {origin_tensor_dims[0], origin_tensor_dims[1], origin_tensor_dims = {origin_tensor_dims[0], origin_tensor_dims[1],
target_dims[2], target_dims[3]}; target_dims[2], target_dims[3]};
tensor->Resize(origin_tensor_dims); tensor->Resize(origin_tensor_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册