diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index b4eb89e2d7b72801556efe8121eb323e2d6ab346..4c76326e7cecf3dc9a0298864803578c05d49d3b 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -141,9 +141,27 @@ static std::unique_ptr BackwardRecursive( net->ops_[op_offset]->Rename(name, dup_outputs.back()); } // collect all the offset to append `add` op for each alias - insert_position.push_back( - {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}}, - {{"Out", {name}}}, {})}); + // + // one variable is shared between multiple operators. + // insert add operator one by one, then add it to output + for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1; + ++output_idx) { + auto insert_add_x = dup_outputs[output_idx]; + auto insert_add_y = dup_outputs[output_idx]; + auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx); + // first add op inserted + if (output_idx == dup_outputs.size() - 2) { + insert_add_out = name; + } + if (output_idx != 0) { + insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1); + } + insert_position.push_back( + {dup_op.back(), + OpRegistry::CreateOp( + "sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}}, + {{"Out", {insert_add_out}}}, {})}); + } } // make sure the inserted `add` ops follow the BFS order. diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 6932f5b989a3e21ebc44ec4fec9f5223f2547d7a..a36e7bde8c61c6a8bfc6eefd46ce5e09907476ef 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -133,15 +133,18 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { } }; -class AddOpMaker : public OpProtoAndCheckerMaker { +class SumOpMaker : public framework::OpProtoAndCheckerMaker { public: - AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "x").AsDuplicable(); - AddOutput("Out", "out"); + AddInput("X", "the input tensors of sum operator.") + .AsDuplicable() + .NotInGradient(); + AddOutput("Out", "the output tensor of sum operator.").NotInGradient(); AddComment(""); } }; + } // namespace framework } // namespace paddle @@ -154,7 +157,7 @@ REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); -REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); +REGISTER_OP(sum, f::NOP, f::SumOpMaker, sum_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, f::NOP); @@ -283,7 +286,7 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); - ASSERT_EQ("add", bwd_net->ops_[2]->Type()); + ASSERT_EQ("sum", bwd_net->ops_[2]->Type()); } TEST(Backward, op_register_grad_not_for_network) { diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index db20b69f3fab1a5a5bfcd445536437dc4b428559..2737104a205cbc1e18ce4a3a45592a416d38a874 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet( dim[0] = index_tensors[i].dims()[0]; tensor_child->mutable_data(dim, platform::CPUPlace()); - Gather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], - tensor_child); + CPUGather(dev_ctx, *tensor_parent, index_tensors[i], tensor_child); } } @@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, Variable* var_child = sub_scopes[i]->FindVar(output); PADDLE_ENFORCE_NOT_NULL(var_child); auto* tensor_child = &var_child->Get(); - ScatterUpdate(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], + ScatterAssign(dev_ctx, *tensor_child, index_tensors[i], tensor_parent); } } diff --git a/paddle/operators/gather.cu.h b/paddle/operators/gather.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..8d04ecd284226c7b4c6cdd5531915fee2d94ce61 --- /dev/null +++ b/paddle/operators/gather.cu.h @@ -0,0 +1,79 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using platform::Place; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, + size_t index_size, size_t slice_size) { + CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + int indices_i = i / slice_size; + int slice_i = i - indices_i * slice_size; // offset inside the slice + int gather_i = indices[indices_i]; + int params_i = gather_i * slice_size + slice_i; + *(output + i) = *(params + params_i); + } +} + +/** + * A thin wrapper on gpu tensor + * Return a new tensor from source tensor, gathered according to index + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ +template +void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { + // PADDLE_ENFORCE(platform::is_gpu_place(place)); + // check index of shape 1-D + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; + + auto src_dims = src.dims(); + framework::DDim output_dims(src_dims); + output_dims[0] = index_size; + + // slice size + int slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + const T* p_src = src.data(); + const int* p_index = index.data(); + T* p_output = output->data(); + + int block = 512; + int n = slice_size * index_size; + int grid = (n + block - 1) / block; + + GatherCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_src, p_index, p_output, index_size, slice_size); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index 92fb51ec17709bc6f8abb2f516a9240fb5dc3a77..052db49cb3c2594eca8b9a5e3716689480089703 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -24,49 +24,40 @@ limitations under the License. */ namespace paddle { namespace operators { -// Implementation of CPU copy -template -void CPUGather(const T* src, const int* indices, const int slice_size, - const int index_size, T* output) { - const size_t slice_bytes = slice_size * sizeof(T); - - for (int i = 0; i < index_size; ++i) { - int index_ = indices[i]; - memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes); - } -} - -// Implementation of GPU copy: -template -void GPUGather(const T* src, const int* index, const int slice_size, - const int index_size, T* output); +using framework::Tensor; /** + * A thin wrapper for gathering on cpu tensor * Return a new tensor from source tensor, gathered according to index * input[src]: type-T source Tensor * input[index]: type-int index Tensor (1-D) * return: output tensor */ template -void Gather(const platform::Place& place, const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); framework::DDim output_dims(src_dims); output_dims[0] = index_size; + const T* p_src = src.data(); + const int* p_index = index.data(); + T* p_output = output->data(); + // slice size int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - // Gathering - if (platform::is_cpu_place(place)) { - CPUGather(src->data(), index->data(), slice_size, index_size, - output->data()); + const size_t slice_bytes = slice_size * sizeof(T); + + for (int i = 0; i < index_size; ++i) { + int index_ = p_index[i]; + memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); } } diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index da22bd0c52c27d7decd10e2e2b34fa38d0620da8..fe305337cbebd7c679ae1b8ee8aa2740472ee109 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -31,6 +31,8 @@ class GatherOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of GatherOp should not be null."); + auto index_dims = ctx->GetInputDim("Index"); + PADDLE_ENFORCE(index_dims.size() == 1); int batch_size = ctx->GetInputDim("Index")[0]; PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0"); framework::DDim output_dims(ctx->GetInputDim("X")); @@ -79,8 +81,5 @@ Out = X[Index] namespace ops = paddle::operators; REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad, ops::GatherGradOp); -REGISTER_OP_CPU_KERNEL(gather, - ops::GatherOpKernel); -REGISTER_OP_CPU_KERNEL( - gather_grad, - ops::GatherGradientOpKernel); +REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel); +REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel); diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..92219d6a433e6db0bb9886ed8670cbafaa843ff8 --- /dev/null +++ b/paddle/operators/gather_op.cu @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "gather.cu.h" +#include "paddle/framework/eigen.h" +#include "paddle/operators/gather_op.h" +#include "scatter.cu.h" + +namespace paddle { +namespace operators { + +template +class GatherOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *output = ctx.Output("Out"); + + output->mutable_data(ctx.GetPlace()); + + GPUGather(ctx.device_context(), *x, *index, output); + } +}; + +template +class GatherGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *Index = ctx.Input("Index"); + auto *dX = ctx.Output(framework::GradVarName("X")); + auto *dO = ctx.Input(framework::GradVarName("Out")); + auto *x = ctx.Input("X"); + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto place = ctx.GetEigenDevice(); + dxt.device(place) = dxt.constant(static_cast(0)); + + GPUScatterAssign(ctx.device_context(), *dO, *Index, dX); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gather, ops::GatherOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel); diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index 073e566e8f6962d62cc1b738672843421dcb4ee5..8276ed0d3d8b676aafab45fae70942e78b72b8e6 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -23,29 +23,40 @@ namespace operators { using Tensor = framework::Tensor; -template +template class GatherOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - auto *X = ctx.Input("X"); - auto *Index = ctx.Input("Index"); - auto *Y = ctx.Output("Out"); + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + + auto *x = ctx.Input("X"); + auto *index = ctx.Input("Index"); + auto *output = ctx.Output("Out"); + + output->mutable_data(ctx.GetPlace()); - Y->mutable_data(ctx.GetPlace()); - Gather(ctx.GetPlace(), X, Index, Y); + CPUGather(ctx.device_context(), *x, *index, output); } }; -template +template class GatherGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); + auto *Index = ctx.Input("Index"); auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); dX->mutable_data(ctx.GetPlace()); - ScatterUpdate(ctx.GetPlace(), dO, Index, dX); + auto dxt = framework::EigenVector::Flatten(*dX); + auto place = ctx.GetEigenDevice(); + dxt.device(place) = dxt.constant(static_cast(0)); + + ScatterAssign(ctx.device_context(), *dO, *Index, dX); } }; diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index 0ae1e99452973feb6d085dd6ef51e2afca988f59..cbd86b87961ee24aa889e208de5ac38e03a33135 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -41,7 +41,9 @@ TEST(Gather, GatherData) { int* p_output = output->mutable_data(make_ddim({2, 4}), CPUPlace()); - Gather(CPUPlace(), src, index, output); + auto* cpu_place = new paddle::platform::CPUPlace(); + paddle::platform::CPUDeviceContext ctx(*cpu_place); + CPUGather(ctx, *src, *index, output); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); diff --git a/paddle/operators/scatter.cu.h b/paddle/operators/scatter.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..d95436be4f25b9df4aaef57ddb249ecf944f0666 --- /dev/null +++ b/paddle/operators/scatter.cu.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void ScatterCUDAKernel(const T* params, const int* indices, + T* output, size_t index_size, + size_t slice_size) { + CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + int indices_i = i / slice_size; + int slice_i = i - indices_i * slice_size; // offset inside the slice + int scatter_i = indices[indices_i]; + int out_i = scatter_i * slice_size + slice_i; + *(output + out_i) = *(params + i); + } +} + +/** + * A thin wrapper on gpu tensor + * Return a new updated tensor from source tensor, scatter-assigned according to + * index + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ +template +void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { + // PADDLE_ENFORCE(platform::is_gpu_place(place)); + // check index of shape 1-D + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; + + auto src_dims = src.dims(); + framework::DDim output_dims(src_dims); + output_dims[0] = index_size; + + // slice size + int slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + const T* p_src = src.data(); + const int* p_index = index.data(); + T* p_output = output->data(); + + int block = 512; + int n = slice_size * index_size; + int grid = (n + block - 1) / block; + + ScatterCUDAKernel<<< + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + p_src, p_index, p_output, index_size, slice_size); +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h index fec1046a9748637ac41aa35143a8c2cf5528913f..c1fb844ebd2ff7ca7dbdb8e8ac3c1fff4c0c6607 100644 --- a/paddle/operators/scatter.h +++ b/paddle/operators/scatter.h @@ -24,63 +24,42 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; - -// Implementation of CPU copy -template -void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index, - const size_t index_size, - paddle::framework::Tensor* output) { - paddle::framework::DDim output_dims = output->dims(); - - for (size_t i = 0; i < index_size; ++i) { - int index_ = index[i]; - - paddle::framework::Tensor src_ = *src; - paddle::framework::Tensor output_ = *output; - if (index_size > 1) src_ = src->Slice(i, i + 1); - if (output_dims[0] > 1) output_ = output->Slice(index_, index_ + 1); - - auto X = EigenVector::Flatten(src_); - auto Y = EigenVector::Flatten(output_); - - Y = X + Y; - } -} - -// Implementation of GPU scatter: -template -void GPUScatterUpdate(const T* src, const int* index, const int slice_size, - const int index_size, T* output); /** * Return a updated tensor from source tensor, scattered according to index: - * dst[i] += src[index[i]] + * dst[i] = src[index[i]] * input[src]: type-T source Tensor * input[index]: type-int index Tensor (1-D) * return: output tensor */ template -void ScatterUpdate(const platform::Place& place, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); auto dst_dims = output->dims(); + const T* p_src = src.data(); + const int* p_index = index.data(); + T* p_output = output->data(); + // check src shape and dst shape should match for (int i = 1; i < src_dims.size(); i++) PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); - if (platform::is_cpu_place(place)) { - CPUScatterUpdate(src, index->data(), index_size, output); - } else { + // slice size + size_t slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + const size_t slice_bytes = slice_size * sizeof(T); + + for (int i = 0; i < index_size; ++i) { + int index_ = p_index[i]; + memcpy(p_output + index_ * slice_size, p_src + i * slice_size, slice_bytes); } } diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index cadd8841b6ab3a3674054240265eb6d4b474db1e..d15ba151539987c133ac57e102df53551483c6dd 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -97,8 +97,5 @@ Out[Index] = Ref[Index] + Updates namespace ops = paddle::operators; REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad, ops::ScatterGradOp); -REGISTER_OP_CPU_KERNEL(scatter, - ops::ScatterOpKernel); -REGISTER_OP_CPU_KERNEL( - scatter_grad, - ops::ScatterGradientOpKernel); +REGISTER_OP_CPU_KERNEL(scatter, ops::ScatterOpKernel); +REGISTER_OP_CPU_KERNEL(scatter_grad, ops::ScatterGradientOpKernel); diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..06f4d759447b6dcd28b50576dfc246fc466d9336 --- /dev/null +++ b/paddle/operators/scatter_op.cu @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "gather.cu.h" +#include "paddle/operators/gather_op.h" +#include "scatter.cu.h" + +namespace paddle { +namespace operators { + +template +class ScatterOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *Ref = ctx.Input("Ref"); + auto *Index = ctx.Input("Index"); + auto *Updates = ctx.Input("Updates"); + auto *Out = ctx.Output("Out"); + + Out->ShareDataWith(*Ref); + + GPUScatterAssign(ctx.device_context(), *Updates, *Index, Out); + } +}; + +template +class ScatterGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + auto *dRef = ctx.Output(framework::GradVarName("Ref")); + auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); + auto *Index = ctx.Input("Index"); + auto *dOut = ctx.Input(framework::GradVarName("Out")); + + // In place gradient: dRef = dO + dRef->ShareDataWith(*dOut); + dUpdates->mutable_data(ctx.GetPlace()); + // Gradient by Gather: dUpdates = dO[Index] + GPUGather(ctx.device_context(), *dOut, *Index, dUpdates); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(scatter, ops::ScatterOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel); diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index a8eb54399a932913de208e1ddc90a6ff0dfaa452..6101219006414e4865f676e3ca5d2a88949ad17a 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -23,10 +23,12 @@ namespace operators { using Tensor = framework::Tensor; -template +template class ScatterOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); auto *Ref = ctx.Input("Ref"); auto *Index = ctx.Input("Index"); auto *Updates = ctx.Input("Updates"); @@ -35,14 +37,16 @@ class ScatterOpKernel : public framework::OpKernel { // In place output: Out = Ref, Out[Index] += Updates Out->ShareDataWith(*Ref); // Apply ScatterUpdate: Out[index] += Updates[:] - ScatterUpdate(ctx.GetPlace(), Updates, Index, Out); + ScatterAssign(ctx.device_context(), *Updates, *Index, Out); } }; -template +template class ScatterGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "This kernel only runs on CPU."); auto *dRef = ctx.Output(framework::GradVarName("Ref")); auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); auto *Index = ctx.Input("Index"); @@ -52,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Index] - Gather(ctx.GetPlace(), dOut, Index, dUpdates); + CPUGather(ctx.device_context(), *dOut, *Index, dUpdates); } }; diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc index 26fdaff1460a297fa638181641991f732533fe52..00dbdacbfef7af826790472acc6caa285c259e0e 100644 --- a/paddle/operators/scatter_test.cc +++ b/paddle/operators/scatter_test.cc @@ -40,7 +40,9 @@ TEST(scatter, ScatterUpdate) { float* p_output = output->mutable_data(make_ddim({4, 4}), CPUPlace()); - ScatterUpdate(CPUPlace(), src, index, output); + auto* cpu_place = new paddle::platform::CPUPlace(); + paddle::platform::CPUDeviceContext ctx(*cpu_place); + ScatterAssign(ctx, *src, *index, output); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0)); diff --git a/python/paddle/v2/framework/tests/test_scatter_op.py b/python/paddle/v2/framework/tests/test_scatter_op.py index 33c73c52631a09ea0fefdeb9467991ae9c04321c..1032269d5dfb02e3518b9ef2820d5d0dcc8a51a0 100644 --- a/python/paddle/v2/framework/tests/test_scatter_op.py +++ b/python/paddle/v2/framework/tests/test_scatter_op.py @@ -10,7 +10,7 @@ class TestScatterOp(OpTest): index_np = np.array([1, 2]).astype("int32") updates_np = np.random.random((2, 3)).astype("float32") output_np = np.copy(ref_np) - output_np[index_np] += updates_np + output_np[index_np] = updates_np self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} self.outputs = {'Out': output_np} @@ -18,7 +18,7 @@ class TestScatterOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['Updates', 'Ref'], 'Out', in_place=True) + self.check_grad(['Updates'], 'Out', in_place=True) if __name__ == "__main__":