提交 3d7530f4 编写于 作者: J jameswu2014

reshape2 ddim bug

上级 ca334444
......@@ -92,6 +92,31 @@ void reshape(LoDTensor *input, LoDTensor *output) {
fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half));
}
static inline bool reshape2_judge(const framework::DDim input_dims,const framework::DDim output_dims){
int input_dims_size = input_dims.size();
int output_dims_size = output_dims.size();
bool dims_flag2 = true;
auto temp_dims = input_dims_size > output_dims_size ? input_dims : output_dims;
int short_dims = input_dims_size > output_dims_size ? output_dims_size : input_dims_size;
for(int i = 0; i < temp_dims.size(); ++i){
if(i < short_dims){
if(input_dims[i] != output_dims[i]){
dims_flag2 = false;
break;
}
}
else{
if(temp_dims[i] != 1){
dims_flag2 = false;
break;
}
}
}
return dims_flag2;
}
template <>
void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) {
auto input = const_cast<LoDTensor *>(param.InputX());
......@@ -109,7 +134,19 @@ void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) {
}
}
output->Resize(framework::make_ddim(shape));
if (output->dims() == input->dims()) {
auto input_dims = input->dims();
auto output_dims = output->dims();
bool dims_flags = input_dims == output_dims;
bool dims_flag2 = true;
if(!dims_flags){
dims_flag2 = reshape2_judge(input_dims, output_dims);
}
if (dims_flags || dims_flag2 ) {
DLOG << "No need to reshape";
output->ShareDataWith(*input);
framework::LoD lod = input->lod();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册