From 364c44422ed0e9027befc835be592a2e60982b05 Mon Sep 17 00:00:00 2001 From: wawltor <980627148@qq.com> Date: Wed, 4 Sep 2019 13:42:08 +0800 Subject: [PATCH] Add the support the int64 data type of `scatter_op` input Index(#18804) (#19508) * test=develop Fix the scatter op bug when use the add mode, and support the int64 data type of scatter_op Index(#18804). * test=develop Remove the PADDLE_ENFORCE and use PADDLE_ENFORCE_EQ * test=develop Remove the fix bug of scatter_add, and just add the support of int64 in scatter_add * test=develop Add the test case for scatter op, the test case just for index int64 --- paddle/fluid/operators/scatter.h | 1 + paddle/fluid/operators/scatter_op.cu | 35 +++++++++++++++- paddle/fluid/operators/scatter_op.h | 17 +++++++- .../fluid/tests/unittests/test_scatter_op.py | 42 +++++++++++++++++++ 4 files changed, 92 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 3f6bfff5db..6d9d1863c2 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -136,6 +136,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, memset(result_p_output + slice_size * index_, 0, slice_bytes); } + // if not in overwrite mode, need to init output data for (int i = 0; i < index_size; ++i) { const IndexT& index_ = p_index[i]; elementwise_inner_add(ctx, p_src, p_output, result_p_output, src, diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index e17617b40d..6c4da760ce 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -33,7 +33,22 @@ class ScatterOpCUDAKernel : public framework::OpKernel { bool overwrite = ctx.Attr("overwrite"); Out->ShareDataWith(*X); - GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); + // use template class to support int32_t and int64_t + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "scatter_op Index holds the wrong type, it holds %s, but desires to be " + "%s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); + } else { + GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); + } } }; @@ -54,7 +69,23 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "scatter_op Index holds the wrong type, it holds %s, but desires to " + "be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64)); + // Gradient by Gather: dUpdates = dO[Ids] + if (index_type == framework::proto::VarType::INT32) { + GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } else { + GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } } }; diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 3b6184de77..97254f817d 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -81,7 +81,22 @@ class ScatterGradientOpKernel : public framework::OpKernel { if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "scatter_op index holds the wrong type, it holds %s, but desires to " + "be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } else { + CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 9c60a11828..999b7ea88b 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -131,5 +131,47 @@ class TestScatterOp3(OpTest): self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) +class TestScatterOp4(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int64") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestScatterOp5(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int64") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + + if __name__ == "__main__": unittest.main() -- GitLab