diff --git a/mobile/src/framework/executor.cpp b/mobile/src/framework/executor.cpp index 7e5c15c9b729933534db919b5cc773433b61336c..c1ff6ee29b1f7be509e309bbbc302aab3eff209c 100644 --- a/mobile/src/framework/executor.cpp +++ b/mobile/src/framework/executor.cpp @@ -126,15 +126,24 @@ Executor::Executor(const Program &program, printf("================[ op init profile ]==================\n"); PrintProfile(profile); #endif + ApplyMemoryOptimise(config, lod_mode); +} + +template +void Executor::ApplyMemoryOptimise( + const PaddleMobileConfigInternal &config, const bool lod_mode) const {} #ifdef PADDLE_MOBILE_CL +template <> +void Executor::ApplyMemoryOptimise( + const PaddleMobileConfigInternal &config, const bool lod_mode) const { if (!config.load_when_predict && !lod_mode && config_.memory_optimization_level != NoMemoryOptimization) { pass::MemoryOptPassCl()(program_desc_.get(), program_.scope.get(), config_.memory_optimization_level); } -#endif } +#endif template void Executor::InitFeedFetchList() { diff --git a/mobile/src/framework/executor.h b/mobile/src/framework/executor.h index 4f108c993c0ff9bda94b11cdebc3cb13af41be03..ebb16f697b39391cd5f405c565285c1bd37dfad5 100644 --- a/mobile/src/framework/executor.h +++ b/mobile/src/framework/executor.h @@ -118,6 +118,8 @@ class Executor { void PrintProfile(const vector::ProfInfo> &profile) const; #endif + void ApplyMemoryOptimise(const PaddleMobileConfigInternal &config, + const bool lod_mode) const; }; } // namespace framework diff --git a/mobile/src/pass/memory_optimize_cl.cpp b/mobile/src/pass/memory_optimize_cl.cpp index 81e41beaaf1d632e6993fabf1f2f60e9d1e63acb..355123349d645075fd2ccc37144144da7d332a8f 100644 --- a/mobile/src/pass/memory_optimize_cl.cpp +++ b/mobile/src/pass/memory_optimize_cl.cpp @@ -181,10 +181,10 @@ void MemoryOptPassCl::ShareData( const int64_t numl = tensor->numel(); auto origin_tensor_dims = tensor->dims(); - PADDLE_MOBILE_ENFORCE(origin_tensor_dims.size() == 4, - "tensor dims must larger than 4"); // for super ,hack origin dims if (target_dims.size() == 4) { + PADDLE_MOBILE_ENFORCE(origin_tensor_dims.size() == 4, + "tensor dims must be equal to 4"); origin_tensor_dims = {origin_tensor_dims[0], origin_tensor_dims[1], target_dims[2], target_dims[3]}; tensor->Resize(origin_tensor_dims);