提交 a1f40a2c 编写于 作者: Z zhaojiaying01

update fetch kernel

上级 71a1cf75
......@@ -922,12 +922,7 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() {
if (var_desc->Persistable()) {
CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
auto tensor = var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "fetch") {
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
tensor->mutable_data<float>(ddim);
}
var->template GetMutable<framework::LoDTensor>();
continue;
} else {
cl_image = var->template GetMutable<framework::CLImage>();
......@@ -991,12 +986,7 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
if (var_desc->Persistable()) {
CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
auto tensor = var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "fetch") {
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
tensor->mutable_data<float>(ddim);
}
var->template GetMutable<framework::LoDTensor>();
continue;
} else {
cl_image = var->template GetMutable<framework::CLImage>();
......
......@@ -95,7 +95,7 @@ void OperatorBase<Dtype>::Run() {
if (type_ == "fetch") {
Tensor *tensor = vari->template GetMutable<framework::LoDTensor>();
if (tensor) {
DLOG << type_ << " output- " << key << "=" << tensor->dims();
DLOG << type_ << " output- " << key << "=" << *tensor;
}
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
......
......@@ -21,6 +21,8 @@ namespace operators {
template <>
bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
auto *out = param->Out();
out->mutable_data<float>();
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册