From a1f40a2c931774d855ff8c29652dcf01858d59cc Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Fri, 19 Oct 2018 15:26:40 +0800 Subject: [PATCH] update fetch kernel --- src/framework/executor.cpp | 14 ++------------ src/framework/operator.cpp | 2 +- src/operators/kernel/cl/fetch_kernel.cpp | 2 ++ 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 4e4fc5c0cb..8da8c1488b 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -922,12 +922,7 @@ void Executor::InitMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - auto tensor = var->template GetMutable(); - if (var_desc->Name() == "fetch") { - const framework::TensorDesc &desc = var_desc->Tensor_desc(); - framework::DDim ddim = framework::make_ddim(desc.Dims()); - tensor->mutable_data(ddim); - } + var->template GetMutable(); continue; } else { cl_image = var->template GetMutable(); @@ -991,12 +986,7 @@ void Executor::InitCombineMemory() { if (var_desc->Persistable()) { CLImage *cl_image = nullptr; if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { - auto tensor = var->template GetMutable(); - if (var_desc->Name() == "fetch") { - const framework::TensorDesc &desc = var_desc->Tensor_desc(); - framework::DDim ddim = framework::make_ddim(desc.Dims()); - tensor->mutable_data(ddim); - } + var->template GetMutable(); continue; } else { cl_image = var->template GetMutable(); diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index ab9d4f788a..9aeed326d1 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -95,7 +95,7 @@ void OperatorBase::Run() { if (type_ == "fetch") { Tensor *tensor = vari->template GetMutable(); if (tensor) { - DLOG << type_ << " output- " << key << "=" << tensor->dims(); + DLOG << type_ << " output- " << key << "=" << *tensor; } } else { CLImage *cl_image = vari->template GetMutable(); diff --git a/src/operators/kernel/cl/fetch_kernel.cpp b/src/operators/kernel/cl/fetch_kernel.cpp index a84f8d82f8..807f883344 100644 --- a/src/operators/kernel/cl/fetch_kernel.cpp +++ b/src/operators/kernel/cl/fetch_kernel.cpp @@ -21,6 +21,8 @@ namespace operators { template <> bool FetchKernel::Init(FetchParam *param) { this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + auto *out = param->Out(); + out->mutable_data(); return true; } -- GitLab