diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index bc2a46ab35570e9f60f663830dddb3836e247592..4e4fc5c0cb59725b33f63fb085350b6a81e79156 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -37,7 +37,7 @@ limitations under the License. */ #include "framework/cl/cl_image.h" #endif -int debug_to = 34; +int debug_to = 33; namespace paddle_mobile { namespace framework { @@ -922,7 +922,12 @@ void Executor::InitMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + auto tensor = var->template GetMutable(); + if (var_desc->Name() == "fetch") { + const framework::TensorDesc &desc = var_desc->Tensor_desc(); + framework::DDim ddim = framework::make_ddim(desc.Dims()); + tensor->mutable_data(ddim); + } continue; } else { cl_image = var->template GetMutable(); @@ -986,7 +991,12 @@ void Executor::InitCombineMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - var->template GetMutable(); + auto tensor = var->template GetMutable(); + if (var_desc->Name() == "fetch") { + const framework::TensorDesc &desc = var_desc->Tensor_desc(); + framework::DDim ddim = framework::make_ddim(desc.Dims()); + tensor->mutable_data(ddim); + } continue; } else { cl_image = var->template GetMutable(); diff --git a/src/operators/kernel/cl/fetch_kernel.cpp b/src/operators/kernel/cl/fetch_kernel.cpp index 87a892ee9c2a3ef9fa9584bc5e358ba4b06f7577..a84f8d82f8bacff833d7a5aa7f3bb1a0683d1f89 100644 --- a/src/operators/kernel/cl/fetch_kernel.cpp +++ b/src/operators/kernel/cl/fetch_kernel.cpp @@ -32,11 +32,19 @@ void FetchKernel::Compute(const FetchParam ¶m) { auto input = param.InputX()->GetCLImage(); auto *out = param.Out(); - const auto &dims = param.InputX()->dims(); - const int N = dims[0]; - const int C = dims[1]; - const int in_height = dims[2]; - const int in_width = dims[3]; + const auto &dim = param.InputX()->dims(); + size_t new_dims[] = {1, 1, 1, 1}; + + for (int j = 0; j < dim.size(); ++j) { + new_dims[4 - dim.size() + j] = dim[j]; + } + + size_t N, C, in_height, in_width; + + N = new_dims[0]; + C = new_dims[1]; + in_height = new_dims[2]; + in_width = new_dims[3]; int size_ch = in_height * in_width; int size_block = size_ch * 4; diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 71a2ae5d10f1b26e4f2d1d1e81d7e9157a646268..4d03159751c194e2ebb3d433bf2a0cf84377c286 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -953,7 +953,7 @@ class FetchParam : public OpParam { Tensor *Out() const { return out_; } static Tensor *OutFrom(const VariableNameMap &outputs, const Scope &scope) { - return GetVarValue("Out", outputs, scope); + return GetVarValue("Out", outputs, scope); } private: