diff --git a/paddle/fluid/operators/jit/gen/vbroadcast.cc b/paddle/fluid/operators/jit/gen/vbroadcast.cc index 31deb1643056231de5b617c57b8761dc55f904d4..3f9fbdbd821acae0940c5a7b8d9a5eb2432712ff 100644 --- a/paddle/fluid/operators/jit/gen/vbroadcast.cc +++ b/paddle/fluid/operators/jit/gen/vbroadcast.cc @@ -37,36 +37,33 @@ void VBroadcastJitCode::genCode() { } // protect param_h - const size_t width_in_byte = sizeof(float) * w_; mov(reg_height, param_h); - int acc_num_regs = 0; - for (int num_regs : groups) { + Label l_next_h; + xor_(reg_h_i, reg_h_i); + mov(reg_ptr_dst_i, param_dst); + L(l_next_h); + { mov(reg_ptr_src_i, param_src); - add(reg_ptr_src_i, acc_num_regs * block_size); - size_t w_offset = 0; - for (int reg_i = 0; reg_i < num_regs; ++reg_i) { - vmovups(ymm_t(reg_i), ptr[reg_ptr_src_i + w_offset]); - w_offset += block_size; - } + for (int num_regs : groups) { + size_t w_offset = 0; + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ymm_t(reg_i), ptr[reg_ptr_src_i + w_offset]); + w_offset += block_size; + } + add(reg_ptr_src_i, num_regs * block_size); - Label l_next_h; - xor_(reg_h_i, reg_h_i); - mov(reg_ptr_dst_i, param_dst); - add(reg_ptr_dst_i, acc_num_regs * block_size); - L(l_next_h); - { w_offset = 0; for (int reg_i = 0; reg_i < num_regs; ++reg_i) { vmovups(ptr[reg_ptr_dst_i + w_offset], ymm_t(reg_i)); w_offset += block_size; } - add(reg_ptr_dst_i, width_in_byte); - inc(reg_h_i); - cmp(reg_h_i, reg_height); - jl(l_next_h, T_NEAR); - } // end of l_next_h - acc_num_regs += num_regs; - } // end of groups + add(reg_ptr_dst_i, num_regs * block_size); + } // end of groups + inc(reg_h_i); + cmp(reg_h_i, reg_height); + jl(l_next_h, T_NEAR); + } // end of l_next_h + postCode(); }