From 88a8eedda17dead5471f4d9a64e291e49b522775 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Thu, 28 Sep 2017 14:36:38 -0700 Subject: [PATCH] scatter gather gpu gather scatter gpu --- paddle/operators/cond_op.cc | 6 +- paddle/operators/gather.cu.h | 84 ++++++++++++++++++ paddle/operators/gather.h | 35 ++++---- paddle/operators/gather_op.cc | 9 +- paddle/operators/gather_op.cu | 70 +++++++++++++++ paddle/operators/gather_op.h | 27 ++++-- paddle/operators/gather_test.cc | 2 +- paddle/operators/scatter.cu.h | 86 +++++++++++++++++++ paddle/operators/scatter.h | 45 ++++------ paddle/operators/scatter_op.cc | 7 +- paddle/operators/scatter_op.cu | 63 ++++++++++++++ paddle/operators/scatter_op.h | 12 ++- paddle/operators/scatter_test.cc | 2 +- .../v2/framework/tests/test_scatter_op.py | 4 +- 14 files changed, 375 insertions(+), 77 deletions(-) create mode 100644 paddle/operators/gather.cu.h create mode 100644 paddle/operators/gather_op.cu create mode 100644 paddle/operators/scatter.cu.h create mode 100644 paddle/operators/scatter_op.cu diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index aaffa6661f..157656786a 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -169,8 +169,8 @@ void CondOp::Run(const Scope& scope, tensor_child->Resize(dim); tensor_child->mutable_data(dim, platform::CPUPlace()); - Gather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], - tensor_child); + CPUTGather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], + tensor_child); } } @@ -194,7 +194,7 @@ void CondOp::Run(const Scope& scope, PADDLE_ENFORCE_NOT_NULL(v); LoDTensor* tensor_child = v->GetMutable(); - ScatterUpdate(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], + ScatterAssign(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], tensor_parent); } } diff --git a/paddle/operators/gather.cu.h b/paddle/operators/gather.cu.h new file mode 100644 index 0000000000..c96071e295 --- /dev/null +++ b/paddle/operators/gather.cu.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using platform::Place; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, + size_t index_size, size_t slice_size) { + CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + int indices_i = i / slice_size; + int slice_i = i - indices_i * slice_size; // offset inside the slice + int gather_i = indices[indices_i]; + int params_i = gather_i * slice_size + slice_i; + *(output + i) = *(params + params_i); + } +} + +// Implementation of GPU copy: +template +struct GPUGather { + void operator()(const T* src, const int* index, const int slice_size, + const int index_size, T* output) { + int block = 512; + int n = slice_size * index_size; + int grid = (n + block - 1) / block; + GatherCUDAKernel<<>>(src, index, output, index_size, + slice_size); + } +}; + +/** + * A thin wrapper on gpu tensor + * Return a new tensor from source tensor, gathered according to index + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ +template +void GPUTGather(const Place& place, const Tensor* src, const Tensor* index, + Tensor* output) { + PADDLE_ENFORCE(platform::is_gpu_place(place)); + // check index of shape 1-D + PADDLE_ENFORCE(index->dims().size() == 1); + int index_size = index->dims()[0]; + + auto src_dims = src->dims(); + framework::DDim output_dims(src_dims); + output_dims[0] = index_size; + + // slice size + int slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + // Gathering + GPUGather gather_functor; + gather_functor(src->data(), index->data(), slice_size, index_size, + output->data()); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index 92fb51ec17..a3db17bd3d 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -26,31 +26,31 @@ namespace operators { // Implementation of CPU copy template -void CPUGather(const T* src, const int* indices, const int slice_size, - const int index_size, T* output) { - const size_t slice_bytes = slice_size * sizeof(T); +struct CPUGather { + void operator()(const T* src, const int* indices, const int slice_size, + const int index_size, T* output) { + const size_t slice_bytes = slice_size * sizeof(T); - for (int i = 0; i < index_size; ++i) { - int index_ = indices[i]; - memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); + for (int i = 0; i < index_size; ++i) { + int index_ = indices[i]; + memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); + } } -} - -// Implementation of GPU copy: -template -void GPUGather(const T* src, const int* index, const int slice_size, - const int index_size, T* output); +}; /** + * A thin wrapper on cpu tensor * Return a new tensor from source tensor, gathered according to index * input[src]: type-T source Tensor * input[index]: type-int index Tensor (1-D) * return: output tensor */ template -void Gather(const platform::Place& place, const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void CPUTGather(const platform::Place& place, + const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { + PADDLE_ENFORCE(platform::is_cpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; @@ -64,10 +64,9 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; // Gathering - if (platform::is_cpu_place(place)) { - CPUGather(src->data(), index->data(), slice_size, index_size, + CPUGather gather_functor; + gather_functor(src->data(), index->data(), slice_size, index_size, output->data()); - } } } // namespace operators diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index da22bd0c52..fe305337cb 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -31,6 +31,8 @@ class GatherOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of GatherOp should not be null."); + auto index_dims = ctx->GetInputDim("Index"); + PADDLE_ENFORCE(index_dims.size() == 1); int batch_size = ctx->GetInputDim("Index")[0]; PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); framework::DDim output_dims(ctx->GetInputDim("X")); @@ -79,8 +81,5 @@ Out = X[Index] namespace ops = paddle::operators; REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, ops::GatherGradOp); -REGISTER_OP_CPU_KERNEL(gather, - ops::GatherOpKernel); -REGISTER_OP_CPU_KERNEL( - gather_grad, - ops::GatherGradientOpKernel); +REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel); +REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel); diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu new file mode 100644 index 0000000000..f3ed692666 --- /dev/null +++ b/paddle/operators/gather_op.cu @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "gather.cu.h" +#include "paddle/framework/eigen.h" +#include "paddle/operators/gather_op.h" +#include "scatter.cu.h" + +namespace paddle { +namespace operators { + +// template +__global__ void print_arr(const float *params, const int N) { + CUDA_1D_KERNEL_LOOP(i, N) { printf("device: %d, %f\n", i, params[i]); } +} + +template +class GatherOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *output = ctx.Output("Out"); + + output->mutable_data(ctx.GetPlace()); + + GPUTGather(ctx.GetPlace(), x, index, output); + } +}; + +template +class GatherGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + LOG(INFO) << "Gather grad here"; + auto *Index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + auto *x = ctx.Input("X"); + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto place = ctx.GetEigenDevice(); + dxt.device(place) = dxt.constant(static_cast(0)); + + GPUTScatter(ctx.GetPlace(), dO, Index, dX); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gather, ops::GatherOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel); diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index 073e566e8f..b80a4ab370 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -23,29 +23,40 @@ namespace operators { using Tensor = framework::Tensor; -template +template class GatherOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - auto *X = ctx.Input("X"); - auto *Index = ctx.Input("Index"); - auto *Y = ctx.Output("Out"); + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *output = ctx.Output("Out"); + + output->mutable_data(ctx.GetPlace()); - Y->mutable_data(ctx.GetPlace()); - Gather(ctx.GetPlace(), X, Index, Y); + CPUTGather(ctx.GetPlace(), x, index, output); } }; -template +template class GatherGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + auto *Index = ctx.Input("Index"); auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); dX->mutable_data(ctx.GetPlace()); - ScatterUpdate(ctx.GetPlace(), dO, Index, dX); + auto dxt = framework::EigenVector::Flatten(*dX); + auto place = ctx.GetEigenDevice(); + dxt.device(place) = dxt.constant(static_cast(0)); + + ScatterAssign(ctx.GetPlace(), dO, Index, dX); } }; diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index 0ae1e99452..ea06ae2847 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -41,7 +41,7 @@ TEST(Gather, GatherData) { int* p_output = output->mutable_data(make_ddim({2, 4}), CPUPlace()); - Gather(CPUPlace(), src, index, output); + CPUTGather(CPUPlace(), src, index, output); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); diff --git a/paddle/operators/scatter.cu.h b/paddle/operators/scatter.cu.h new file mode 100644 index 0000000000..82e5040305 --- /dev/null +++ b/paddle/operators/scatter.cu.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void ScatterCUDAKernel(const T* params, const int* indices, + T* output, size_t index_size, + size_t slice_size) { + CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + int indices_i = i / slice_size; + int slice_i = i - indices_i * slice_size; // offset inside the slice + int scatter_i = indices[indices_i]; + int out_i = scatter_i * slice_size + slice_i; + *(output + out_i) = *(params + i); + } +} + +// Implementation of GPU copy: +template +struct GPUScatterAssign { + void operator()(const T* src, const int* index, const int slice_size, + const int index_size, T* output) { + int block = 512; + int n = slice_size * index_size; + int grid = (n + block - 1) / block; + // printf("grid, block: %d %d\n", grid, block); + ScatterCUDAKernel<<>>(src, index, output, index_size, + slice_size); + } +}; + +/** + * A thin wrapper on gpu tensor + * Return a new updated tensor from source tensor, scatter-assigned according to + * index + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ +template +void GPUTScatter(const platform::Place& place, + const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { + PADDLE_ENFORCE(platform::is_gpu_place(place)); + // check index of shape 1-D + PADDLE_ENFORCE(index->dims().size() == 1); + int index_size = index->dims()[0]; + + auto src_dims = src->dims(); + framework::DDim output_dims(src_dims); + output_dims[0] = index_size; + + // slice size + int slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + // Scatter Assign + GPUScatterAssign scatter_functor; + scatter_functor(src->data(), index->data(), slice_size, index_size, + output->data()); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h index 6b542675c2..670204b4dd 100644 --- a/paddle/operators/scatter.h +++ b/paddle/operators/scatter.h @@ -24,49 +24,33 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; // Implementation of CPU copy template -void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index, - const size_t index_size, - paddle::framework::Tensor* output) { - paddle::framework::DDim output_dims = output->dims(); +void CPUScatterAssign(const T* src, const int* index, const int slice_size, + const int index_size, T* output) { + // paddle::framework::DDim output_dims = output->dims(); + const size_t slice_bytes = slice_size * sizeof(T); - for (size_t i = 0; i < index_size; ++i) { + for (int i = 0; i < index_size; ++i) { int index_ = index[i]; - - paddle::framework::Tensor src_ = *src; - paddle::framework::Tensor output_ = *output; - if (index_size > 1) src_ = src->Slice(i, i + 1); - if (output_dims[0] > 1) output_ = output->Slice(index_, index_ + 1); - - auto X = EigenVector::Flatten(src_); - auto Y = EigenVector::Flatten(output_); - - Y = X + Y; + memcpy(output + index_ * slice_size, src + i * slice_size, slice_bytes); } } -// Implementation of GPU scatter: -template -void GPUScatterUpdate(const T* src, const int* index, const int slice_size, - const int index_size, T* output); - /** * Return a updated tensor from source tensor, scattered according to index: - * dst[i] += src[index[i]] + * dst[i] = src[index[i]] * input[src]: type-T source Tensor * input[index]: type-int index Tensor (1-D) * return: output tensor */ template -void ScatterUpdate(const platform::Place& place, +void ScatterAssign(const platform::Place& place, const paddle::framework::Tensor* src, const paddle::framework::Tensor* index, paddle::framework::Tensor* output) { + PADDLE_ENFORCE(platform::is_cpu_place(place)); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; @@ -74,18 +58,19 @@ void ScatterUpdate(const platform::Place& place, auto src_dims = src->dims(); auto dst_dims = output->dims(); + const T* p_src = src->data(); + const int* p_index = index->data(); + T* p_output = output->data(); + // check src shape and dst shape should match for (int i = 1; i < src_dims.size(); i++) PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); // slice size size_t slice_size = 1; - for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - if (platform::is_cpu_place(place)) { - CPUScatterUpdate(src, index->data(), index_size, output); - } else { - } + CPUScatterAssign(p_src, p_index, slice_size, index_size, p_output); } } // namespace operators diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index cadd8841b6..d15ba15153 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -97,8 +97,5 @@ Out[Index] = Ref[Index] + Updates namespace ops = paddle::operators; REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, ops::ScatterGradOp); -REGISTER_OP_CPU_KERNEL(scatter, - ops::ScatterOpKernel); -REGISTER_OP_CPU_KERNEL( - scatter_grad, - ops::ScatterGradientOpKernel); +REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel); +REGISTER_OP_CPU_KERNEL(scatter_grad, ops::ScatterGradientOpKernel); diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu new file mode 100644 index 0000000000..e27a926c6a --- /dev/null +++ b/paddle/operators/scatter_op.cu @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "gather.cu.h" +#include "paddle/operators/gather_op.h" +#include "scatter.cu.h" + +namespace paddle { +namespace operators { + +template +class ScatterOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *Ref = ctx.Input("Ref"); + auto *Index = ctx.Input("Index"); + auto *Updates = ctx.Input("Updates"); + auto *Out = ctx.Output("Out"); + + Out->ShareDataWith(*Ref); + + GPUTScatter(ctx.GetPlace(), Updates, Index, Out); + } +}; + +template +class ScatterGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *dRef = ctx.Output(framework::GradVarName("Ref")); + auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); + auto *Index = ctx.Input("Index"); + auto *dOut = ctx.Input(framework::GradVarName("Out")); + + // In place gradient: dRef = dO + dRef->ShareDataWith(*dOut); + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Index] + GPUTGather(ctx.GetPlace(), dOut, Index, dUpdates); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(scatter, ops::ScatterOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel); diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index a8eb54399a..74b2718f43 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -23,10 +23,12 @@ namespace operators { using Tensor = framework::Tensor; -template +template class ScatterOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); auto *Ref = ctx.Input("Ref"); auto *Index = ctx.Input("Index"); auto *Updates = ctx.Input("Updates"); @@ -35,14 +37,16 @@ class ScatterOpKernel : public framework::OpKernel { // In place output: Out = Ref, Out[Index] += Updates Out->ShareDataWith(*Ref); // Apply ScatterUpdate: Out[index] += Updates[:] - ScatterUpdate(ctx.GetPlace(), Updates, Index, Out); + ScatterAssign(ctx.GetPlace(), Updates, Index, Out); } }; -template +template class ScatterGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); auto *dRef = ctx.Output(framework::GradVarName("Ref")); auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); auto *Index = ctx.Input("Index"); @@ -52,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Index] - Gather(ctx.GetPlace(), dOut, Index, dUpdates); + CPUTGather(ctx.GetPlace(), dOut, Index, dUpdates); } }; diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc index 26fdaff146..bace6419d0 100644 --- a/paddle/operators/scatter_test.cc +++ b/paddle/operators/scatter_test.cc @@ -40,7 +40,7 @@ TEST(scatter, ScatterUpdate) { float* p_output = output->mutable_data(make_ddim({4, 4}), CPUPlace()); - ScatterUpdate(CPUPlace(), src, index, output); + ScatterAssign(CPUPlace(), src, index, output); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0)); diff --git a/python/paddle/v2/framework/tests/test_scatter_op.py b/python/paddle/v2/framework/tests/test_scatter_op.py index 33c73c5263..1032269d5d 100644 --- a/python/paddle/v2/framework/tests/test_scatter_op.py +++ b/python/paddle/v2/framework/tests/test_scatter_op.py @@ -10,7 +10,7 @@ class TestScatterOp(OpTest): index_np = np.array([1, 2]).astype("int32") updates_np = np.random.random((2, 3)).astype("float32") output_np = np.copy(ref_np) - output_np[index_np] += updates_np + output_np[index_np] = updates_np self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} self.outputs = {'Out': output_np} @@ -18,7 +18,7 @@ class TestScatterOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates', 'Ref'], 'Out', in_place=True) + self.check_grad(['Updates'], 'Out', in_place=True) if __name__ == "__main__": -- GitLab