提交 257a7dc8 编写于 作者: J Jiaying Zhao 提交者: GitHub

Merge pull request #1156 from smilejames/opencl

update fetch kernel
...@@ -922,12 +922,7 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() { ...@@ -922,12 +922,7 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() {
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
CLImage *cl_image = nullptr; CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
auto tensor = var->template GetMutable<framework::LoDTensor>(); var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "fetch") {
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
tensor->mutable_data<float>(ddim);
}
continue; continue;
} else { } else {
cl_image = var->template GetMutable<framework::CLImage>(); cl_image = var->template GetMutable<framework::CLImage>();
...@@ -991,12 +986,7 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() { ...@@ -991,12 +986,7 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
CLImage *cl_image = nullptr; CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
auto tensor = var->template GetMutable<framework::LoDTensor>(); var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "fetch") {
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
tensor->mutable_data<float>(ddim);
}
continue; continue;
} else { } else {
cl_image = var->template GetMutable<framework::CLImage>(); cl_image = var->template GetMutable<framework::CLImage>();
......
...@@ -95,7 +95,7 @@ void OperatorBase<Dtype>::Run() { ...@@ -95,7 +95,7 @@ void OperatorBase<Dtype>::Run() {
if (type_ == "fetch") { if (type_ == "fetch") {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>(); Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
if (tensor) { if (tensor) {
DLOG << type_ << " output- " << key << "=" << tensor->dims(); DLOG << type_ << " output- " << key << "=" << *tensor;
} }
} else { } else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>(); CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
......
...@@ -21,6 +21,8 @@ namespace operators { ...@@ -21,6 +21,8 @@ namespace operators {
template <> template <>
bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) { bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
auto *out = param->Out();
out->mutable_data<float>();
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册