提交 2d876b86 编写于 作者: Z zchen0211

gather scatter fix according to google style

上级 2ccaec4f
......@@ -126,7 +126,7 @@ void CondOp::PrepareDataForSubnet(
dim[0] = index_tensors[i].dims()[0];
tensor_child->mutable_data<float>(dim, platform::CPUPlace());
CPUGather<float>(dev_ctx, tensor_parent, &index_tensors[i], tensor_child);
CPUGather<float>(dev_ctx, *tensor_parent, index_tensors[i], tensor_child);
}
}
......@@ -187,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
Variable* var_child = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_child);
auto* tensor_child = &var_child->Get<LoDTensor>();
ScatterAssign<float>(dev_ctx, tensor_child, &index_tensors[i],
ScatterAssign<float>(dev_ctx, *tensor_child, index_tensors[i],
tensor_parent);
}
}
......
......@@ -46,14 +46,14 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
* return: output tensor
*/
template <typename T>
void GPUGather(const platform::DeviceContext& ctx, const Tensor* src,
const Tensor* index, Tensor* output) {
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
auto src_dims = src->dims();
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
......@@ -61,8 +61,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor* src,
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src->data<T>();
const int* p_index = index->data<int>();
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
int block = 512;
......
......@@ -24,6 +24,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::Tensor;
/**
* A thin wrapper for gathering on cpu tensor
* Return a new tensor from source tensor, gathered according to index
......@@ -32,21 +34,19 @@ namespace operators {
* return: output tensor
*/
template <typename T>
void CPUGather(const platform::DeviceContext& ctx,
const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
auto src_dims = src->dims();
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
const T* p_src = src->data<T>();
const int* p_index = index->data<int>();
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
// slice size
......
......@@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
GPUGather<T>(ctx.device_context(), x, index, output);
GPUGather<T>(ctx.device_context(), *x, *index, output);
}
};
......@@ -52,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetEigenDevice<platform::GPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
GPUScatterAssign<T>(ctx.device_context(), dO, Index, dX);
GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
}
};
......
......@@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
CPUGather<T>(ctx.device_context(), x, index, output);
CPUGather<T>(ctx.device_context(), *x, *index, output);
}
};
......@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.device_context(), dO, Index, dX);
ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
}
};
......
......@@ -43,7 +43,7 @@ TEST(Gather, GatherData) {
auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
CPUGather<int>(ctx, src, index, output);
CPUGather<int>(ctx, *src, *index, output);
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
......
......@@ -19,6 +19,8 @@
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
......@@ -45,16 +47,14 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices,
* return: output tensor
*/
template <typename T>
void GPUScatterAssign(const platform::DeviceContext& ctx,
const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
auto src_dims = src->dims();
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
......@@ -62,8 +62,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx,
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src->data<T>();
const int* p_index = index->data<int>();
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
int block = 512;
......
......@@ -33,20 +33,18 @@ using Tensor = framework::Tensor;
* return: output tensor
*/
template <typename T>
void ScatterAssign(const platform::DeviceContext& ctx,
const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0];
auto src_dims = src->dims();
auto src_dims = src.dims();
auto dst_dims = output->dims();
const T* p_src = src->data<T>();
const int* p_index = index->data<int>();
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
T* p_output = output->data<T>();
// check src shape and dst shape should match
......
......@@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
Out->ShareDataWith<T>(*Ref);
GPUScatterAssign<T>(ctx.device_context(), Updates, Index, Out);
GPUScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
}
};
......@@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates = dO[Index]
GPUGather<T>(ctx.device_context(), dOut, Index, dUpdates);
GPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
}
};
......
......@@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
// In place output: Out = Ref, Out[Index] += Updates
Out->ShareDataWith<T>(*Ref);
// Apply ScatterUpdate: Out[index] += Updates[:]
ScatterAssign<T>(ctx.device_context(), Updates, Index, Out);
ScatterAssign<T>(ctx.device_context(), *Updates, *Index, Out);
}
};
......@@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index]
CPUGather<T>(ctx.device_context(), dOut, Index, dUpdates);
CPUGather<T>(ctx.device_context(), *dOut, *Index, dUpdates);
}
};
......
......@@ -42,7 +42,7 @@ TEST(scatter, ScatterUpdate) {
auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
ScatterAssign<float>(ctx, src, index, output);
ScatterAssign<float>(ctx, *src, *index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0));
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册