Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2d876b86
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2d876b86
编写于
10月 03, 2017
作者:
Z
zchen0211
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gather scatter fix according to google style
上级
2ccaec4f
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
44 addition
and
46 deletion
+44
-46
paddle/operators/cond_op.cc
paddle/operators/cond_op.cc
+2
-2
paddle/operators/gather.cu.h
paddle/operators/gather.cu.h
+7
-7
paddle/operators/gather.h
paddle/operators/gather.h
+9
-9
paddle/operators/gather_op.cu
paddle/operators/gather_op.cu
+2
-2
paddle/operators/gather_op.h
paddle/operators/gather_op.h
+2
-2
paddle/operators/gather_test.cc
paddle/operators/gather_test.cc
+1
-1
paddle/operators/scatter.cu.h
paddle/operators/scatter.cu.h
+9
-9
paddle/operators/scatter.h
paddle/operators/scatter.h
+7
-9
paddle/operators/scatter_op.cu
paddle/operators/scatter_op.cu
+2
-2
paddle/operators/scatter_op.h
paddle/operators/scatter_op.h
+2
-2
paddle/operators/scatter_test.cc
paddle/operators/scatter_test.cc
+1
-1
未找到文件。
paddle/operators/cond_op.cc
浏览文件 @
2d876b86
...
@@ -126,7 +126,7 @@ void CondOp::PrepareDataForSubnet(
...
@@ -126,7 +126,7 @@ void CondOp::PrepareDataForSubnet(
dim
[
0
]
=
index_tensors
[
i
].
dims
()[
0
];
dim
[
0
]
=
index_tensors
[
i
].
dims
()[
0
];
tensor_child
->
mutable_data
<
float
>
(
dim
,
platform
::
CPUPlace
());
tensor_child
->
mutable_data
<
float
>
(
dim
,
platform
::
CPUPlace
());
CPUGather
<
float
>
(
dev_ctx
,
tensor_parent
,
&
index_tensors
[
i
],
tensor_child
);
CPUGather
<
float
>
(
dev_ctx
,
*
tensor_parent
,
index_tensors
[
i
],
tensor_child
);
}
}
}
}
...
@@ -187,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
...
@@ -187,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
Variable
*
var_child
=
sub_scopes
[
i
]
->
FindVar
(
output
);
Variable
*
var_child
=
sub_scopes
[
i
]
->
FindVar
(
output
);
PADDLE_ENFORCE_NOT_NULL
(
var_child
);
PADDLE_ENFORCE_NOT_NULL
(
var_child
);
auto
*
tensor_child
=
&
var_child
->
Get
<
LoDTensor
>
();
auto
*
tensor_child
=
&
var_child
->
Get
<
LoDTensor
>
();
ScatterAssign
<
float
>
(
dev_ctx
,
tensor_child
,
&
index_tensors
[
i
],
ScatterAssign
<
float
>
(
dev_ctx
,
*
tensor_child
,
index_tensors
[
i
],
tensor_parent
);
tensor_parent
);
}
}
}
}
...
...
paddle/operators/gather.cu.h
浏览文件 @
2d876b86
...
@@ -46,14 +46,14 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
...
@@ -46,14 +46,14 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
* return: output tensor
* return: output tensor
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
GPUGather
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
*
src
,
void
GPUGather
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
src
,
const
Tensor
*
index
,
Tensor
*
output
)
{
const
Tensor
&
index
,
Tensor
*
output
)
{
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
// check index of shape 1-D
PADDLE_ENFORCE
(
index
->
dims
().
size
()
==
1
);
PADDLE_ENFORCE
(
index
.
dims
().
size
()
==
1
);
int
index_size
=
index
->
dims
()[
0
];
int
index_size
=
index
.
dims
()[
0
];
auto
src_dims
=
src
->
dims
();
auto
src_dims
=
src
.
dims
();
framework
::
DDim
output_dims
(
src_dims
);
framework
::
DDim
output_dims
(
src_dims
);
output_dims
[
0
]
=
index_size
;
output_dims
[
0
]
=
index_size
;
...
@@ -61,8 +61,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor* src,
...
@@ -61,8 +61,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor* src,
int
slice_size
=
1
;
int
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
for
(
int
i
=
1
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
const
T
*
p_src
=
src
->
data
<
T
>
();
const
T
*
p_src
=
src
.
data
<
T
>
();
const
int
*
p_index
=
index
->
data
<
int
>
();
const
int
*
p_index
=
index
.
data
<
int
>
();
T
*
p_output
=
output
->
data
<
T
>
();
T
*
p_output
=
output
->
data
<
T
>
();
int
block
=
512
;
int
block
=
512
;
...
...
paddle/operators/gather.h
浏览文件 @
2d876b86
...
@@ -24,6 +24,8 @@ limitations under the License. */
...
@@ -24,6 +24,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
framework
::
Tensor
;
/**
/**
* A thin wrapper for gathering on cpu tensor
* A thin wrapper for gathering on cpu tensor
* Return a new tensor from source tensor, gathered according to index
* Return a new tensor from source tensor, gathered according to index
...
@@ -32,21 +34,19 @@ namespace operators {
...
@@ -32,21 +34,19 @@ namespace operators {
* return: output tensor
* return: output tensor
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
CPUGather
(
const
platform
::
DeviceContext
&
ctx
,
void
CPUGather
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
src
,
const
paddle
::
framework
::
Tensor
*
src
,
const
Tensor
&
index
,
Tensor
*
output
)
{
const
paddle
::
framework
::
Tensor
*
index
,
paddle
::
framework
::
Tensor
*
output
)
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()));
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()));
// check index of shape 1-D
// check index of shape 1-D
PADDLE_ENFORCE
(
index
->
dims
().
size
()
==
1
);
PADDLE_ENFORCE
(
index
.
dims
().
size
()
==
1
);
int
index_size
=
index
->
dims
()[
0
];
int
index_size
=
index
.
dims
()[
0
];
auto
src_dims
=
src
->
dims
();
auto
src_dims
=
src
.
dims
();
framework
::
DDim
output_dims
(
src_dims
);
framework
::
DDim
output_dims
(
src_dims
);
output_dims
[
0
]
=
index_size
;
output_dims
[
0
]
=
index_size
;
const
T
*
p_src
=
src
->
data
<
T
>
();
const
T
*
p_src
=
src
.
data
<
T
>
();
const
int
*
p_index
=
index
->
data
<
int
>
();
const
int
*
p_index
=
index
.
data
<
int
>
();
T
*
p_output
=
output
->
data
<
T
>
();
T
*
p_output
=
output
->
data
<
T
>
();
// slice size
// slice size
...
...
paddle/operators/gather_op.cu
浏览文件 @
2d876b86
...
@@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -32,7 +32,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
GPUGather
<
T
>
(
ctx
.
device_context
(),
x
,
index
,
output
);
GPUGather
<
T
>
(
ctx
.
device_context
(),
*
x
,
*
index
,
output
);
}
}
};
};
...
@@ -52,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -52,7 +52,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto
place
=
ctx
.
GetEigenDevice
<
platform
::
GPUPlace
>
();
auto
place
=
ctx
.
GetEigenDevice
<
platform
::
GPUPlace
>
();
dxt
.
device
(
place
)
=
dxt
.
constant
(
static_cast
<
T
>
(
0
));
dxt
.
device
(
place
)
=
dxt
.
constant
(
static_cast
<
T
>
(
0
));
GPUScatterAssign
<
T
>
(
ctx
.
device_context
(),
dO
,
Index
,
dX
);
GPUScatterAssign
<
T
>
(
ctx
.
device_context
(),
*
dO
,
*
Index
,
dX
);
}
}
};
};
...
...
paddle/operators/gather_op.h
浏览文件 @
2d876b86
...
@@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel<T> {
...
@@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel<T> {
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
CPUGather
<
T
>
(
ctx
.
device_context
(),
x
,
index
,
output
);
CPUGather
<
T
>
(
ctx
.
device_context
(),
*
x
,
*
index
,
output
);
}
}
};
};
...
@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
...
@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto
place
=
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
();
auto
place
=
ctx
.
GetEigenDevice
<
platform
::
CPUPlace
>
();
dxt
.
device
(
place
)
=
dxt
.
constant
(
static_cast
<
T
>
(
0
));
dxt
.
device
(
place
)
=
dxt
.
constant
(
static_cast
<
T
>
(
0
));
ScatterAssign
<
T
>
(
ctx
.
device_context
(),
dO
,
Index
,
dX
);
ScatterAssign
<
T
>
(
ctx
.
device_context
(),
*
dO
,
*
Index
,
dX
);
}
}
};
};
...
...
paddle/operators/gather_test.cc
浏览文件 @
2d876b86
...
@@ -43,7 +43,7 @@ TEST(Gather, GatherData) {
...
@@ -43,7 +43,7 @@ TEST(Gather, GatherData) {
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
paddle
::
platform
::
CPUDeviceContext
ctx
(
*
cpu_place
);
paddle
::
platform
::
CPUDeviceContext
ctx
(
*
cpu_place
);
CPUGather
<
int
>
(
ctx
,
src
,
index
,
output
);
CPUGather
<
int
>
(
ctx
,
*
src
,
*
index
,
output
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
EXPECT_EQ
(
p_output
[
i
],
i
+
4
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
EXPECT_EQ
(
p_output
[
i
],
i
+
4
);
for
(
int
i
=
4
;
i
<
8
;
++
i
)
EXPECT_EQ
(
p_output
[
i
],
i
-
4
);
for
(
int
i
=
4
;
i
<
8
;
++
i
)
EXPECT_EQ
(
p_output
[
i
],
i
-
4
);
...
...
paddle/operators/scatter.cu.h
浏览文件 @
2d876b86
...
@@ -19,6 +19,8 @@
...
@@ -19,6 +19,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
#define CUDA_1D_KERNEL_LOOP(i, n) \
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
i += blockDim.x * gridDim.x)
...
@@ -45,16 +47,14 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices,
...
@@ -45,16 +47,14 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices,
* return: output tensor
* return: output tensor
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
GPUScatterAssign
(
const
platform
::
DeviceContext
&
ctx
,
void
GPUScatterAssign
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
src
,
const
paddle
::
framework
::
Tensor
*
src
,
const
Tensor
&
index
,
Tensor
*
output
)
{
const
paddle
::
framework
::
Tensor
*
index
,
paddle
::
framework
::
Tensor
*
output
)
{
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
// check index of shape 1-D
PADDLE_ENFORCE
(
index
->
dims
().
size
()
==
1
);
PADDLE_ENFORCE
(
index
.
dims
().
size
()
==
1
);
int
index_size
=
index
->
dims
()[
0
];
int
index_size
=
index
.
dims
()[
0
];
auto
src_dims
=
src
->
dims
();
auto
src_dims
=
src
.
dims
();
framework
::
DDim
output_dims
(
src_dims
);
framework
::
DDim
output_dims
(
src_dims
);
output_dims
[
0
]
=
index_size
;
output_dims
[
0
]
=
index_size
;
...
@@ -62,8 +62,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx,
...
@@ -62,8 +62,8 @@ void GPUScatterAssign(const platform::DeviceContext& ctx,
int
slice_size
=
1
;
int
slice_size
=
1
;
for
(
int
i
=
1
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
for
(
int
i
=
1
;
i
<
src_dims
.
size
();
++
i
)
slice_size
*=
src_dims
[
i
];
const
T
*
p_src
=
src
->
data
<
T
>
();
const
T
*
p_src
=
src
.
data
<
T
>
();
const
int
*
p_index
=
index
->
data
<
int
>
();
const
int
*
p_index
=
index
.
data
<
int
>
();
T
*
p_output
=
output
->
data
<
T
>
();
T
*
p_output
=
output
->
data
<
T
>
();
int
block
=
512
;
int
block
=
512
;
...
...
paddle/operators/scatter.h
浏览文件 @
2d876b86
...
@@ -33,20 +33,18 @@ using Tensor = framework::Tensor;
...
@@ -33,20 +33,18 @@ using Tensor = framework::Tensor;
* return: output tensor
* return: output tensor
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
ScatterAssign
(
const
platform
::
DeviceContext
&
ctx
,
void
ScatterAssign
(
const
platform
::
DeviceContext
&
ctx
,
const
Tensor
&
src
,
const
paddle
::
framework
::
Tensor
*
src
,
const
Tensor
&
index
,
Tensor
*
output
)
{
const
paddle
::
framework
::
Tensor
*
index
,
paddle
::
framework
::
Tensor
*
output
)
{
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()));
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()));
// check index of shape 1-D
// check index of shape 1-D
PADDLE_ENFORCE
(
index
->
dims
().
size
()
==
1
);
PADDLE_ENFORCE
(
index
.
dims
().
size
()
==
1
);
int
index_size
=
index
->
dims
()[
0
];
int
index_size
=
index
.
dims
()[
0
];
auto
src_dims
=
src
->
dims
();
auto
src_dims
=
src
.
dims
();
auto
dst_dims
=
output
->
dims
();
auto
dst_dims
=
output
->
dims
();
const
T
*
p_src
=
src
->
data
<
T
>
();
const
T
*
p_src
=
src
.
data
<
T
>
();
const
int
*
p_index
=
index
->
data
<
int
>
();
const
int
*
p_index
=
index
.
data
<
int
>
();
T
*
p_output
=
output
->
data
<
T
>
();
T
*
p_output
=
output
->
data
<
T
>
();
// check src shape and dst shape should match
// check src shape and dst shape should match
...
...
paddle/operators/scatter_op.cu
浏览文件 @
2d876b86
...
@@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -32,7 +32,7 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
Out
->
ShareDataWith
<
T
>
(
*
Ref
);
Out
->
ShareDataWith
<
T
>
(
*
Ref
);
GPUScatterAssign
<
T
>
(
ctx
.
device_context
(),
Updates
,
Index
,
Out
);
GPUScatterAssign
<
T
>
(
ctx
.
device_context
(),
*
Updates
,
*
Index
,
Out
);
}
}
};
};
...
@@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -51,7 +51,7 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
dRef
->
ShareDataWith
<
T
>
(
*
dOut
);
dRef
->
ShareDataWith
<
T
>
(
*
dOut
);
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Gradient by Gather: dUpdates = dO[Index]
// Gradient by Gather: dUpdates = dO[Index]
GPUGather
<
T
>
(
ctx
.
device_context
(),
dOut
,
Index
,
dUpdates
);
GPUGather
<
T
>
(
ctx
.
device_context
(),
*
dOut
,
*
Index
,
dUpdates
);
}
}
};
};
...
...
paddle/operators/scatter_op.h
浏览文件 @
2d876b86
...
@@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
...
@@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
// In place output: Out = Ref, Out[Index] += Updates
// In place output: Out = Ref, Out[Index] += Updates
Out
->
ShareDataWith
<
T
>
(
*
Ref
);
Out
->
ShareDataWith
<
T
>
(
*
Ref
);
// Apply ScatterUpdate: Out[index] += Updates[:]
// Apply ScatterUpdate: Out[index] += Updates[:]
ScatterAssign
<
T
>
(
ctx
.
device_context
(),
Updates
,
Index
,
Out
);
ScatterAssign
<
T
>
(
ctx
.
device_context
(),
*
Updates
,
*
Index
,
Out
);
}
}
};
};
...
@@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
...
@@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
dRef
->
ShareDataWith
<
T
>
(
*
dOut
);
dRef
->
ShareDataWith
<
T
>
(
*
dOut
);
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dUpdates
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Gradient by Gather: dUpdates += dO[Index]
// Gradient by Gather: dUpdates += dO[Index]
CPUGather
<
T
>
(
ctx
.
device_context
(),
dOut
,
Index
,
dUpdates
);
CPUGather
<
T
>
(
ctx
.
device_context
(),
*
dOut
,
*
Index
,
dUpdates
);
}
}
};
};
...
...
paddle/operators/scatter_test.cc
浏览文件 @
2d876b86
...
@@ -42,7 +42,7 @@ TEST(scatter, ScatterUpdate) {
...
@@ -42,7 +42,7 @@ TEST(scatter, ScatterUpdate) {
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
paddle
::
platform
::
CPUDeviceContext
ctx
(
*
cpu_place
);
paddle
::
platform
::
CPUDeviceContext
ctx
(
*
cpu_place
);
ScatterAssign
<
float
>
(
ctx
,
src
,
index
,
output
);
ScatterAssign
<
float
>
(
ctx
,
*
src
,
*
index
,
output
);
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
EXPECT_EQ
(
p_output
[
i
],
float
(
0
));
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
EXPECT_EQ
(
p_output
[
i
],
float
(
0
));
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
EXPECT_EQ
(
output
->
data
<
float
>
()[
i
],
float
(
0
));
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
EXPECT_EQ
(
output
->
data
<
float
>
()[
i
],
float
(
0
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录