提交 dc7c1455 编写于 作者: Z zhaojiaying01

update fetch kernel

上级 98f3e10f
...@@ -37,7 +37,7 @@ limitations under the License. */ ...@@ -37,7 +37,7 @@ limitations under the License. */
#include "framework/cl/cl_image.h" #include "framework/cl/cl_image.h"
#endif #endif
int debug_to = 34; int debug_to = 33;
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -922,7 +922,12 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() { ...@@ -922,7 +922,12 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() {
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
CLImage *cl_image = nullptr; CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<framework::LoDTensor>(); auto tensor = var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "fetch") {
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
tensor->mutable_data<float>(ddim);
}
continue; continue;
} else { } else {
cl_image = var->template GetMutable<framework::CLImage>(); cl_image = var->template GetMutable<framework::CLImage>();
...@@ -986,7 +991,12 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() { ...@@ -986,7 +991,12 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
CLImage *cl_image = nullptr; CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<framework::LoDTensor>(); auto tensor = var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "fetch") {
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
tensor->mutable_data<float>(ddim);
}
continue; continue;
} else { } else {
cl_image = var->template GetMutable<framework::CLImage>(); cl_image = var->template GetMutable<framework::CLImage>();
......
...@@ -32,11 +32,19 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) { ...@@ -32,11 +32,19 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
auto input = param.InputX()->GetCLImage(); auto input = param.InputX()->GetCLImage();
auto *out = param.Out(); auto *out = param.Out();
const auto &dims = param.InputX()->dims(); const auto &dim = param.InputX()->dims();
const int N = dims[0]; size_t new_dims[] = {1, 1, 1, 1};
const int C = dims[1];
const int in_height = dims[2]; for (int j = 0; j < dim.size(); ++j) {
const int in_width = dims[3]; new_dims[4 - dim.size() + j] = dim[j];
}
size_t N, C, in_height, in_width;
N = new_dims[0];
C = new_dims[1];
in_height = new_dims[2];
in_width = new_dims[3];
int size_ch = in_height * in_width; int size_ch = in_height * in_width;
int size_block = size_ch * 4; int size_block = size_ch * 4;
......
...@@ -953,7 +953,7 @@ class FetchParam : public OpParam { ...@@ -953,7 +953,7 @@ class FetchParam : public OpParam {
Tensor *Out() const { return out_; } Tensor *Out() const { return out_; }
static Tensor *OutFrom(const VariableNameMap &outputs, const Scope &scope) { static Tensor *OutFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<Tensor>("Out", outputs, scope); return GetVarValue<LoDTensor>("Out", outputs, scope);
} }
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册