diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h index 3799b9d4892f85f751f6007f3ea852febfe382c1..1e5dfe2a542b0f71956efe02924a1d34429e949d 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -70,6 +70,7 @@ struct BroadcastConfig { int strides_out[phi::DDim::kMaxRank]; int in_dim[phi::DDim::kMaxRank]; int dim_after_cmp[phi::DDim::kMaxRank]; + int y_dim_after_cmp[phi::DDim::kMaxRank]; int dim_size_after_cmp = 0; int cmp_res = 0; OptType cmp_type = OptType::CanNotOptimize; @@ -82,7 +83,7 @@ struct BroadcastConfig { HOSTDEVICE BroadcastConfig(const std::vector& out_dims, const std::vector& in_dims, - const std::vector& another_in_dims, + const std::vector& y_in_dims, int dim_size) { std::vector strides_in_tmp; std::vector strides_out_tmp; @@ -103,8 +104,8 @@ struct BroadcastConfig { memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); - cmp_res = get_mnk_for_broadcast_ops(in_dims, another_in_dims); - get_opt_type(another_in_dims); + cmp_res = get_mnk_for_broadcast_ops(in_dims, y_in_dims); + get_opt_type(); buf_len = get_buf_len(); } @@ -154,7 +155,7 @@ struct BroadcastConfig { return index_src; } - void get_opt_type(const std::vector& y_dim_after_cmp) { + void get_opt_type() { if (dim_size_after_cmp == 1) { if (dim_after_cmp[0] == 1 && y_dim_after_cmp[0] != 1) { // {1} op {n} n = y_dim_after_cmp[0]; @@ -241,6 +242,7 @@ struct BroadcastConfig { int cmp_x = 0; int cmp_y = 0; bool is_same = false; + std::vector xshape_after_remove_ones = xshape; std::vector yshape_after_remove_ones = yshape; // first step: remove excess ones @@ -275,6 +277,7 @@ struct BroadcastConfig { } idx = idx + 1; dim_after_cmp[after_cmp_idx] = cmp_x; + y_dim_after_cmp[after_cmp_idx] = cmp_y; after_cmp_idx++; if (idx == xshape_after_remove_ones.size()) { dim_size_after_cmp = after_cmp_idx;