未验证 提交 9514b4aa 编写于 作者: S ShenLiang 提交者: GitHub

Fix scatter grad bug (#30604)

上级 1f5841c2
...@@ -28,8 +28,7 @@ using Tensor = framework::Tensor; ...@@ -28,8 +28,7 @@ using Tensor = framework::Tensor;
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
size_t index_size, size_t slice_size, size_t index_size, size_t slice_size) {
bool overwrite) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) { CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size; int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice int slice_i = i - indices_i * slice_size; // offset inside the slice
...@@ -129,7 +128,7 @@ void GPUScatterAssign(const framework::ExecutionContext& context, ...@@ -129,7 +128,7 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
ScatterInitCUDAKernel<T, IndexT><<< ScatterInitCUDAKernel<T, IndexT><<<
grid, block, 0, grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>( reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_index, p_output, index_size, slice_size, overwrite); p_index, p_output, index_size, slice_size);
} }
ScatterCUDAKernel<T, IndexT><<< ScatterCUDAKernel<T, IndexT><<<
...@@ -138,6 +137,37 @@ void GPUScatterAssign(const framework::ExecutionContext& context, ...@@ -138,6 +137,37 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
p_src, p_index, p_output, index_size, slice_size, overwrite); p_src, p_index, p_output, index_size, slice_size, overwrite);
} }
// The function is only for scatter grad x,
// however update grad use gather
template <typename T, typename IndexT = int>
void GPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index,
Tensor* output) {
IndexT index_size = index.dims()[0];
auto dst_dims = output->dims();
// slice size
IndexT slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
const size_t& slice_bytes = slice_size * sizeof(T);
// set block and grid num
int64_t block = 512;
int64_t n = slice_size * index_size;
int64_t height = (n + block - 1) / block;
int64_t max_grid_dimx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.GetCUDAMaxGridDimSize()
.x;
int64_t grid = height < max_grid_dimx ? height : max_grid_dimx;
ScatterInitCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_index, p_output, index_size, slice_size);
}
template <typename DeviceContext, typename T, typename IndexT = int> template <typename DeviceContext, typename T, typename IndexT = int>
void GPUScatterNdAdd(const framework::ExecutionContext& context, void GPUScatterNdAdd(const framework::ExecutionContext& context,
const Tensor& update, const Tensor& index, const Tensor& update, const Tensor& index,
......
...@@ -171,6 +171,24 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, ...@@ -171,6 +171,24 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
} }
} }
// The function is only for scatter grad x,
// however update grad use gather
template <typename T, typename IndexT = int>
void CPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index,
Tensor* output) {
int index_size = index.dims()[0];
auto dst_dims = output->dims();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
size_t slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i];
memset(p_output + slice_size * index_, 0, slice_bytes);
}
}
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update, void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
const Tensor& index, Tensor* output) { const Tensor& index, Tensor* output) {
......
...@@ -65,7 +65,6 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -65,7 +65,6 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Index"); auto *Ids = ctx.Input<Tensor>("Index");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) { if (dX) {
// In place gradient: dX = dO
framework::TensorCopy(*dOut, ctx.GetPlace(), dX); framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
} }
if (dUpdates) { if (dUpdates) {
......
...@@ -71,8 +71,7 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> { ...@@ -71,8 +71,7 @@ class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
if (dX) { if (dX) {
// In place gradient: dX = dO framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
} }
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
......
...@@ -138,9 +138,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterGradNoNeedBufferVarsInferer, ...@@ -138,9 +138,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(ScatterGradNoNeedBufferVarsInferer,
"Updates"); "Updates");
DECLARE_INPLACE_OP_INFERER(ScatterInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ScatterInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ScatterGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -151,8 +148,7 @@ REGISTER_OPERATOR(scatter, ops::ScatterOp, ops::ScatterOpMaker, ...@@ -151,8 +148,7 @@ REGISTER_OPERATOR(scatter, ops::ScatterOp, ops::ScatterOpMaker,
ops::ScatterGradMaker<paddle::imperative::OpBase>, ops::ScatterGradMaker<paddle::imperative::OpBase>,
ops::ScatterInplaceInferer); ops::ScatterInplaceInferer);
REGISTER_OPERATOR(scatter_grad, ops::ScatterGradOp, REGISTER_OPERATOR(scatter_grad, ops::ScatterGradOp,
ops::ScatterGradNoNeedBufferVarsInferer, ops::ScatterGradNoNeedBufferVarsInferer);
ops::ScatterGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel<float>, REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel<float>,
ops::ScatterOpKernel<double>, ops::ScatterOpKernel<int>, ops::ScatterOpKernel<double>, ops::ScatterOpKernel<int>,
ops::ScatterOpKernel<int64_t>); ops::ScatterOpKernel<int64_t>);
......
...@@ -67,27 +67,33 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -67,27 +67,33 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates")); auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (dX) { if (dX) {
// In place gradient: dX = dO
framework::TensorCopy(*dOut, ctx.GetPlace(), dX); framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
if (index_type == framework::proto::VarType::INT32) {
GPUScatterGradForX<T, int32_t>(ctx.device_context(), *Ids, dX);
} else {
GPUScatterGradForX<T, int64_t>(ctx.device_context(), *Ids, dX);
}
} }
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"scatter_op Index holds the wrong type, it holds [%s], "
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
// Gradient by Gather: dUpdates = dO[Ids]
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates); GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else { } else {
......
...@@ -79,26 +79,32 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -79,26 +79,32 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
auto *Ids = ctx.Input<Tensor>("Ids"); auto *Ids = ctx.Input<Tensor>("Ids");
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (dX) { if (dX) {
// In place gradient: dX = dO
framework::TensorCopy(*dOut, ctx.GetPlace(), dX); framework::TensorCopy(*dOut, ctx.GetPlace(), dX);
if (index_type == framework::proto::VarType::INT32) {
CPUScatterGradForX<T, int32_t>(ctx.device_context(), *Ids, dX);
} else {
CPUScatterGradForX<T, int64_t>(ctx.device_context(), *Ids, dX);
}
} }
if (dUpdates) { if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Ids] // Gradient by Gather: dUpdates = dO[Ids]
const auto &index_type = Ids->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match, true,
platform::errors::InvalidArgument(
"scatter_op index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates); CPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);
} else { } else {
......
...@@ -78,7 +78,7 @@ class TestScatterNdAddSimpleOp(OpTest): ...@@ -78,7 +78,7 @@ class TestScatterNdAddSimpleOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True) self.check_grad(['X', 'Updates'], 'Out')
class TestScatterNdAddWithEmptyIndex(OpTest): class TestScatterNdAddWithEmptyIndex(OpTest):
...@@ -101,7 +101,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest): ...@@ -101,7 +101,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', in_place=True) self.check_grad(['X', 'Updates'], 'Out')
class TestScatterNdAddWithHighRankSame(OpTest): class TestScatterNdAddWithHighRankSame(OpTest):
...@@ -111,11 +111,11 @@ class TestScatterNdAddWithHighRankSame(OpTest): ...@@ -111,11 +111,11 @@ class TestScatterNdAddWithHighRankSame(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
shape = (10, 9, 8, 1, 15) shape = (3, 2, 2, 1, 10)
ref_np = np.random.rand(*shape).astype("float64") ref_np = np.random.rand(*shape).astype("float64")
index_np = np.vstack( index_np = np.vstack(
[np.random.randint( [np.random.randint(
0, s, size=150) for s in shape]).T.astype("int32") 0, s, size=100) for s in shape]).T.astype("int32")
update_shape = judge_update_shape(ref_np, index_np) update_shape = judge_update_shape(ref_np, index_np)
updates_np = np.random.rand(*update_shape).astype("float64") updates_np = np.random.rand(*update_shape).astype("float64")
expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np) expect_np = numpy_scatter_nd_add(ref_np.copy(), index_np, updates_np)
...@@ -127,7 +127,7 @@ class TestScatterNdAddWithHighRankSame(OpTest): ...@@ -127,7 +127,7 @@ class TestScatterNdAddWithHighRankSame(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True) self.check_grad(['X', 'Updates'], 'Out')
class TestScatterNdAddWithHighRankDiff(OpTest): class TestScatterNdAddWithHighRankDiff(OpTest):
...@@ -137,7 +137,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest): ...@@ -137,7 +137,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scatter_nd_add" self.op_type = "scatter_nd_add"
shape = (10, 9, 8, 1, 15) shape = (8, 2, 2, 1, 10)
ref_np = np.random.rand(*shape).astype("double") ref_np = np.random.rand(*shape).astype("double")
index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T
index_np = index.reshape([10, 5, 10, 5]).astype("int64") index_np = index.reshape([10, 5, 10, 5]).astype("int64")
...@@ -152,7 +152,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest): ...@@ -152,7 +152,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True) self.check_grad(['X', 'Updates'], 'Out')
#Test Python API #Test Python API
......
...@@ -37,7 +37,7 @@ class TestScatterOp(OpTest): ...@@ -37,7 +37,7 @@ class TestScatterOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True) self.check_grad(["X", "Updates"], "Out")
class TestScatterOp0(OpTest): class TestScatterOp0(OpTest):
...@@ -56,7 +56,7 @@ class TestScatterOp0(OpTest): ...@@ -56,7 +56,7 @@ class TestScatterOp0(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True) self.check_grad(["X", "Updates"], "Out")
class TestScatterOp1(OpTest): class TestScatterOp1(OpTest):
...@@ -78,7 +78,7 @@ class TestScatterOp1(OpTest): ...@@ -78,7 +78,7 @@ class TestScatterOp1(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True) self.check_grad(["X", "Updates"], "Out")
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -102,7 +102,7 @@ class TestScatterOp2(OpTest): ...@@ -102,7 +102,7 @@ class TestScatterOp2(OpTest):
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -130,7 +130,7 @@ class TestScatterOp3(OpTest): ...@@ -130,7 +130,7 @@ class TestScatterOp3(OpTest):
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
class TestScatterOp4(OpTest): class TestScatterOp4(OpTest):
...@@ -148,7 +148,7 @@ class TestScatterOp4(OpTest): ...@@ -148,7 +148,7 @@ class TestScatterOp4(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Updates'], 'Out', in_place=True) self.check_grad(['X', 'Updates'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -172,7 +172,7 @@ class TestScatterOp5(OpTest): ...@@ -172,7 +172,7 @@ class TestScatterOp5(OpTest):
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True) self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
class TestScatterAPI(unittest.TestCase): class TestScatterAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册