From 44316e0bfb829118e6c2b8b28739f016c3ce8d43 Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Fri, 19 Oct 2018 14:41:26 +0800 Subject: [PATCH] update fetch kernel --- src/framework/executor.cpp | 16 +++++++++++++--- src/operators/kernel/cl/fetch_kernel.cpp | 18 +++++++++++++----- src/operators/op_param.h | 2 +- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index bc2a46ab35..4e4fc5c0cb 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -37,7 +37,7 @@ limitations under the License. */ #include "framework/cl/cl_image.h" #endif -int debug_to = 34; +int debug_to = 33; namespace paddle_mobile { namespace framework { @@ -922,7 +922,12 @@ void Executor::InitMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + auto tensor = var->template GetMutable(); + if (var_desc->Name() == "fetch") { + const framework::TensorDesc &desc = var_desc->Tensor_desc(); + framework::DDim ddim = framework::make_ddim(desc.Dims()); + tensor->mutable_data(ddim); + } continue; } else { cl_image = var->template GetMutable(); @@ -986,7 +991,12 @@ void Executor::InitCombineMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + auto tensor = var->template GetMutable(); + if (var_desc->Name() == "fetch") { + const framework::TensorDesc &desc = var_desc->Tensor_desc(); + framework::DDim ddim = framework::make_ddim(desc.Dims()); + tensor->mutable_data(ddim); + } continue; } else { cl_image = var->template GetMutable(); diff --git a/src/operators/kernel/cl/fetch_kernel.cpp b/src/operators/kernel/cl/fetch_kernel.cpp index 87a892ee9c..a84f8d82f8 100644 --- a/src/operators/kernel/cl/fetch_kernel.cpp +++ b/src/operators/kernel/cl/fetch_kernel.cpp @@ -32,11 +32,19 @@ void FetchKernel::Compute(const FetchParam ¶m) { auto input = param.InputX()->GetCLImage(); auto *out = param.Out(); - const auto &dims = param.InputX()->dims(); - const int N = dims[0]; - const int C = dims[1]; - const int in_height = dims[2]; - const int in_width = dims[3]; + const auto &dim = param.InputX()->dims(); + size_t new_dims[] = {1, 1, 1, 1}; + + for (int j = 0; j < dim.size(); ++j) { + new_dims[4 - dim.size() + j] = dim[j]; + } + + size_t N, C, in_height, in_width; + + N = new_dims[0]; + C = new_dims[1]; + in_height = new_dims[2]; + in_width = new_dims[3]; int size_ch = in_height * in_width; int size_block = size_ch * 4; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 71a2ae5d10..4d03159751 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -953,7 +953,7 @@ class FetchParam : public OpParam { Tensor *Out() const { return out_; } static Tensor *OutFrom(const VariableNameMap &outputs, const Scope &scope) { - return GetVarValue("Out", outputs, scope); + return GetVarValue("Out", outputs, scope); } private: -- GitLab