diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 55822827d9aa9b157a5a8ee93f71702994651847..7d7f1ba3b1178bf931debb9bd2eb9e2901b09bf5 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet( dim[0] = index_tensors[i].dims()[0]; tensor_child->mutable_data(dim, platform::CPUPlace()); - CPUGather(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i], - tensor_child); + CPUGather(dev_ctx, tensor_parent, &index_tensors[i], tensor_child); } } @@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, Variable* var_child = sub_scopes[i]->FindVar(output); PADDLE_ENFORCE_NOT_NULL(var_child); auto* tensor_child = &var_child->Get(); - ScatterAssign(dev_ctx.GetPlace(), tensor_child, &index_tensors[i], + ScatterAssign(dev_ctx, tensor_child, &index_tensors[i], tensor_parent); } } diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index cb635f6825591fb80fc48f511a2f3652a348bd1b..1e39a6da271eb9e319e91ef28a895e986ce65491 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -32,11 +32,11 @@ namespace operators { * return: output tensor */ template -void CPUGather(const platform::Place& place, +void CPUGather(const platform::DeviceContext& ctx, const paddle::framework::Tensor* src, const paddle::framework::Tensor* index, paddle::framework::Tensor* output) { - PADDLE_ENFORCE(platform::is_cpu_place(place)); + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index fb065b8da7d043c30c6684ab8347f96a5d0007d1..5bd2c36f7b767718979a8990e540461e206ee830 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - CPUGather(ctx.GetPlace(), x, index, output); + CPUGather(ctx.device_context(), x, index, output); } }; @@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - ScatterAssign(ctx.GetPlace(), dO, Index, dX); + ScatterAssign(ctx.device_context(), dO, Index, dX); } }; diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index 3c1d06ccd10b62a2c6dc25c299b329502d6db83d..d8bf8dd9a42f86fe1d2a7b454912f632d4907ec5 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -41,7 +41,9 @@ TEST(Gather, GatherData) { int* p_output = output->mutable_data(make_ddim({2, 4}), CPUPlace()); - CPUGather(CPUPlace(), src, index, output); + auto* cpu_place = new paddle::platform::CPUPlace(); + paddle::platform::CPUDeviceContext ctx(*cpu_place); + CPUGather(ctx, src, index, output); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h index f895f22e281eeb9ff2c9cb927d3fd9e923e145e8..0d174d3b5b84326690589c9f0c39e3e1460aaead 100644 --- a/paddle/operators/scatter.h +++ b/paddle/operators/scatter.h @@ -33,11 +33,11 @@ using Tensor = framework::Tensor; * return: output tensor */ template -void ScatterAssign(const platform::Place& place, +void ScatterAssign(const platform::DeviceContext& ctx, const paddle::framework::Tensor* src, const paddle::framework::Tensor* index, paddle::framework::Tensor* output) { - PADDLE_ENFORCE(platform::is_cpu_place(place)); + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D PADDLE_ENFORCE(index->dims().size() == 1); int index_size = index->dims()[0]; diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index 771a1f2ddb79874038792d9eff1d733d00535051..ac04968549cb0fcbaf318bcc79e1a95470cb42f1 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel { // In place output: Out = Ref, Out[Index] += Updates Out->ShareDataWith(*Ref); // Apply ScatterUpdate: Out[index] += Updates[:] - ScatterAssign(ctx.GetPlace(), Updates, Index, Out); + ScatterAssign(ctx.device_context(), Updates, Index, Out); } }; @@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Index] - CPUGather(ctx.GetPlace(), dOut, Index, dUpdates); + CPUGather(ctx.device_context(), dOut, Index, dUpdates); } }; diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc index bace6419d0be4b44b246251750bea685f14a106b..321bba3dadd68d08d332e9d48dcc1de11eed652a 100644 --- a/paddle/operators/scatter_test.cc +++ b/paddle/operators/scatter_test.cc @@ -40,7 +40,9 @@ TEST(scatter, ScatterUpdate) { float* p_output = output->mutable_data(make_ddim({4, 4}), CPUPlace()); - ScatterAssign(CPUPlace(), src, index, output); + auto* cpu_place = new paddle::platform::CPUPlace(); + paddle::platform::CPUDeviceContext ctx(*cpu_place); + ScatterAssign(ctx, src, index, output); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0));