未验证 提交 364c4442 编写于 作者: W wawltor 提交者: GitHub

Add the support the int64 data type of `scatter_op` input Index(#18804) (#19508)

* test=develop
Fix the scatter op bug when use the add mode, and support the int64 data type of scatter_op Index(#18804).

* test=develop
Remove the PADDLE_ENFORCE and use PADDLE_ENFORCE_EQ

* test=develop
Remove the fix bug of scatter_add, and just add the support of int64 in scatter_add

* test=develop
Add the test case for scatter op, the test case just for index int64
上级 9c885708
...@@ -136,6 +136,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, ...@@ -136,6 +136,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
memset(result_p_output + slice_size * index_, 0, slice_bytes); memset(result_p_output + slice_size * index_, 0, slice_bytes);
} }
// if not in overwrite mode, need to init output data
for (int i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i]; const IndexT& index_ = p_index[i];
elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, result_p_output, src, elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, result_p_output, src,
......
...@@ -33,7 +33,22 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -33,7 +33,22 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
bool overwrite = ctx.Attr<bool>("overwrite"); bool overwrite = ctx.Attr<bool>("overwrite");
Out->ShareDataWith(*X); Out->ShareDataWith(*X);
GPUScatterAssign<T>(ctx, *Updates, *Ids, Out, overwrite); // use template class to support int32_t and int64_t
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"scatter_op Index holds the wrong type, it holds %s, but desires to be "
"%s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int32_t>(ctx, *Updates, *Ids, Out, overwrite);
} else {
GPUScatterAssign<T, int64_t>(ctx, *Updates, *Ids, Out, overwrite);
}
} }
}; };
...@@ -54,7 +69,23 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -54,7 +69,23 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
GPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates); const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"scatter_op Index holds the wrong type, it holds %s, but desires to "
"be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64));
// Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else {
GPUGather<T, int64_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
} }
} }
}; };
......
...@@ -81,7 +81,22 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -81,7 +81,22 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
CPUGather<T>(ctx.device_context(), *dOut, *Ids, dUpdates); const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
"scatter_op index holds the wrong type, it holds %s, but desires to "
"be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else {
CPUGather<T, int64_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
}
} }
} }
}; };
......
...@@ -131,5 +131,47 @@ class TestScatterOp3(OpTest): ...@@ -131,5 +131,47 @@ class TestScatterOp3(OpTest):
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
class TestScatterOp4(OpTest):
def setUp(self):
self.op_type = "scatter"
ref_np = np.ones((3, 3)).astype("float32")
index_np = np.array([1, 2]).astype("int64")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestScatterOp5(OpTest):
def setUp(self):
self.op_type = "scatter"
ref_np = np.ones((3, 3)).astype("float32")
index_np = np.array([1, 2]).astype("int64")
updates_np = np.random.random((2, 3)).astype("float32")
output_np = np.copy(ref_np)
output_np[index_np] = updates_np
self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np}
self.outputs = {'Out': output_np}
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-3)
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册