From 2ccaec4f57afe94f36ee4781bae6e0eec78b29a8 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Mon, 2 Oct 2017 18:31:55 -0700 Subject: [PATCH] gather scatter cond --- paddle/operators/cond_op.cc | 5 ++--- paddle/operators/gather.h | 4 ++-- paddle/operators/gather_op.h | 4 ++-- paddle/operators/gather_test.cc | 4 +++- paddle/operators/scatter.h | 4 ++-- paddle/operators/scatter_op.h | 4 ++-- paddle/operators/scatter_test.cc | 4 +++- 7 files changed, 16 insertions(+), 13 deletions(-) diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 55822827d9a..7d7f1ba3b11 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet( dim[0] = index_tensors[i].dims()[0]; tensor_child->mutable_data(dim, platform::CPUPlace()); - CPUGather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], - tensor_child); + CPUGather(dev_ctx, tensor_parent, &index_tensors[i], tensor_child); } } @@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, Variable* var_child = sub_scopes[i]->FindVar(output); PADDLE_ENFORCE_NOT_NULL(var_child); auto* tensor_child = &var_child->Get(); - ScatterAssign(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], + ScatterAssign(dev_ctx, tensor_child, &index_tensors[i], tensor_parent); } } diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index cb635f68255..1e39a6da271 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -32,11 +32,11 @@ namespace operators { * return: output tensor */ template -void CPUGather(const platform::Place& place, +void CPUGather(const platform::DeviceContext& ctx, const paddle::framework::Tensor* src, const paddle::framework::Tensor* index, paddle::framework::Tensor* output) { - PADDLE_ENFORCE(platform::is_cpu_place(place)); + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index fb065b8da7d..5bd2c36f7b7 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - CPUGather(ctx.GetPlace(), x, index, output); + CPUGather(ctx.device_context(), x, index, output); } }; @@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - ScatterAssign(ctx.GetPlace(), dO, Index, dX); + ScatterAssign(ctx.device_context(), dO, Index, dX); } }; diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index 3c1d06ccd10..d8bf8dd9a42 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -41,7 +41,9 @@ TEST(Gather, GatherData) { int* p_output = output->mutable_data(make_ddim({2, 4}), CPUPlace()); - CPUGather(CPUPlace(), src, index, output); + auto* cpu_place = new paddle::platform::CPUPlace(); + paddle::platform::CPUDeviceContext ctx(*cpu_place); + CPUGather(ctx, src, index, output); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h index f895f22e281..0d174d3b5b8 100644 --- a/paddle/operators/scatter.h +++ b/paddle/operators/scatter.h @@ -33,11 +33,11 @@ using Tensor = framework::Tensor; * return: output tensor */ template -void ScatterAssign(const platform::Place& place, +void ScatterAssign(const platform::DeviceContext& ctx, const paddle::framework::Tensor* src, const paddle::framework::Tensor* index, paddle::framework::Tensor* output) { - PADDLE_ENFORCE(platform::is_cpu_place(place)); + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index 771a1f2ddb7..ac04968549c 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel { // In place output: Out = Ref, Out[Index] += Updates Out->ShareDataWith(*Ref); // Apply ScatterUpdate: Out[index] += Updates[:] - ScatterAssign(ctx.GetPlace(), Updates, Index, Out); + ScatterAssign(ctx.device_context(), Updates, Index, Out); } }; @@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Index] - CPUGather(ctx.GetPlace(), dOut, Index, dUpdates); + CPUGather(ctx.device_context(), dOut, Index, dUpdates); } }; diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc index bace6419d0b..321bba3dadd 100644 --- a/paddle/operators/scatter_test.cc +++ b/paddle/operators/scatter_test.cc @@ -40,7 +40,9 @@ TEST(scatter, ScatterUpdate) { float* p_output = output->mutable_data(make_ddim({4, 4}), CPUPlace()); - ScatterAssign(CPUPlace(), src, index, output); + auto* cpu_place = new paddle::platform::CPUPlace(); + paddle::platform::CPUDeviceContext ctx(*cpu_place); + ScatterAssign(ctx, src, index, output); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0)); -- GitLab