未验证 提交 106083aa 编写于 作者: S shixingbo 提交者: GitHub

Fix a bug in BroadcastConfig for KP XPU2 rec model (#42866)

上级 2ffb3371
......@@ -70,6 +70,7 @@ struct BroadcastConfig {
int strides_out[phi::DDim::kMaxRank];
int in_dim[phi::DDim::kMaxRank];
int dim_after_cmp[phi::DDim::kMaxRank];
int y_dim_after_cmp[phi::DDim::kMaxRank];
int dim_size_after_cmp = 0;
int cmp_res = 0;
OptType cmp_type = OptType::CanNotOptimize;
......@@ -82,7 +83,7 @@ struct BroadcastConfig {
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
const std::vector<int64_t>& in_dims,
const std::vector<int64_t>& another_in_dims,
const std::vector<int64_t>& y_in_dims,
int dim_size) {
std::vector<int> strides_in_tmp;
std::vector<int> strides_out_tmp;
......@@ -103,8 +104,8 @@ struct BroadcastConfig {
memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
cmp_res = get_mnk_for_broadcast_ops(in_dims, another_in_dims);
get_opt_type(another_in_dims);
cmp_res = get_mnk_for_broadcast_ops(in_dims, y_in_dims);
get_opt_type();
buf_len = get_buf_len();
}
......@@ -154,7 +155,7 @@ struct BroadcastConfig {
return index_src;
}
void get_opt_type(const std::vector<int64_t>& y_dim_after_cmp) {
void get_opt_type() {
if (dim_size_after_cmp == 1) {
if (dim_after_cmp[0] == 1 && y_dim_after_cmp[0] != 1) { // {1} op {n}
n = y_dim_after_cmp[0];
......@@ -241,6 +242,7 @@ struct BroadcastConfig {
int cmp_x = 0;
int cmp_y = 0;
bool is_same = false;
std::vector<int64_t> xshape_after_remove_ones = xshape;
std::vector<int64_t> yshape_after_remove_ones = yshape;
// first step: remove excess ones
......@@ -275,6 +277,7 @@ struct BroadcastConfig {
}
idx = idx + 1;
dim_after_cmp[after_cmp_idx] = cmp_x;
y_dim_after_cmp[after_cmp_idx] = cmp_y;
after_cmp_idx++;
if (idx == xshape_after_remove_ones.size()) {
dim_size_after_cmp = after_cmp_idx;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册