未验证 提交 106083aa 编写于 作者: S shixingbo 提交者: GitHub

Fix a bug in BroadcastConfig for KP XPU2 rec model (#42866)

上级 2ffb3371
...@@ -70,6 +70,7 @@ struct BroadcastConfig { ...@@ -70,6 +70,7 @@ struct BroadcastConfig {
int strides_out[phi::DDim::kMaxRank]; int strides_out[phi::DDim::kMaxRank];
int in_dim[phi::DDim::kMaxRank]; int in_dim[phi::DDim::kMaxRank];
int dim_after_cmp[phi::DDim::kMaxRank]; int dim_after_cmp[phi::DDim::kMaxRank];
int y_dim_after_cmp[phi::DDim::kMaxRank];
int dim_size_after_cmp = 0; int dim_size_after_cmp = 0;
int cmp_res = 0; int cmp_res = 0;
OptType cmp_type = OptType::CanNotOptimize; OptType cmp_type = OptType::CanNotOptimize;
...@@ -82,7 +83,7 @@ struct BroadcastConfig { ...@@ -82,7 +83,7 @@ struct BroadcastConfig {
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims, HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
const std::vector<int64_t>& in_dims, const std::vector<int64_t>& in_dims,
const std::vector<int64_t>& another_in_dims, const std::vector<int64_t>& y_in_dims,
int dim_size) { int dim_size) {
std::vector<int> strides_in_tmp; std::vector<int> strides_in_tmp;
std::vector<int> strides_out_tmp; std::vector<int> strides_out_tmp;
...@@ -103,8 +104,8 @@ struct BroadcastConfig { ...@@ -103,8 +104,8 @@ struct BroadcastConfig {
memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int)); memcpy(strides_out, strides_out_tmp.data(), kDims * sizeof(int));
memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int)); memcpy(in_dim, dim_tmp.data(), kDims * sizeof(int));
cmp_res = get_mnk_for_broadcast_ops(in_dims, another_in_dims); cmp_res = get_mnk_for_broadcast_ops(in_dims, y_in_dims);
get_opt_type(another_in_dims); get_opt_type();
buf_len = get_buf_len(); buf_len = get_buf_len();
} }
...@@ -154,7 +155,7 @@ struct BroadcastConfig { ...@@ -154,7 +155,7 @@ struct BroadcastConfig {
return index_src; return index_src;
} }
void get_opt_type(const std::vector<int64_t>& y_dim_after_cmp) { void get_opt_type() {
if (dim_size_after_cmp == 1) { if (dim_size_after_cmp == 1) {
if (dim_after_cmp[0] == 1 && y_dim_after_cmp[0] != 1) { // {1} op {n} if (dim_after_cmp[0] == 1 && y_dim_after_cmp[0] != 1) { // {1} op {n}
n = y_dim_after_cmp[0]; n = y_dim_after_cmp[0];
...@@ -241,6 +242,7 @@ struct BroadcastConfig { ...@@ -241,6 +242,7 @@ struct BroadcastConfig {
int cmp_x = 0; int cmp_x = 0;
int cmp_y = 0; int cmp_y = 0;
bool is_same = false; bool is_same = false;
std::vector<int64_t> xshape_after_remove_ones = xshape; std::vector<int64_t> xshape_after_remove_ones = xshape;
std::vector<int64_t> yshape_after_remove_ones = yshape; std::vector<int64_t> yshape_after_remove_ones = yshape;
// first step: remove excess ones // first step: remove excess ones
...@@ -275,6 +277,7 @@ struct BroadcastConfig { ...@@ -275,6 +277,7 @@ struct BroadcastConfig {
} }
idx = idx + 1; idx = idx + 1;
dim_after_cmp[after_cmp_idx] = cmp_x; dim_after_cmp[after_cmp_idx] = cmp_x;
y_dim_after_cmp[after_cmp_idx] = cmp_y;
after_cmp_idx++; after_cmp_idx++;
if (idx == xshape_after_remove_ones.size()) { if (idx == xshape_after_remove_ones.size()) {
dim_size_after_cmp = after_cmp_idx; dim_size_after_cmp = after_cmp_idx;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册