From 714007115290164d68ebe2fa85b4e37749023cf0 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 15 Mar 2018 10:29:37 +0800 Subject: [PATCH] "exported scatter to python" (#9038) * "exported scatter to python" * Revert ""exported scatter to python"" This reverts commit 38745a626c3f937bec836c92c98a76deadf0a03d. * "polish scatter and export to python" --- paddle/fluid/operators/scatter_op.cc | 35 +++++++++---------- paddle/fluid/operators/scatter_op.cu | 20 +++++------ paddle/fluid/operators/scatter_op.h | 22 ++++++------ python/paddle/fluid/layers/ops.py | 32 ++++------------- .../fluid/tests/unittests/test_scatter_op.py | 2 +- 5 files changed, 46 insertions(+), 65 deletions(-) diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 3fb8b56d26..d6fd621471 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -23,24 +23,24 @@ class ScatterOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Ref"), - "Input(Ref) of ScatterOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Index"), - "Input(Index) of ScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ScatterOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Ids"), + "Input(Ids) of ScatterOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Updates"), "Input(Updates) of ScatterOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of ScatterOp should not be null."); auto updates_dims = ctx->GetInputDim("Updates"); - auto ref_dims = ctx->GetInputDim("Ref"); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Index").size(), 1, - "Update Index should be 1-D."); + auto ref_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Ids").size(), 1, + "Update Ids should be 1-D."); PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(), - "Reference and Updates should have the same shape size"); + "Xerence and Updates should have the same shape size"); PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0], - ctx->GetInputDim("Index")[0], - "Updates and Index should have same batch-size."); + ctx->GetInputDim("Ids")[0], + "Updates and Ids should have same batch-size."); framework::DDim data_dim(updates_dims); for (int i = 1; i < data_dim.size(); ++i) { PADDLE_ENFORCE_EQ(data_dim[i], updates_dims[i]); @@ -52,7 +52,7 @@ class ScatterOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Ref")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.device_context()); } }; @@ -64,14 +64,14 @@ class ScatterGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("Updates"), ctx->GetInputDim("Updates")); - ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Ref")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.device_context()); } }; @@ -80,9 +80,8 @@ class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { public: ScatterOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Ref", "The source input of scatter op"); - AddInput("Index", - "The index input of scatter op where Ref will be updated"); + AddInput("X", "The source input of scatter op"); + AddInput("Ids", "The index input of scatter op where X will be updated"); AddInput("Updates", "The updated value of updates op"); AddOutput("Out", "The output of add op"); AddComment(R"DOC( @@ -91,8 +90,8 @@ Scatter Operator. This operator obtains output by updating the input on selected indices on the first axis: $$ -Out = Ref \\ -Out[Index] = Ref[Index] + Updates +Out = X \\ +Out[Ids] = X[Ids] + Updates $$ )DOC"); diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index bdabb29fa6..ef7d700659 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -25,14 +25,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto *Ref = ctx.Input("Ref"); - auto *Index = ctx.Input("Index"); + auto *X = ctx.Input("X"); + auto *Ids = ctx.Input("Ids"); auto *Updates = ctx.Input("Updates"); auto *Out = ctx.Output("Out"); - Out->ShareDataWith(*Ref); + Out->ShareDataWith(*X); - GPUScatterAssign(ctx.device_context(), *Updates, *Index, Out); + GPUScatterAssign(ctx.device_context(), *Updates, *Ids, Out); } }; @@ -42,16 +42,16 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto *dRef = ctx.Output(framework::GradVarName("Ref")); + auto *dX = ctx.Output(framework::GradVarName("X")); auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); - auto *Index = ctx.Input("Index"); + auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - // In place gradient: dRef = dO - dRef->ShareDataWith(*dOut); + // In place gradient: dX = dO + dX->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates = dO[Index] - GPUGather(ctx.device_context(), *dOut, *Index, dUpdates); + // Gradient by Gather: dUpdates = dO[Ids] + GPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); } }; diff --git a/paddle/fluid/operators/scatter_op.h b/paddle/fluid/operators/scatter_op.h index 3c6e7ece32..2151d8a924 100644 --- a/paddle/fluid/operators/scatter_op.h +++ b/paddle/fluid/operators/scatter_op.h @@ -29,15 +29,15 @@ class ScatterOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "This kernel only runs on CPU."); - auto *Ref = ctx.Input("Ref"); - auto *Index = ctx.Input("Index"); + auto *X = ctx.Input("X"); + auto *Ids = ctx.Input("Ids"); auto *Updates = ctx.Input("Updates"); auto *Out = ctx.Output("Out"); - // In place output: Out = Ref, Out[Index] += Updates - Out->ShareDataWith(*Ref); + // In place output: Out = X, Out[Ids] += Updates + Out->ShareDataWith(*X); // Apply ScatterUpdate: Out[index] += Updates[:] - ScatterAssign(ctx.device_context(), *Updates, *Index, Out); + ScatterAssign(ctx.device_context(), *Updates, *Ids, Out); } }; @@ -47,16 +47,16 @@ class ScatterGradientOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "This kernel only runs on CPU."); - auto *dRef = ctx.Output(framework::GradVarName("Ref")); + auto *dX = ctx.Output(framework::GradVarName("X")); auto *dUpdates = ctx.Output(framework::GradVarName("Updates")); - auto *Index = ctx.Input("Index"); + auto *Ids = ctx.Input("Ids"); auto *dOut = ctx.Input(framework::GradVarName("Out")); - // In place gradient: dRef = dO - dRef->ShareDataWith(*dOut); + // In place gradient: dX = dO + dX->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); - // Gradient by Gather: dUpdates += dO[Index] - CPUGather(ctx.device_context(), *dOut, *Index, dUpdates); + // Gradient by Gather: dUpdates += dO[Ids] + CPUGather(ctx.device_context(), *dOut, *Ids, dUpdates); } }; diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 0b88b63962..14ad18d508 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -45,31 +45,13 @@ __activations__ = [ ] __all__ = [ - 'mean', - 'mul', - 'reshape', - 'scale', - 'sigmoid_cross_entropy_with_logits', - 'elementwise_add', - 'elementwise_div', - 'elementwise_sub', - 'elementwise_mul', - 'elementwise_max', - 'elementwise_min', - 'elementwise_pow', - 'clip', - 'clip_by_norm', - 'softmax', - 'sequence_softmax', - 'logical_and', - 'logical_or', - 'logical_xor', - 'logical_not', - 'uniform_random', - 'uniform_random_batch_size_like', - 'gaussian_random', - 'gaussian_random_batch_size_like', - 'cumsum', + 'mean', 'mul', 'reshape', 'scale', 'sigmoid_cross_entropy_with_logits', + 'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul', + 'elementwise_max', 'elementwise_min', 'elementwise_pow', 'clip', + 'clip_by_norm', 'softmax', 'sequence_softmax', 'logical_and', 'logical_or', + 'logical_xor', 'logical_not', 'uniform_random', + 'uniform_random_batch_size_like', 'gaussian_random', + 'gaussian_random_batch_size_like', 'cumsum', 'scatter' ] + __activations__ for _OP in set(__all__): diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index bb02a40d44..fb17287436 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -25,7 +25,7 @@ class TestScatterOp(OpTest): updates_np = np.random.random((2, 3)).astype("float32") output_np = np.copy(ref_np) output_np[index_np] = updates_np - self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np} + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} self.outputs = {'Out': output_np} def test_check_output(self): -- GitLab