提交 2d876b86 编写于 作者: Z zchen0211

gather scatter fix according to google style

上级 2ccaec4f
...@@ -126,7 +126,7 @@ void CondOp::PrepareDataForSubnet( ...@@ -126,7 +126,7 @@ void CondOp::PrepareDataForSubnet(
dim[0] = index_tensors[i].dims()[0]; dim[0] = index_tensors[i].dims()[0];
tensor_child->mutable_data<float>(dim, platform::CPUPlace()); tensor_child->mutable_data<float>(dim, platform::CPUPlace());
CPUGather<float>(dev_ctx, tensor_parent, &index_tensors[i], tensor_child); CPUGather<float>(dev_ctx, *tensor_parent, index_tensors[i], tensor_child);
} }
} }
...@@ -187,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, ...@@ -187,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
Variable* var_child = sub_scopes[i]->FindVar(output); Variable* var_child = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_child); PADDLE_ENFORCE_NOT_NULL(var_child);
auto* tensor_child = &var_child->Get<LoDTensor>(); auto* tensor_child = &var_child->Get<LoDTensor>();
ScatterAssign<float>(dev_ctx, tensor_child, &index_tensors[i], ScatterAssign<float>(dev_ctx, *tensor_child, index_tensors[i],
tensor_parent); tensor_parent);
} }
} }
......
...@@ -46,14 +46,14 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, ...@@ -46,14 +46,14 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void GPUGather(const platform::DeviceContext& ctx, const Tensor* src, void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor* index, Tensor* output) { const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place)); // PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index.dims()[0];
auto src_dims = src->dims(); auto src_dims = src.dims();
framework::DDim output_dims(src_dims); framework::DDim output_dims(src_dims);
output_dims[0] = index_size; output_dims[0] = index_size;
...@@ -61,8 +61,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor* src, ...@@ -61,8 +61,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor* src,
int slice_size = 1; int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src->data<T>(); const T* p_src = src.data<T>();
const int* p_index = index->data<int>(); const int* p_index = index.data<int>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
int block = 512; int block = 512;
......
...@@ -24,6 +24,8 @@ limitations under the License. */ ...@@ -24,6 +24,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
/** /**
* A thin wrapper for gathering on cpu tensor * A thin wrapper for gathering on cpu tensor
* Return a new tensor from source tensor, gathered according to index * Return a new tensor from source tensor, gathered according to index
...@@ -32,21 +34,19 @@ namespace operators { ...@@ -32,21 +34,19 @@ namespace operators {
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void CPUGather(const platform::DeviceContext& ctx, void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const paddle::framework::Tensor* src, const Tensor& index, Tensor* output) {
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index.dims()[0];
auto src_dims = src->dims(); auto src_dims = src.dims();
framework::DDim output_dims(src_dims); framework::DDim output_dims(src_dims);
output_dims[0] = index_size; output_dims[0] = index_size;
const T* p_src = src->data<T>(); const T* p_src = src.data<T>();
const int* p_index = index->data<int>(); const int* p_index = index.data<int>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
// slice size // slice size
......
...@@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
GPUGather<T>(ctx.device_context(), x, index, output); GPUGather<T>(ctx.device_context(), *x, *index, output);
} }
}; };
...@@ -52,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -52,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetEigenDevice<platform::GPUPlace>(); auto place = ctx.GetEigenDevice<platform::GPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
GPUScatterAssign<T>(ctx.device_context(), dO, Index, dX); GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
} }
}; };
......
...@@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel<T> { ...@@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
CPUGather<T>(ctx.device_context(), x, index, output); CPUGather<T>(ctx.device_context(), *x, *index, output);
} }
}; };
...@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto place = ctx.GetEigenDevice<platform::CPUPlace>(); auto place = ctx.GetEigenDevice<platform::CPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.device_context(), dO, Index, dX); ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
} }
}; };
......
...@@ -43,7 +43,7 @@ TEST(Gather, GatherData) { ...@@ -43,7 +43,7 @@ TEST(Gather, GatherData) {
auto* cpu_place = new paddle::platform::CPUPlace(); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place); paddle::platform::CPUDeviceContext ctx(*cpu_place);
CPUGather<int>(ctx, src, index, output); CPUGather<int>(ctx, *src, *index, output);
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
#define CUDA_1D_KERNEL_LOOP(i, n) \ #define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x) i += blockDim.x * gridDim.x)
...@@ -45,16 +47,14 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices, ...@@ -45,16 +47,14 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices,
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void GPUScatterAssign(const platform::DeviceContext& ctx, void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const paddle::framework::Tensor* src, const Tensor& index, Tensor* output) {
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place)); // PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index.dims()[0];
auto src_dims = src->dims(); auto src_dims = src.dims();
framework::DDim output_dims(src_dims); framework::DDim output_dims(src_dims);
output_dims[0] = index_size; output_dims[0] = index_size;
...@@ -62,8 +62,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, ...@@ -62,8 +62,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx,
int slice_size = 1; int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src->data<T>(); const T* p_src = src.data<T>();
const int* p_index = index->data<int>(); const int* p_index = index.data<int>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
int block = 512; int block = 512;
......
...@@ -33,20 +33,18 @@ using Tensor = framework::Tensor; ...@@ -33,20 +33,18 @@ using Tensor = framework::Tensor;
* return: output tensor * return: output tensor
*/ */
template <typename T> template <typename T>
void ScatterAssign(const platform::DeviceContext& ctx, void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const paddle::framework::Tensor* src, const Tensor& index, Tensor* output) {
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index->dims()[0]; int index_size = index.dims()[0];
auto src_dims = src->dims(); auto src_dims = src.dims();
auto dst_dims = output->dims(); auto dst_dims = output->dims();
const T* p_src = src->data<T>(); const T* p_src = src.data<T>();
const int* p_index = index->data<int>(); const int* p_index = index.data<int>();
T* p_output = output->data<T>(); T* p_output = output->data<T>();
// check src shape and dst shape should match // check src shape and dst shape should match
......
...@@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
Out->ShareDataWith<T>(*Ref); Out->ShareDataWith<T>(*Ref);
GPUScatterAssign<T>(ctx.device_context(), Updates, Index, Out); GPUScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
} }
}; };
...@@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
dRef->ShareDataWith<T>(*dOut); dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Index] // Gradient by Gather: dUpdates = dO[Index]
GPUGather<T>(ctx.device_context(), dOut, Index, dUpdates); GPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
} }
}; };
......
...@@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> { ...@@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
// In place output: Out = Ref, Out[Index] += Updates // In place output: Out = Ref, Out[Index] += Updates
Out->ShareDataWith<T>(*Ref); Out->ShareDataWith<T>(*Ref);
// Apply ScatterUpdate: Out[index] += Updates[:] // Apply ScatterUpdate: Out[index] += Updates[:]
ScatterAssign<T>(ctx.device_context(), Updates, Index, Out); ScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
} }
}; };
...@@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> { ...@@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
dRef->ShareDataWith<T>(*dOut); dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace()); dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index] // Gradient by Gather: dUpdates += dO[Index]
CPUGather<T>(ctx.device_context(), dOut, Index, dUpdates); CPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
} }
}; };
......
...@@ -42,7 +42,7 @@ TEST(scatter, ScatterUpdate) { ...@@ -42,7 +42,7 @@ TEST(scatter, ScatterUpdate) {
auto* cpu_place = new paddle::platform::CPUPlace(); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place); paddle::platform::CPUDeviceContext ctx(*cpu_place);
ScatterAssign<float>(ctx, src, index, output); ScatterAssign<float>(ctx, *src, *index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0));
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0)); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册