提交 827e349a 编写于 作者: J jameswu2014

reshape2-format-modify

上级 6c9aa421
...@@ -80,7 +80,7 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) { ...@@ -80,7 +80,7 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
auto output = &param.Out()->at(col); auto output = &param.Out()->at(col);
if (input->type() == type_id<float>()) { if (input->type() == type_id<float>()) {
output->ShareDataWith(*input); output->ShareDataWith(*input);
// output dims equal to input dim // output dims == input dim
output->Resize(input->dims()); output->Resize(input->dims());
return; return;
} }
......
...@@ -92,69 +92,67 @@ void reshape(LoDTensor *input, LoDTensor *output) { ...@@ -92,69 +92,67 @@ void reshape(LoDTensor *input, LoDTensor *output) {
fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half)); fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half));
} }
static inline bool reshape2_judge(const framework::DDim input_dims,const framework::DDim output_dims){ static inline bool reshape2_judge(const framework::DDim input_dims,
int input_dims_size = input_dims.size(); const framework::DDim output_dims) {
int output_dims_size = output_dims.size(); int input_dims_size = input_dims.size();
bool dims_flag2 = true; int output_dims_size = output_dims.size();
auto temp_dims = input_dims_size > output_dims_size ? input_dims : output_dims; bool dims_flag2 = true;
int short_dims = input_dims_size > output_dims_size ? output_dims_size : input_dims_size; auto temp_dims = input_dims_size > output_dims_size ?
for(int i = 0; i < temp_dims.size(); ++i){ input_dims : output_dims;
if(i < short_dims){ int short_dims = input_dims_size > output_dims_size ?
if(input_dims[i] != output_dims[i]){ output_dims_size : input_dims_size;
dims_flag2 = false; for (int i = 0; i < temp_dims.size(); ++i) {
break; if (i < short_dims) {
} if (input_dims[i] != output_dims[i]) {
} dims_flag2 = false;
else{ break;
if(temp_dims[i] != 1){ }
dims_flag2 = false; } else {
break; if (temp_dims[i] != 1) {
} dims_flag2 = false;
} break;
} }
return dims_flag2; }
}
return dims_flag2;
} }
template <> template <>
void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) { void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) {
auto input = const_cast<LoDTensor *>(param.InputX()); auto input = const_cast<LoDTensor *>(param.InputX());
auto output = param.Out(); auto output = param.Out();
auto shape = param.Shape(); auto shape = param.Shape();
auto num_in = framework::product(input->dims()); auto num_in = framework::product(input->dims());
auto num_shape = framework::product(framework::make_ddim(shape)); auto num_shape = framework::product(framework::make_ddim(shape));
PADDLE_MOBILE_ENFORCE(num_shape != 0, "0 index is not supported"); PADDLE_MOBILE_ENFORCE(num_shape != 0, "0 index is not supported");
for (int i = 0; i < shape.size(); i++) { for (int i = 0; i < shape.size(); i++) {
if (shape[i] == -1) { if (shape[i] == -1) {
shape[i] = static_cast<int>(-num_in / num_shape); shape[i] = static_cast<int>(-num_in / num_shape);
break; break;
}
} }
} output->Resize(framework::make_ddim(shape));
output->Resize(framework::make_ddim(shape));
auto input_dims = input->dims(); auto input_dims = input->dims();
auto output_dims = output->dims(); auto output_dims = output->dims();
bool dims_flags = input_dims == output_dims; bool dims_flags = input_dims == output_dims;
bool dims_flag2 = true; bool dims_flag2 = true;
if(!dims_flags){ if (!dims_flags) {
dims_flag2 = reshape2_judge(input_dims, output_dims); dims_flag2 = reshape2_judge(input_dims, output_dims);
} }
if (dims_flags || dims_flag2) {
DLOG << "No need to reshape";
if (dims_flags || dims_flag2 ) { output->ShareDataWith(*input);
DLOG << "No need to reshape"; framework::LoD lod = input->lod();
output->ShareDataWith(*input); output->set_lod(lod);
framework::LoD lod = input->lod(); return;
output->set_lod(lod); }
return;
}
reshape(input, output); reshape(input, output);
// //
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册