未验证 提交 f48b1264 编写于 作者: L limingshu 提交者: GitHub

Performance fix for broadcast kernel [Part3] (#45854)

* first commit

* fix some bugs in code

* fix bugs

* to optimize merge one dimension feature
上级 8dde7aea
......@@ -213,9 +213,10 @@ struct DimensionsTransform {
}
}
};
int swap_idx = 0;
bool has_seq_one = FindSequentialOneDim(&swap_idx);
if (has_seq_one) {
for (auto i = 0; i < dim_size; ++i) {
int swap_idx = 0;
bool has_seq_one = FindSequentialOneDim(&swap_idx);
if (!has_seq_one) break;
merge_ptr = merge_sequential_one_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[swap_idx], in_dims[0]);
......@@ -508,7 +509,6 @@ void BroadcastKernelForDifferentVecSize(
"functions is %d.",
outs->size(),
NumOuts));
// mergedim and get vec_size
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
phi::Array<kps::details::BroadcastConfig, kArity> configs;
......
......@@ -85,7 +85,7 @@ struct FastDivMod {
struct BroadcastConfig {
FastDivMod divmoders[phi::DDim::kMaxRank];
uint32_t strides[phi::DDim::kMaxRank];
int kDims;
int kDims{0};
HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册