提交 827e349a 编写于 作者: J jameswu2014

reshape2-format-modify

上级 6c9aa421
...@@ -80,7 +80,7 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) { ...@@ -80,7 +80,7 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
auto output = &param.Out()->at(col); auto output = &param.Out()->at(col);
if (input->type() == type_id<float>()) { if (input->type() == type_id<float>()) {
output->ShareDataWith(*input); output->ShareDataWith(*input);
// output dims equal to input dim // output dims == input dim
output->Resize(input->dims()); output->Resize(input->dims());
return; return;
} }
......
...@@ -92,29 +92,29 @@ void reshape(LoDTensor *input, LoDTensor *output) { ...@@ -92,29 +92,29 @@ void reshape(LoDTensor *input, LoDTensor *output) {
fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half)); fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half));
} }
static inline bool reshape2_judge(const framework::DDim input_dims,const framework::DDim output_dims){ static inline bool reshape2_judge(const framework::DDim input_dims,
const framework::DDim output_dims) {
int input_dims_size = input_dims.size(); int input_dims_size = input_dims.size();
int output_dims_size = output_dims.size(); int output_dims_size = output_dims.size();
bool dims_flag2 = true; bool dims_flag2 = true;
auto temp_dims = input_dims_size > output_dims_size ? input_dims : output_dims; auto temp_dims = input_dims_size > output_dims_size ?
int short_dims = input_dims_size > output_dims_size ? output_dims_size : input_dims_size; input_dims : output_dims;
for(int i = 0; i < temp_dims.size(); ++i){ int short_dims = input_dims_size > output_dims_size ?
if(i < short_dims){ output_dims_size : input_dims_size;
if(input_dims[i] != output_dims[i]){ for (int i = 0; i < temp_dims.size(); ++i) {
if (i < short_dims) {
if (input_dims[i] != output_dims[i]) {
dims_flag2 = false; dims_flag2 = false;
break; break;
} }
} } else {
else{ if (temp_dims[i] != 1) {
if(temp_dims[i] != 1){
dims_flag2 = false; dims_flag2 = false;
break; break;
} }
} }
} }
return dims_flag2; return dims_flag2;
} }
template <> template <>
...@@ -141,12 +141,10 @@ void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) { ...@@ -141,12 +141,10 @@ void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) {
bool dims_flags = input_dims == output_dims; bool dims_flags = input_dims == output_dims;
bool dims_flag2 = true; bool dims_flag2 = true;
if(!dims_flags){ if (!dims_flags) {
dims_flag2 = reshape2_judge(input_dims, output_dims); dims_flag2 = reshape2_judge(input_dims, output_dims);
} }
if (dims_flags || dims_flag2) {
if (dims_flags || dims_flag2 ) {
DLOG << "No need to reshape"; DLOG << "No need to reshape";
output->ShareDataWith(*input); output->ShareDataWith(*input);
framework::LoD lod = input->lod(); framework::LoD lod = input->lod();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册