diff --git a/paddle/operators/cond_op.cc b/paddle/operators/cond_op.cc index 7d7f1ba3b1178bf931debb9bd2eb9e2901b09bf5..2737104a205cbc1e18ce4a3a45592a416d38a874 100644 --- a/paddle/operators/cond_op.cc +++ b/paddle/operators/cond_op.cc @@ -126,7 +126,7 @@ void CondOp::PrepareDataForSubnet( dim[0] = index_tensors[i].dims()[0]; tensor_child->mutable_data(dim, platform::CPUPlace()); - CPUGather(dev_ctx, tensor_parent, &index_tensors[i], tensor_child); + CPUGather(dev_ctx, *tensor_parent, index_tensors[i], tensor_child); } } @@ -187,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, Variable* var_child = sub_scopes[i]->FindVar(output); PADDLE_ENFORCE_NOT_NULL(var_child); auto* tensor_child = &var_child->Get(); - ScatterAssign(dev_ctx, tensor_child, &index_tensors[i], + ScatterAssign(dev_ctx, *tensor_child, index_tensors[i], tensor_parent); } } diff --git a/paddle/operators/gather.cu.h b/paddle/operators/gather.cu.h index 2ae11376a2ba9d2a140ad468e263cf06f25417ca..8d04ecd284226c7b4c6cdd5531915fee2d94ce61 100644 --- a/paddle/operators/gather.cu.h +++ b/paddle/operators/gather.cu.h @@ -46,14 +46,14 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, * return: output tensor */ template -void GPUGather(const platform::DeviceContext& ctx, const Tensor* src, - const Tensor* index, Tensor* output) { +void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); framework::DDim output_dims(src_dims); output_dims[0] = index_size; @@ -61,8 +61,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor* src, int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - const T* p_src = src->data(); - const int* p_index = index->data(); + const T* p_src = src.data(); + const int* p_index = index.data(); T* p_output = output->data(); int block = 512; diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index 1e39a6da271eb9e319e91ef28a895e986ce65491..052db49cb3c2594eca8b9a5e3716689480089703 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -24,6 +24,8 @@ limitations under the License. */ namespace paddle { namespace operators { +using framework::Tensor; + /** * A thin wrapper for gathering on cpu tensor * Return a new tensor from source tensor, gathered according to index @@ -32,21 +34,19 @@ namespace operators { * return: output tensor */ template -void CPUGather(const platform::DeviceContext& ctx, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); framework::DDim output_dims(src_dims); output_dims[0] = index_size; - const T* p_src = src->data(); - const int* p_index = index->data(); + const T* p_src = src.data(); + const int* p_index = index.data(); T* p_output = output->data(); // slice size diff --git a/paddle/operators/gather_op.cu b/paddle/operators/gather_op.cu index 9937be5915d38ea55507fbabce4b5ce6b2f9618f..92219d6a433e6db0bb9886ed8670cbafaa843ff8 100644 --- a/paddle/operators/gather_op.cu +++ b/paddle/operators/gather_op.cu @@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - GPUGather(ctx.device_context(), x, index, output); + GPUGather(ctx.device_context(), *x, *index, output); } }; @@ -52,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - GPUScatterAssign(ctx.device_context(), dO, Index, dX); + GPUScatterAssign(ctx.device_context(), *dO, *Index, dX); } }; diff --git a/paddle/operators/gather_op.h b/paddle/operators/gather_op.h index 5bd2c36f7b767718979a8990e540461e206ee830..8276ed0d3d8b676aafab45fae70942e78b72b8e6 100644 --- a/paddle/operators/gather_op.h +++ b/paddle/operators/gather_op.h @@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); - CPUGather(ctx.device_context(), x, index, output); + CPUGather(ctx.device_context(), *x, *index, output); } }; @@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel { auto place = ctx.GetEigenDevice(); dxt.device(place) = dxt.constant(static_cast(0)); - ScatterAssign(ctx.device_context(), dO, Index, dX); + ScatterAssign(ctx.device_context(), *dO, *Index, dX); } }; diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index d8bf8dd9a42f86fe1d2a7b454912f632d4907ec5..cbd86b87961ee24aa889e208de5ac38e03a33135 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -43,7 +43,7 @@ TEST(Gather, GatherData) { auto* cpu_place = new paddle::platform::CPUPlace(); paddle::platform::CPUDeviceContext ctx(*cpu_place); - CPUGather(ctx, src, index, output); + CPUGather(ctx, *src, *index, output); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); diff --git a/paddle/operators/scatter.cu.h b/paddle/operators/scatter.cu.h index f4a3965d94c11fe53e72d4c6df5a6f88ccc420fb..d95436be4f25b9df4aaef57ddb249ecf944f0666 100644 --- a/paddle/operators/scatter.cu.h +++ b/paddle/operators/scatter.cu.h @@ -19,6 +19,8 @@ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + #define CUDA_1D_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) @@ -45,16 +47,14 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices, * return: output tensor */ template -void GPUScatterAssign(const platform::DeviceContext& ctx, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); framework::DDim output_dims(src_dims); output_dims[0] = index_size; @@ -62,8 +62,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, int slice_size = 1; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - const T* p_src = src->data(); - const int* p_index = index->data(); + const T* p_src = src.data(); + const int* p_index = index.data(); T* p_output = output->data(); int block = 512; diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h index 0d174d3b5b84326690589c9f0c39e3e1460aaead..c1fb844ebd2ff7ca7dbdb8e8ac3c1fff4c0c6607 100644 --- a/paddle/operators/scatter.h +++ b/paddle/operators/scatter.h @@ -33,20 +33,18 @@ using Tensor = framework::Tensor; * return: output tensor */ template -void ScatterAssign(const platform::DeviceContext& ctx, - const paddle::framework::Tensor* src, - const paddle::framework::Tensor* index, - paddle::framework::Tensor* output) { +void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, + const Tensor& index, Tensor* output) { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); // check index of shape 1-D - PADDLE_ENFORCE(index->dims().size() == 1); - int index_size = index->dims()[0]; + PADDLE_ENFORCE(index.dims().size() == 1); + int index_size = index.dims()[0]; - auto src_dims = src->dims(); + auto src_dims = src.dims(); auto dst_dims = output->dims(); - const T* p_src = src->data(); - const int* p_index = index->data(); + const T* p_src = src.data(); + const int* p_index = index.data(); T* p_output = output->data(); // check src shape and dst shape should match diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu index 6d13a876f98a400f8b0b8cdc0f3ecf297895e335..06f4d759447b6dcd28b50576dfc246fc466d9336 100644 --- a/paddle/operators/scatter_op.cu +++ b/paddle/operators/scatter_op.cu @@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel { Out->ShareDataWith(*Ref); - GPUScatterAssign(ctx.device_context(), Updates, Index, Out); + GPUScatterAssign(ctx.device_context(), *Updates, *Index, Out); } }; @@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates = dO[Index] - GPUGather(ctx.device_context(), dOut, Index, dUpdates); + GPUGather(ctx.device_context(), *dOut, *Index, dUpdates); } }; diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h index ac04968549cb0fcbaf318bcc79e1a95470cb42f1..6101219006414e4865f676e3ca5d2a88949ad17a 100644 --- a/paddle/operators/scatter_op.h +++ b/paddle/operators/scatter_op.h @@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel { // In place output: Out = Ref, Out[Index] += Updates Out->ShareDataWith(*Ref); // Apply ScatterUpdate: Out[index] += Updates[:] - ScatterAssign(ctx.device_context(), Updates, Index, Out); + ScatterAssign(ctx.device_context(), *Updates, *Index, Out); } }; @@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel { dRef->ShareDataWith(*dOut); dUpdates->mutable_data(ctx.GetPlace()); // Gradient by Gather: dUpdates += dO[Index] - CPUGather(ctx.device_context(), dOut, Index, dUpdates); + CPUGather(ctx.device_context(), *dOut, *Index, dUpdates); } }; diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc index 321bba3dadd68d08d332e9d48dcc1de11eed652a..00dbdacbfef7af826790472acc6caa285c259e0e 100644 --- a/paddle/operators/scatter_test.cc +++ b/paddle/operators/scatter_test.cc @@ -42,7 +42,7 @@ TEST(scatter, ScatterUpdate) { auto* cpu_place = new paddle::platform::CPUPlace(); paddle::platform::CPUDeviceContext ctx(*cpu_place); - ScatterAssign(ctx, src, index, output); + ScatterAssign(ctx, *src, *index, output); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0));