From 9514b4aa5fec9b302416743325e272b42ebbdbf8 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Fri, 22 Jan 2021 11:54:47 +0800 Subject: [PATCH] Fix scatter grad bug (#30604) --- paddle/fluid/operators/scatter.cu.h | 36 +++++++++++++++++-- paddle/fluid/operators/scatter.h | 18 ++++++++++ paddle/fluid/operators/scatter_nd_add_op.cu | 1 - paddle/fluid/operators/scatter_nd_add_op.h | 3 +- paddle/fluid/operators/scatter_op.cc | 6 +--- paddle/fluid/operators/scatter_op.cu | 36 +++++++++++-------- paddle/fluid/operators/scatter_op.h | 34 ++++++++++-------- .../tests/unittests/test_scatter_nd_op.py | 14 ++++---- .../fluid/tests/unittests/test_scatter_op.py | 14 ++++---- 9 files changed, 108 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index 7890d50e109..b116a78891a 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -28,8 +28,7 @@ using Tensor = framework::Tensor; template __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, - size_t index_size, size_t slice_size, - bool overwrite) { + size_t index_size, size_t slice_size) { CUDA_KERNEL_LOOP(i, index_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice @@ -129,7 +128,7 @@ void GPUScatterAssign(const framework::ExecutionContext& context, ScatterInitCUDAKernel<<< grid, block, 0, reinterpret_cast(ctx).stream()>>>( - p_index, p_output, index_size, slice_size, overwrite); + p_index, p_output, index_size, slice_size); } ScatterCUDAKernel<<< @@ -138,6 +137,37 @@ void GPUScatterAssign(const framework::ExecutionContext& context, p_src, p_index, p_output, index_size, slice_size, overwrite); } +// The function is only for scatter grad x, +// however update grad use gather +template +void GPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index, + Tensor* output) { + IndexT index_size = index.dims()[0]; + auto dst_dims = output->dims(); + // slice size + IndexT slice_size = 1; + for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i]; + const IndexT* p_index = index.data(); + T* p_output = output->data(); + const size_t& slice_bytes = slice_size * sizeof(T); + + // set block and grid num + int64_t block = 512; + int64_t n = slice_size * index_size; + int64_t height = (n + block - 1) / block; + + int64_t max_grid_dimx = + reinterpret_cast(ctx) + .GetCUDAMaxGridDimSize() + .x; + int64_t grid = height < max_grid_dimx ? height : max_grid_dimx; + + ScatterInitCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_index, p_output, index_size, slice_size); +} + template void GPUScatterNdAdd(const framework::ExecutionContext& context, const Tensor& update, const Tensor& index, diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 7325df85c46..cfa88b9808d 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -171,6 +171,24 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, } } +// The function is only for scatter grad x, +// however update grad use gather +template +void CPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index, + Tensor* output) { + int index_size = index.dims()[0]; + auto dst_dims = output->dims(); + const IndexT* p_index = index.data(); + T* p_output = output->data(); + size_t slice_size = 1; + for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i]; + const size_t slice_bytes = slice_size * sizeof(T); + for (int i = 0; i < index_size; ++i) { + const IndexT& index_ = p_index[i]; + memset(p_output + slice_size * index_, 0, slice_bytes); + } +} + template void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update, const Tensor& index, Tensor* output) { diff --git a/paddle/fluid/operators/scatter_nd_add_op.cu b/paddle/fluid/operators/scatter_nd_add_op.cu index fb9bc9a045d..ec2a0201de6 100644 --- a/paddle/fluid/operators/scatter_nd_add_op.cu +++ b/paddle/fluid/operators/scatter_nd_add_op.cu @@ -65,7 +65,6 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel { auto *Ids = ctx.Input("Index"); auto *dOut = ctx.Input(framework::GradVarName("Out")); if (dX) { - // In place gradient: dX = dO framework::TensorCopy(*dOut, ctx.GetPlace(), dX); } if (dUpdates) { diff --git a/paddle/fluid/operators/scatter_nd_add_op.h b/paddle/fluid/operators/scatter_nd_add_op.h index 2c8cf0210a1..904b8a421d0 100644 --- a/paddle/fluid/operators/scatter_nd_add_op.h +++ b/paddle/fluid/operators/scatter_nd_add_op.h @@ -71,8 +71,7 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel { auto *dOut = ctx.Input(framework::GradVarName("Out")); if (dX) { - // In place gradient: dX = dO - framework::TensorCopySync(*dOut, ctx.GetPlace(), dX); + framework::TensorCopy(*dOut, ctx.GetPlace(), dX); } if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 8ee5aa312f7..3fc40d41c30 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -138,9 +138,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterGradNoNeedBufferVarsInferer, "Updates"); DECLARE_INPLACE_OP_INFERER(ScatterInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(ScatterGradInplaceInferer, - {framework::GradVarName("Out"), - framework::GradVarName("X")}); } // namespace operators } // namespace paddle @@ -151,8 +148,7 @@ REGISTER_OPERATOR(scatter, ops::ScatterOp, ops::ScatterOpMaker, ops::ScatterGradMaker, ops::ScatterInplaceInferer); REGISTER_OPERATOR(scatter_grad, ops::ScatterGradOp, - ops::ScatterGradNoNeedBufferVarsInferer, - ops::ScatterGradInplaceInferer); + ops::ScatterGradNoNeedBufferVarsInferer); REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel, ops::ScatterOpKernel, ops::ScatterOpKernel, ops::ScatterOpKernel); diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index e6745ae97a9..1556099d6f1 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -67,27 +67,33 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); + + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "scatter_op index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + if (dX) { - // In place gradient: dX = dO framework::TensorCopy(*dOut, ctx.GetPlace(), dX); + if (index_type == framework::proto::VarType::INT32) { + GPUScatterGradForX(ctx.device_context(), *Ids, dX); + } else { + GPUScatterGradForX(ctx.device_context(), *Ids, dX); + } } + if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - const auto &index_type = Ids->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "scatter_op Index holds the wrong type, it holds [%s], " - "but desires to be [%s] or [%s]", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - // Gradient by Gather: dUpdates = dO[Ids] if (index_type == framework::proto::VarType::INT32) { GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); } else { diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 9c00ac7e9c2..185398bed10 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -79,26 +79,32 @@ class ScatterGradientOpKernel : public framework::OpKernel { auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); + const auto &index_type = Ids->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, true, + platform::errors::InvalidArgument( + "scatter_op index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s]", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + if (dX) { - // In place gradient: dX = dO framework::TensorCopy(*dOut, ctx.GetPlace(), dX); + if (index_type == framework::proto::VarType::INT32) { + CPUScatterGradForX(ctx.device_context(), *Ids, dX); + } else { + CPUScatterGradForX(ctx.device_context(), *Ids, dX); + } } + if (dUpdates) { dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Ids] - const auto &index_type = Ids->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ( - index_type_match, true, - platform::errors::InvalidArgument( - "scatter_op index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s]", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); } else { diff --git a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py index 90aae939a61..35bb4487c6a 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py @@ -78,7 +78,7 @@ class TestScatterNdAddSimpleOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates'], 'Out', in_place=True) + self.check_grad(['X', 'Updates'], 'Out') class TestScatterNdAddWithEmptyIndex(OpTest): @@ -101,7 +101,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out', in_place=True) + self.check_grad(['X', 'Updates'], 'Out') class TestScatterNdAddWithHighRankSame(OpTest): @@ -111,11 +111,11 @@ class TestScatterNdAddWithHighRankSame(OpTest): def setUp(self): self.op_type = "scatter_nd_add" - shape = (10, 9, 8, 1, 15) + shape = (3, 2, 2, 1, 10) ref_np = np.random.rand(*shape).astype("float64") index_np = np.vstack( [np.random.randint( - 0, s, size=150) for s in shape]).T.astype("int32") + 0, s, size=100) for s in shape]).T.astype("int32") update_shape = judge_update_shape(ref_np, index_np) updates_np = np.random.rand(*update_shape).astype("float64") expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) @@ -127,7 +127,7 @@ class TestScatterNdAddWithHighRankSame(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates'], 'Out', in_place=True) + self.check_grad(['X', 'Updates'], 'Out') class TestScatterNdAddWithHighRankDiff(OpTest): @@ -137,7 +137,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest): def setUp(self): self.op_type = "scatter_nd_add" - shape = (10, 9, 8, 1, 15) + shape = (8, 2, 2, 1, 10) ref_np = np.random.rand(*shape).astype("double") index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T index_np = index.reshape([10, 5, 10, 5]).astype("int64") @@ -152,7 +152,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates'], 'Out', in_place=True) + self.check_grad(['X', 'Updates'], 'Out') #Test Python API diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index e2f012e9a63..c40ca3941ac 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -37,7 +37,7 @@ class TestScatterOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates'], 'Out', in_place=True) + self.check_grad(["X", "Updates"], "Out") class TestScatterOp0(OpTest): @@ -56,7 +56,7 @@ class TestScatterOp0(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates'], 'Out', in_place=True) + self.check_grad(["X", "Updates"], "Out") class TestScatterOp1(OpTest): @@ -78,7 +78,7 @@ class TestScatterOp1(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates'], 'Out', in_place=True) + self.check_grad(["X", "Updates"], "Out") @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -102,7 +102,7 @@ class TestScatterOp2(OpTest): def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + self.check_grad_with_place(place, ['X', 'Updates'], 'Out') @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -130,7 +130,7 @@ class TestScatterOp3(OpTest): def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + self.check_grad_with_place(place, ['X', 'Updates'], 'Out') class TestScatterOp4(OpTest): @@ -148,7 +148,7 @@ class TestScatterOp4(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates'], 'Out', in_place=True) + self.check_grad(['X', 'Updates'], 'Out') @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -172,7 +172,7 @@ class TestScatterOp5(OpTest): def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) + self.check_grad_with_place(place, ['X', 'Updates'], 'Out') class TestScatterAPI(unittest.TestCase): -- GitLab