提交 dc7c1455 编写于 作者: Z zhaojiaying01

update fetch kernel

上级 98f3e10f
......@@ -37,7 +37,7 @@ limitations under the License. */
#include "framework/cl/cl_image.h"
#endif
int debug_to = 34;
int debug_to = 33;
namespace paddle_mobile {
namespace framework {
......@@ -922,7 +922,12 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() {
if (var_desc->Persistable()) {
CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<framework::LoDTensor>();
auto tensor = var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "fetch") {
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
tensor->mutable_data<float>(ddim);
}
continue;
} else {
cl_image = var->template GetMutable<framework::CLImage>();
......@@ -986,7 +991,12 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
if (var_desc->Persistable()) {
CLImage *cl_image = nullptr;
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
var->template GetMutable<framework::LoDTensor>();
auto tensor = var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "fetch") {
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
tensor->mutable_data<float>(ddim);
}
continue;
} else {
cl_image = var->template GetMutable<framework::CLImage>();
......
......@@ -32,11 +32,19 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
auto input = param.InputX()->GetCLImage();
auto *out = param.Out();
const auto &dims = param.InputX()->dims();
const int N = dims[0];
const int C = dims[1];
const int in_height = dims[2];
const int in_width = dims[3];
const auto &dim = param.InputX()->dims();
size_t new_dims[] = {1, 1, 1, 1};
for (int j = 0; j < dim.size(); ++j) {
new_dims[4 - dim.size() + j] = dim[j];
}
size_t N, C, in_height, in_width;
N = new_dims[0];
C = new_dims[1];
in_height = new_dims[2];
in_width = new_dims[3];
int size_ch = in_height * in_width;
int size_block = size_ch * 4;
......
......@@ -953,7 +953,7 @@ class FetchParam : public OpParam {
Tensor *Out() const { return out_; }
static Tensor *OutFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<Tensor>("Out", outputs, scope);
return GetVarValue<LoDTensor>("Out", outputs, scope);
}
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册