未验证 提交 9514b4aa 编写于 作者: S ShenLiang 提交者: GitHub

Fix scatter grad bug (#30604)

上级 1f5841c2
......@@ -28,8 +28,7 @@ using Tensor = framework::Tensor;
template <typename T, typename IndexT = int>
__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
size_t index_size, size_t slice_size,
bool overwrite) {
size_t index_size, size_t slice_size) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
......@@ -129,7 +128,7 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
ScatterInitCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_index, p_output, index_size, slice_size, overwrite);
p_index, p_output, index_size, slice_size);
}
ScatterCUDAKernel<T, IndexT><<<
......@@ -138,6 +137,37 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
p_src, p_index, p_output, index_size, slice_size, overwrite);
}
// The function is only for scatter grad x,
// however update grad use gather
template <typename T, typename IndexT = int>
void GPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index,
Tensor* output) {
IndexT index_size = index.dims()[0];
auto dst_dims = output->dims();
// slice size
IndexT slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T);
// set block and grid num
int64_t block = 512;
int64_t n = slice_size * index_size;
int64_t height = (n + block - 1) / block;
int64_t max_grid_dimx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.GetCUDAMaxGridDimSize()
.x;
int64_t grid = height < max_grid_dimx ? height : max_grid_dimx;
ScatterInitCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_index, p_output, index_size, slice_size);
}
template <typename DeviceContext, typename T, typename IndexT = int>
void GPUScatterNdAdd(const framework::ExecutionContext& context,
const Tensor& update, const Tensor& index,
......
......@@ -171,6 +171,24 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
}
}
// The function is only for scatter grad x,
// however update grad use gather
template <typename T, typename IndexT = int>
void CPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index,
Tensor* output) {
int index_size = index.dims()[0];
auto dst_dims = output->dims();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
size_t slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i];
memset(p_output + slice_size * index_, 0, slice_bytes);
}
}
template <typename T, typename IndexT = int>
void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
const Tensor& index, Tensor* output) {
......
......@@ -65,7 +65,6 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Index");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
......
......@@ -71,8 +71,7 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
......
......@@ -138,9 +138,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterGradNoNeedBufferVarsInferer,
"Updates");
DECLARE_INPLACE_OP_INFERER(ScatterInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ScatterGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
......@@ -151,8 +148,7 @@ REGISTER_OPERATOR(scatter, ops::ScatterOp, ops::ScatterOpMaker,
ops::ScatterGradMaker<paddle::imperative::OpBase>,
ops::ScatterInplaceInferer);
REGISTER_OPERATOR(scatter_grad, ops::ScatterGradOp,
ops::ScatterGradNoNeedBufferVarsInferer,
ops::ScatterGradInplaceInferer);
ops::ScatterGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel<float>,
ops::ScatterOpKernel<double>, ops::ScatterOpKernel<int>,
ops::ScatterOpKernel<int64_t>);
......
......@@ -67,26 +67,32 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"scatter_op Index holds the wrong type, it holds [%s], "
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (dX) {
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
if (index_type == framework::proto::VarType::INT32) {
GPUScatterGradForX<T, int32_t>(ctx.device_context(), *Ids, dX);
} else {
GPUScatterGradForX<T, int64_t>(ctx.device_context(), *Ids, dX);
}
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
......
......@@ -79,13 +79,6 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) {
// In place gradient: dX = dO
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
......@@ -99,6 +92,19 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (dX) {
framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
if (index_type == framework::proto::VarType::INT32) {
CPUScatterGradForX<T, int32_t>(ctx.device_context(), *Ids, dX);
} else {
CPUScatterGradForX<T, int64_t>(ctx.device_context(), *Ids, dX);
}
}
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else {
......
......@@ -78,7 +78,7 @@ class TestScatterNdAddSimpleOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
self.check_grad(['X', 'Updates'], 'Out')
class TestScatterNdAddWithEmptyIndex(OpTest):
......@@ -101,7 +101,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True)
self.check_grad(['X', 'Updates'], 'Out')
class TestScatterNdAddWithHighRankSame(OpTest):
......@@ -111,11 +111,11 @@ class TestScatterNdAddWithHighRankSame(OpTest):
def setUp(self):
self.op_type = "scatter_nd_add"
shape = (10, 9, 8, 1, 15)
shape = (3, 2, 2, 1, 10)
ref_np = np.random.rand(*shape).astype("float64")
index_np = np.vstack(
[np.random.randint(
0, s, size=150) for s in shape]).T.astype("int32")
0, s, size=100) for s in shape]).T.astype("int32")
update_shape = judge_update_shape(ref_np, index_np)
updates_np = np.random.rand(*update_shape).astype("float64")
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
......@@ -127,7 +127,7 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
self.check_grad(['X', 'Updates'], 'Out')
class TestScatterNdAddWithHighRankDiff(OpTest):
......@@ -137,7 +137,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
def setUp(self):
self.op_type = "scatter_nd_add"
shape = (10, 9, 8, 1, 15)
shape = (8, 2, 2, 1, 10)
ref_np = np.random.rand(*shape).astype("double")
index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T
index_np = index.reshape([10, 5, 10, 5]).astype("int64")
......@@ -152,7 +152,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
self.check_grad(['X', 'Updates'], 'Out')
#Test Python API
......
......@@ -37,7 +37,7 @@ class TestScatterOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
self.check_grad(["X", "Updates"], "Out")
class TestScatterOp0(OpTest):
......@@ -56,7 +56,7 @@ class TestScatterOp0(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
self.check_grad(["X", "Updates"], "Out")
class TestScatterOp1(OpTest):
......@@ -78,7 +78,7 @@ class TestScatterOp1(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
self.check_grad(["X", "Updates"], "Out")
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -102,7 +102,7 @@ class TestScatterOp2(OpTest):
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -130,7 +130,7 @@ class TestScatterOp3(OpTest):
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
class TestScatterOp4(OpTest):
......@@ -148,7 +148,7 @@ class TestScatterOp4(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True)
self.check_grad(['X', 'Updates'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -172,7 +172,7 @@ class TestScatterOp5(OpTest):
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
class TestScatterAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册