提交 827e349a 编写于 作者: J jameswu2014

reshape2-format-modify

上级 6c9aa421
......@@ -80,7 +80,7 @@ void FetchKernel<FPGA, float>::Compute(const FetchParam<FPGA> &param) {
auto output = &param.Out()->at(col);
if (input->type() == type_id<float>()) {
output->ShareDataWith(*input);
// output dims equal to input dim
// output dims == input dim
output->Resize(input->dims());
return;
}
......
......@@ -92,69 +92,67 @@ void reshape(LoDTensor *input, LoDTensor *output) {
fpga::fpga_flush(output_ptr, Hr * WCr_align * sizeof(half));
}
static inline bool reshape2_judge(const framework::DDim input_dims,const framework::DDim output_dims){
int input_dims_size = input_dims.size();
int output_dims_size = output_dims.size();
bool dims_flag2 = true;
auto temp_dims = input_dims_size > output_dims_size ? input_dims : output_dims;
int short_dims = input_dims_size > output_dims_size ? output_dims_size : input_dims_size;
for(int i = 0; i < temp_dims.size(); ++i){
if(i < short_dims){
if(input_dims[i] != output_dims[i]){
dims_flag2 = false;
break;
}
}
else{
if(temp_dims[i] != 1){
dims_flag2 = false;
break;
}
}
}
return dims_flag2;
static inline bool reshape2_judge(const framework::DDim input_dims,
const framework::DDim output_dims) {
int input_dims_size = input_dims.size();
int output_dims_size = output_dims.size();
bool dims_flag2 = true;
auto temp_dims = input_dims_size > output_dims_size ?
input_dims : output_dims;
int short_dims = input_dims_size > output_dims_size ?
output_dims_size : input_dims_size;
for (int i = 0; i < temp_dims.size(); ++i) {
if (i < short_dims) {
if (input_dims[i] != output_dims[i]) {
dims_flag2 = false;
break;
}
} else {
if (temp_dims[i] != 1) {
dims_flag2 = false;
break;
}
}
}
return dims_flag2;
}
template <>
void Reshape2Kernel<FPGA, float>::Compute(const Reshape2Param<FPGA> &param) {
auto input = const_cast<LoDTensor *>(param.InputX());
auto output = param.Out();
auto shape = param.Shape();
auto num_in = framework::product(input->dims());
auto num_shape = framework::product(framework::make_ddim(shape));
PADDLE_MOBILE_ENFORCE(num_shape != 0, "0 index is not supported");
for (int i = 0; i < shape.size(); i++) {
if (shape[i] == -1) {
shape[i] = static_cast<int>(-num_in / num_shape);
break;
auto input = const_cast<LoDTensor *>(param.InputX());
auto output = param.Out();
auto shape = param.Shape();
auto num_in = framework::product(input->dims());
auto num_shape = framework::product(framework::make_ddim(shape));
PADDLE_MOBILE_ENFORCE(num_shape != 0, "0 index is not supported");
for (int i = 0; i < shape.size(); i++) {
if (shape[i] == -1) {
shape[i] = static_cast<int>(-num_in / num_shape);
break;
}
}
}
output->Resize(framework::make_ddim(shape));
output->Resize(framework::make_ddim(shape));
auto input_dims = input->dims();
auto output_dims = output->dims();
auto input_dims = input->dims();
auto output_dims = output->dims();
bool dims_flags = input_dims == output_dims;
bool dims_flag2 = true;
bool dims_flags = input_dims == output_dims;
bool dims_flag2 = true;
if(!dims_flags){
dims_flag2 = reshape2_judge(input_dims, output_dims);
}
if (dims_flags || dims_flag2 ) {
DLOG << "No need to reshape";
output->ShareDataWith(*input);
framework::LoD lod = input->lod();
output->set_lod(lod);
return;
}
if (!dims_flags) {
dims_flag2 = reshape2_judge(input_dims, output_dims);
}
if (dims_flags || dims_flag2) {
DLOG << "No need to reshape";
output->ShareDataWith(*input);
framework::LoD lod = input->lod();
output->set_lod(lod);
return;
}
reshape(input, output);
reshape(input, output);
//
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册