提交 0719bf08 编写于 作者: Z zhangyang0701

add reshape2 for FPGA track

上级 f473957a
......@@ -37,6 +37,13 @@ void format_image(framework::Tensor *image_tensor) {
}
}
void format_ofm(framework::Tensor *ofm_tensor) {
if (ofm_tensor->type() == typeid(float)) {
format_fp32_ofm(ofm_tensor);
} else {
format_fp16_ofm(ofm_tensor);
}
}
void format_fp16_ofm(framework::Tensor *ofm_tensor) {
auto dims = ofm_tensor->dims();
size_t memory_size = 0;
......
......@@ -23,6 +23,7 @@ namespace paddle_mobile {
namespace fpga {
void format_image(framework::Tensor* image_tensor);
void format_ofm(framework::Tensor* ofm_tensor);
void format_fp16_ofm(framework::Tensor* ofm_tensor); // only allocate memory
void format_fp16_ofm(framework::Tensor* ofm_tensor, framework::DDim dims);
void format_fp32_ofm(framework::Tensor* ofm_tensor);
......
......@@ -25,7 +25,6 @@ bool Reshape2Kernel<FPGA, float>::Init(Reshape2Param<FPGA> *param) {
auto input = const_cast<LoDTensor *>(param->InputX());
auto output = param->Out();
auto shape = param->Shape();
output->ShareDataWith(*input);
auto num_in = framework::product(input->dims());
auto num_shape = framework::product(framework::make_ddim(shape));
......@@ -38,22 +37,79 @@ bool Reshape2Kernel<FPGA, float>::Init(Reshape2Param<FPGA> *param) {
}
}
output->Resize(framework::make_ddim(shape));
output->set_type(input->type());
fpga::format_ofm(output);
DLOG << "input: " << input;
DLOG << "output: " << output;
return true;
}
void reshape(LoDTensor *input, LoDTensor *output) {
// Subscript r means after reshape
// TODO zhangyang verify this function
float *input_ptr_f, *output_ptr_f;
half *input_ptr_h, *output_ptr_h;
bool is_float = false;
if (input->type() == typeid(float)) {
input_ptr_f = input->data<float>();
output_ptr_f = output->data<float>();
is_float = true;
} else {
input_ptr_h = input->data<half>();
output_ptr_h = output->data<half>();
}
auto C = static_cast<int>(input->dims()[1]);
auto H = static_cast<int>(input->dims()[2]);
auto W = static_cast<int>(input->dims()[3]);
auto Cr = static_cast<int>(output->dims()[1]);
auto Hr = static_cast<int>(output->dims()[2]);
auto Wr = static_cast<int>(output->dims()[3]);
PADDLE_MOBILE_ENFORCE(C * H * W == Cr * Hr * Wr, "Dims don't match");
auto WC = W * C;
auto WC_align = fpga::align_to_x(WC, IMAGE_ALIGNMENT);
auto HW = H * W;
auto WCr = Wr * Cr;
auto WCr_align = fpga::align_to_x(WCr, IMAGE_ALIGNMENT);
auto HWr = Hr * Wr;
int offset_align = 0;
int offset_r = 0, offset_align_r = 0;
int cr = 0, hr = 0, wr = 0;
for (int h = 0; h < H; h++) {
int offset0 = h * WC_align;
for (int w = 0; w < W; w++) {
int offset1 = w * C + offset0;
for (int c = 0; c < C; c++) {
offset_align = offset1 + c;
offset_r = c * HW + h * W + c;
cr = offset_r / HWr;
hr = offset_r % HWr / Wr;
wr = offset_r % Wr;
offset_align_r = hr * WCr_align + wr * Cr + cr;
// DLOG << "hwc"<< h<< " " << w << " " << c;
// DLOG << "hrwrcr" << hr<< " " << wr << " " << cr;
if (is_float) {
output_ptr_f[offset_align_r] = input_ptr_f[offset_align];
} else {
output_ptr_h[offset_align_r] = input_ptr_h[offset_align];
}
}
}
}
}
template <>
void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) {
auto input = const_cast<LoDTensor *>(param.InputX());
auto output = param.Out();
auto shape = param.Shape();
if (output->type() != typeid(half)) {
DLOG << "wrong type";
}
auto num_in = framework::product(input->dims());
auto num_shape = framework::product(framework::make_ddim(shape));
PADDLE_MOBILE_ENFORCE(num_shape != 0, "0 index is not supported");
......@@ -65,10 +121,12 @@ void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) {
}
}
output->Resize(framework::make_ddim(shape));
if (output->type() != typeid(half)) {
DLOG << "wrong type";
DLOG << output;
if (output->dims() == input->dims()) {
DLOG << "No need to reshape";
return;
}
reshape(input, output);
//
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册