diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 3f6bfff5db4b719dfe3d8b229ee12e9bd8b0db83..6d9d1863c27ec53beeee86ebd14c01c4ee914e92 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -136,6 +136,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, memset(result_p_output + slice_size * index_, 0, slice_bytes); } + // if not in overwrite mode, need to init output data for (int i = 0; i < index_size; ++i) { const IndexT& index_ = p_index[i]; elementwise_inner_add(ctx, p_src, p_output, result_p_output, src, diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index e17617b40da356d74bdffcf53a6c9189d13c64f1..6c4da760ce828e49b55c5d488958e1039fe62702 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -33,7 +33,22 @@ class ScatterOpCUDAKernel : public framework::OpKernel { bool overwrite = ctx.Attr("overwrite"); Out->ShareDataWith(*X); - GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); + // use template class to support int32_t and int64_t + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "scatter_op Index holds the wrong type, it holds %s, but desires to be " + "%s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); + } else { + GPUScatterAssign(ctx, *Updates, *Ids, Out, overwrite); + } } }; @@ -54,7 +69,23 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "scatter_op Index holds the wrong type, it holds %s, but desires to " + "be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64)); + // Gradient by Gather: dUpdates = dO[Ids] + if (index_type == framework::proto::VarType::INT32) { + GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } else { + GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } } }; diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 3b6184de77f4fc05aa2f2900ebc656ed06a8edfc..97254f817d9856aca9ffe1a101551b902541d9cf 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -81,7 +81,22 @@ class ScatterGradientOpKernel : public framework::OpKernel { if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + "scatter_op index holds the wrong type, it holds %s, but desires to " + "be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } else { + CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 9c60a1182852ba1c524f7185a2786c9a8943315f..999b7ea88bad6345bddad4cec92d510facd142dc 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -131,5 +131,47 @@ class TestScatterOp3(OpTest): self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) +class TestScatterOp4(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int64") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Updates'], 'Out', in_place=True) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestScatterOp5(OpTest): + def setUp(self): + self.op_type = "scatter" + ref_np = np.ones((3, 3)).astype("float32") + index_np = np.array([1, 2]).astype("int64") + updates_np = np.random.random((2, 3)).astype("float32") + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + + if __name__ == "__main__": unittest.main()