未验证 提交 f48b1264 编写于 作者: L limingshu 提交者: GitHub

Performance fix for broadcast kernel [Part3] (#45854)

* first commit

* fix some bugs in code

* fix bugs

* to optimize merge one dimension feature
上级 8dde7aea
...@@ -213,9 +213,10 @@ struct DimensionsTransform { ...@@ -213,9 +213,10 @@ struct DimensionsTransform {
} }
} }
}; };
int swap_idx = 0; for (auto i = 0; i < dim_size; ++i) {
bool has_seq_one = FindSequentialOneDim(&swap_idx); int swap_idx = 0;
if (has_seq_one) { bool has_seq_one = FindSequentialOneDim(&swap_idx);
if (!has_seq_one) break;
merge_ptr = merge_sequential_one_dims; merge_ptr = merge_sequential_one_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N); MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[swap_idx], in_dims[0]); std::swap(in_dims[swap_idx], in_dims[0]);
...@@ -508,7 +509,6 @@ void BroadcastKernelForDifferentVecSize( ...@@ -508,7 +509,6 @@ void BroadcastKernelForDifferentVecSize(
"functions is %d.", "functions is %d.",
outs->size(), outs->size(),
NumOuts)); NumOuts));
// mergedim and get vec_size // mergedim and get vec_size
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis); const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
phi::Array<kps::details::BroadcastConfig, kArity> configs; phi::Array<kps::details::BroadcastConfig, kArity> configs;
......
...@@ -85,7 +85,7 @@ struct FastDivMod { ...@@ -85,7 +85,7 @@ struct FastDivMod {
struct BroadcastConfig { struct BroadcastConfig {
FastDivMod divmoders[phi::DDim::kMaxRank]; FastDivMod divmoders[phi::DDim::kMaxRank];
uint32_t strides[phi::DDim::kMaxRank]; uint32_t strides[phi::DDim::kMaxRank];
int kDims; int kDims{0};
HOSTDEVICE BroadcastConfig() {} HOSTDEVICE BroadcastConfig() {}
HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims, HOSTDEVICE BroadcastConfig(const std::vector<int64_t>& out_dims,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册