未验证 提交 4543ca91 编写于 作者: 傅剑寒 提交者: GitHub

fix index_put bug when index is multi-dim bool tensor (#55191)

* fix index_put bug when index is multi-dim bool tensor

* fix name error
上级 ab8fd13b
......@@ -195,7 +195,7 @@ void IndexPutGradKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_res_indices_v;
std::vector<DenseTensor> range_tensor_v;
for (int i = indices.size(); i < x.dims().size(); ++i) {
for (int i = int_indices_v.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}
......
......@@ -127,7 +127,7 @@ void IndexPutKernel(const Context& dev_ctx,
std::vector<DenseTensor> range_tensor_v;
const DenseTensor* ptr_value = nullptr;
for (int i = indices.size(); i < x.dims().size(); ++i) {
for (int i = int_indices_v.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}
......
......@@ -227,7 +227,7 @@ void IndexPutGradKernel(const Context& dev_ctx,
std::vector<DenseTensor> tmp_res_indices_v;
std::vector<DenseTensor> range_tensor_v;
for (int i = indices.size(); i < x.dims().size(); ++i) {
for (int i = int_indices_v.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeCudaTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}
......
......@@ -128,7 +128,7 @@ void IndexPutKernel(const Context& dev_ctx,
std::vector<DenseTensor> range_tensor_v;
const DenseTensor* ptr_value = nullptr;
for (int i = indices.size(); i < x.dims().size(); ++i) {
for (int i = int_indices_v.size(); i < x.dims().size(); ++i) {
range_tensor_v.emplace_back(funcs::GetRangeCudaTensor<int64_t, Context>(
dev_ctx, x.dims()[i], phi::DataType::INT64));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册