Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
236af566
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
236af566
编写于
9月 25, 2017
作者:
Y
Yibing Liu
浏览文件
操作
浏览文件
下载
差异文件
separate index tensor from candidate tensors in multiplex_op
上级
f94109d4
e114aad8
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
427 addition
and
51 deletion
+427
-51
doc/faq/index_cn.rst
doc/faq/index_cn.rst
+122
-1
paddle/operators/crop_op.h
paddle/operators/crop_op.h
+3
-3
paddle/operators/math/math_function.cc
paddle/operators/math/math_function.cc
+26
-0
paddle/operators/math/math_function.cu
paddle/operators/math/math_function.cu
+36
-0
paddle/operators/math/math_function.h
paddle/operators/math/math_function.h
+7
-0
paddle/operators/math/math_function_test.cc
paddle/operators/math/math_function_test.cc
+170
-0
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+25
-18
paddle/operators/multiplex_op.cu
paddle/operators/multiplex_op.cu
+14
-11
paddle/operators/multiplex_op.h
paddle/operators/multiplex_op.h
+15
-12
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+2
-1
python/paddle/v2/framework/tests/test_multiplex_op.py
python/paddle/v2/framework/tests/test_multiplex_op.py
+7
-5
未找到文件。
doc/faq/index_cn.rst
浏览文件 @
236af566
...
...
@@ -390,4 +390,125 @@ PaddlePaddle保存的模型参数文件内容由16字节头信息和网络参数
* 如果发现最早的报错就是网络通信的问题,很有可能是非独占方式执行导致的端口冲突,可以联系OP,看当前MPI集群是否支持resource=full参数提交,如果支持增加此参数提交,并更换job 端口。
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
\ No newline at end of file
* 如果当前MPI集群并不支持任务独占模式,可以联系OP是否可以更换集群或升级当前集群。
19. PaddlePaddle如何输出多个层
------------------------------
* 将需要输出的层作为 :code:`paddle.inference.Inference()` 接口的 :code:`output_layer` 参数输入,代码如下:
.. code-block:: python
inferer = paddle.inference.Inference(output_layer=[layer1, layer2], parameters=parameters)
* 指定要输出的字段进行输出。以输出 :code:`value` 字段为例,代码如下:
.. code-block:: python
out = inferer.infer(input=data_batch, flatten_result=False, field=["value"])
这里设置 :code:`flatten_result=False`,得到的输出结果是元素个数等于输出字段数的 :code:`list`,该 :code:`list` 的每个元素是由所有输出层相应字段结果组成的 :code:`list`,每个字段结果的类型是 :code:`numpy.array`。:code:`flatten_result` 的默认值为 :code:`True`,该情况下,PaddlePaddle会分别对每个字段将所有输出层的结果按行进行拼接,如果各输出层该字段 :code:`numpy.array` 结果的相应维数不匹配,程序将不能正常运行。
20. :code:`paddle.layer.memory` 的参数 :code:`name` 如何使用
-------------------------------------------------------------
* :code:`paddle.layer.memory` 用于获取特定layer上一时间步的输出,该layer是通过参数 :code:`name` 指定,即,:code:`paddle.layer.memory` 会关联参数 :code:`name` 取值相同的layer,并将该layer上一时间步的输出作为自身当前时间步的输出。
* PaddlePaddle的所有layer都有唯一的name,用户通过参数 :code:`name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。而 :code:`paddle.layer.memory` 不是真正的layer,其name由参数 :code:`memory_name` 设定,当用户没有显式设定时,PaddlePaddle会自动设定。:code:`paddle.layer.memory` 的参数 :code:`name` 用于指定其要关联的layer,需要用户显式设定。
21. dropout 使用
-----------------
* 在PaddlePaddle中使用dropout有两种方式
* 在相应layer的 :code:`layer_atter` 设置 :code:`drop_rate`,以 :code:`paddle.layer.fc` 为例,代码如下:
.. code-block:: python
fc = paddle.layer.fc(input=input, layer_attr=paddle.attr.ExtraLayerAttribute(drop_rate=0.5))
* 使用 :code:`paddle.layer.dropout`,以 :code:`paddle.layer.fc` 为例,代码如下:
.. code-block:: python
fc = paddle.layer.fc(input=input)
drop_fc = paddle.layer.dropout(input=fc, dropout_rate=0.5)
* :code:`paddle.layer.dropout` 实际上使用了 :code:`paddle.layer.add_to`,并在该layer里采用第一种方式设置 :code:`drop_rate` 来使用dropout的。这种方式对内存消耗较大。
* PaddlePaddle在激活函数里实现dropout,而不是在layer里实现。
* :code:`paddle.layer.lstmemory`、:code:`paddle.layer.grumemory`、:code:`paddle.layer.recurrent` 不是通过一般的方式来实现对输出的激活,所以不能采用第一种方式在这几个layer里设置 :code:`drop_rate` 来使用dropout。若要对这几个layer使用dropout,可采用第二种方式,即使用 :code:`paddle.layer.dropout`。
22. 如何设置学习率退火(learning rate annealing)
------------------------------------------------
在相应的优化算法里设置learning_rate_schedule及相关参数,以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_decay_a=0.5,
learning_rate_decay_b=0.75,
learning_rate_schedule="poly",)
PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedule及其对应学习率计算方式如下:
* "constant"
lr = learning_rate
* "poly"
lr = learning_rate * pow(1 + learning_rate_decay_a * num_samples_processed, -learning_rate_decay_b)
其中,num_samples_processed为已训练样本数,下同。
* "caffe_poly"
lr = learning_rate * pow(1.0 - num_samples_processed / learning_rate_decay_a, learning_rate_decay_b)
* "exp"
lr = learning_rate * pow(learning_rate_decay_a, num_samples_processed / learning_rate_decay_b)
* "discexp"
lr = learning_rate * pow(learning_rate_decay_a, floor(num_samples_processed / learning_rate_decay_b))
* "linear"
lr = max(learning_rate - learning_rate_decay_a * num_samples_processed, learning_rate_decay_b)
* "manual"
这是一种按已训练样本数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_schedule="manual",
learning_rate_args="1000:1.0,2000:0.9,3000:0.8",)
在该示例中,当已训练样本数小于等于1000时,学习率为 :code:`1e-3 * 1.0`;当已训练样本数大于1000小于等于2000时,学习率为 :code:`1e-3 * 0.9`;当已训练样本数大于2000时,学习率为 :code:`1e-3 * 0.8`。
* "pass_manual"
这是一种按已训练pass数分段取值的学习率退火方法。使用该learning_rate_schedule时,用户通过参数 :code:`learning_rate_args` 设置学习率衰减因子分段函数,当前的学习率为所设置 :code:`learning_rate` 与当前的衰减因子的乘积。以使用Adam算法为例,代码如下:
.. code-block:: python
optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
learning_rate_schedule="manual",
learning_rate_args="1:1.0,2:0.9,3:0.8",)
在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。
23. 出现 :code:`Duplicated layer name` 错误怎么办
--------------------------------------------------
出现该错误的原因一般是用户对不同layer的参数 :code:`name` 设置了相同的取值。遇到该错误时,先找出参数 :code:`name` 取值相同的layer,然后将这些layer的参数 :code:`name` 设置为不同的值。
paddle/operators/crop_op.h
浏览文件 @
236af566
...
...
@@ -38,10 +38,10 @@ class CropKernel : public framework::OpKernel {
auto
out_stride
=
framework
::
stride
(
out
->
dims
());
auto
offsets
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
PADDLE_ENFORCE_EQ
(
x
->
dims
().
size
(),
offsets
.
size
(
),
x
->
dims
().
size
(),
static_cast
<
int64_t
>
(
offsets
.
size
()
),
"Offsets size should be equal to dimension size of input tensor."
);
int64_t
offset
=
0
;
for
(
in
t
i
=
0
;
i
<
offsets
.
size
();
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
offsets
.
size
();
++
i
)
{
offset
+=
(
x_stride
[
i
]
*
offsets
[
i
]);
}
StridedMemcpy
<
T
>
(
context
.
device_context
(),
x_data
+
offset
,
x_stride
,
...
...
@@ -57,7 +57,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
offsets
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"offsets"
);
Eigen
::
array
<
std
::
pair
<
int
,
int
>
,
D
>
paddings
;
for
(
in
t
i
=
0
;
i
<
D
;
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
D
;
++
i
)
{
paddings
[
i
].
first
=
offsets
[
i
];
paddings
[
i
].
second
=
d_x
->
dims
()[
i
]
-
d_out
->
dims
()[
i
]
-
offsets
[
i
];
}
...
...
paddle/operators/math/math_function.cc
浏览文件 @
236af566
...
...
@@ -48,6 +48,32 @@ void gemm<platform::CPUPlace, double>(const platform::DeviceContext& context,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
cblas_sgemm
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
gemm
<
platform
::
CPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
cblas_dgemm
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
void
matmul
<
platform
::
CPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
...
...
paddle/operators/math/math_function.cu
浏览文件 @
236af566
...
...
@@ -63,6 +63,42 @@ void gemm<platform::GPUPlace, double>(const platform::DeviceContext& context,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
));
}
template
<
>
void
gemm
<
platform
::
GPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
float
alpha
,
const
float
*
A
,
const
int
lda
,
const
float
*
B
,
const
int
ldb
,
const
float
beta
,
float
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemm
(
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
));
}
template
<
>
void
gemm
<
platform
::
GPUPlace
,
double
>
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t
cuTransA
=
transA
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
transB
==
false
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasDgemm
(
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
));
}
template
<
>
void
matmul
<
platform
::
GPUPlace
,
float
>
(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
matrix_a
,
...
...
paddle/operators/math/math_function.h
浏览文件 @
236af566
...
...
@@ -70,6 +70,13 @@ void gemm(const platform::DeviceContext& context, const CBLAS_TRANSPOSE transA,
const
CBLAS_TRANSPOSE
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
T
*
B
,
const
T
beta
,
T
*
C
);
// gemm wrapper with stride args for matrix uncontinuous in memory
template
<
typename
Place
,
typename
T
>
void
gemm
(
const
platform
::
DeviceContext
&
context
,
const
bool
transA
,
const
bool
transB
,
const
int
M
,
const
int
N
,
const
int
K
,
const
T
alpha
,
const
T
*
A
,
const
int
lda
,
const
T
*
B
,
const
int
ldb
,
const
T
beta
,
T
*
C
,
const
int
ldc
);
// matrix multiply with continuous memory
template
<
typename
Place
,
typename
T
>
void
matmul
(
const
platform
::
DeviceContext
&
context
,
...
...
paddle/operators/math/math_function_test.cc
浏览文件 @
236af566
...
...
@@ -72,4 +72,174 @@ TEST(math_function, trans_mul_notrans) {
EXPECT_EQ
(
out_ptr
[
8
],
29
);
delete
gpu_place
;
}
TEST
(
math_function
,
gemm_notrans_cublas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
input3_gpu
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
3
,
4
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
input2_gpu
.
CopyFrom
<
float
>
(
input2
,
*
gpu_place
);
input3_gpu
.
CopyFrom
<
float
>
(
input3
,
*
gpu_place
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
1
,
4
,
1
,
c
+
1
,
4
);
input3
.
CopyFrom
<
float
>
(
input3_gpu
,
*
cpu_place
);
// numpy code:
// a = np.arange(6).reshape(2, 3)
// b = np.arange(12).reshape(3, 4)[:, 1:]
// c = np.arange(8).reshape(2, 4)[:, 1:]
// out = np.arange(8).reshape(2, 4)
// out[:, 1:] = np.dot(a, b) + c
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
}
TEST
(
math_function
,
gemm_trans_cublas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
paddle
::
framework
::
Tensor
input1_gpu
;
paddle
::
framework
::
Tensor
input2_gpu
;
paddle
::
framework
::
Tensor
input3_gpu
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
4
,
3
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
4
,
8
,
1
,
5
,
9
,
2
,
6
,
10
,
3
,
7
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
auto
*
gpu_place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
paddle
::
platform
::
CUDADeviceContext
context
(
*
gpu_place
);
input1_gpu
.
CopyFrom
<
float
>
(
input1
,
*
gpu_place
);
input2_gpu
.
CopyFrom
<
float
>
(
input2
,
*
gpu_place
);
input3_gpu
.
CopyFrom
<
float
>
(
input3
,
*
gpu_place
);
float
*
a
=
input1_gpu
.
data
<
float
>
();
float
*
b
=
input2_gpu
.
data
<
float
>
();
float
*
c
=
input3_gpu
.
mutable_data
<
float
>
(
*
gpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
GPUPlace
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
a
,
3
,
b
+
3
,
3
,
1
,
c
+
1
,
4
);
input3
.
CopyFrom
<
float
>
(
input3_gpu
,
*
cpu_place
);
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
delete
gpu_place
;
}
#endif
TEST
(
math_function
,
gemm_notrans_cblas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
3
,
4
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CPUPlace
,
float
>
(
context
,
false
,
false
,
m
,
n
,
k
,
1
,
input1_ptr
,
3
,
input2_ptr
+
1
,
4
,
1
,
input3_ptr
+
1
,
4
);
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
}
TEST
(
math_function
,
gemm_trans_clbas
)
{
paddle
::
framework
::
Tensor
input1
;
paddle
::
framework
::
Tensor
input2
;
paddle
::
framework
::
Tensor
input3
;
int
m
=
2
;
int
n
=
3
;
int
k
=
3
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
float
*
input1_ptr
=
input1
.
mutable_data
<
float
>
({
2
,
3
},
*
cpu_place
);
float
arr1
[
6
]
=
{
0
,
1
,
2
,
3
,
4
,
5
};
memcpy
(
input1_ptr
,
arr1
,
6
*
sizeof
(
float
));
float
*
input2_ptr
=
input2
.
mutable_data
<
float
>
({
4
,
3
},
*
cpu_place
);
float
arr2
[
12
]
=
{
0
,
4
,
8
,
1
,
5
,
9
,
2
,
6
,
10
,
3
,
7
,
11
};
memcpy
(
input2_ptr
,
arr2
,
12
*
sizeof
(
float
));
float
*
input3_ptr
=
input3
.
mutable_data
<
float
>
({
2
,
4
},
*
cpu_place
);
float
arr3
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
memcpy
(
input3_ptr
,
arr3
,
8
*
sizeof
(
float
));
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
paddle
::
operators
::
math
::
gemm
<
paddle
::
platform
::
CPUPlace
,
float
>
(
context
,
false
,
true
,
m
,
n
,
k
,
1
,
input1_ptr
,
3
,
input2_ptr
+
3
,
3
,
1
,
input3_ptr
+
1
,
4
);
EXPECT_EQ
(
input3_ptr
[
0
],
0
);
EXPECT_EQ
(
input3_ptr
[
1
],
24
);
EXPECT_EQ
(
input3_ptr
[
2
],
28
);
EXPECT_EQ
(
input3_ptr
[
3
],
32
);
EXPECT_EQ
(
input3_ptr
[
4
],
4
);
EXPECT_EQ
(
input3_ptr
[
5
],
73
);
EXPECT_EQ
(
input3_ptr
[
6
],
86
);
EXPECT_EQ
(
input3_ptr
[
7
],
99
);
}
paddle/operators/multiplex_op.cc
浏览文件 @
236af566
...
...
@@ -25,24 +25,30 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Ids"
),
"Input(Ids) shouldn't be null."
);
PADDLE_ENFORCE
(
!
ctx
.
MultiInputVar
(
"X"
).
empty
(),
"
Input(X) should not be null
."
);
"
MultiInput(X) shouldn't be empty
."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
"Out"
),
"Output(Out) shouldn't be null."
);
auto
ids_dim
=
ctx
.
Input
<
Tensor
>
(
"Ids"
)
->
dims
();
PADDLE_ENFORCE
(
ids_dim
.
size
()
==
2
&&
ids_dim
[
1
]
==
1
,
"The index tensor must be a vector with size batchSize x 1."
);
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
num_ins
=
ins
.
size
();
PADDLE_ENFORCE
(
num_ins
>
2
,
"multiplex operator should have more than 2 inputs."
);
PADDLE_ENFORCE_EQ
(
ins
[
0
]
->
dims
().
size
(),
1
,
"The first input must be a index vector."
);
auto
in_dim
=
ins
[
1
]
->
dims
();
PADDLE_ENFORCE
(
num_ins
>
1
,
"multiplex operator should have more than "
"one candidate input tensors."
);
for
(
size_t
i
=
2
;
i
<
num_ins
;
i
++
)
{
auto
in_dim
=
ins
[
0
]
->
dims
();
PADDLE_ENFORCE
(
in_dim
.
size
()
==
2
,
"Candidate tensors must be matrix."
);
for
(
size_t
i
=
1
;
i
<
num_ins
;
i
++
)
{
auto
dim
=
ins
[
i
]
->
dims
();
PADDLE_ENFORCE
(
in_dim
==
dim
,
"All the input tensors except the first one must have the "
"same size."
);
"All the candidate tensors must have the same size."
);
}
out
->
Resize
(
in_dim
);
}
...
...
@@ -53,25 +59,26 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
MultiplexOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensors of multiplex operator."
).
AsDuplicable
();
AddInput
(
"Ids"
,
"The index tensor of multiplex operator."
);
AddInput
(
"X"
,
"The candidate tensors of multiplex operator."
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"The output tensor of multiplex operator."
);
AddComment
(
R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first
input tensor.
ins[0]
: the index tensor.
ins[1:N]: the candidate output tensors
.
Ids
: the index tensor.
X[0 : N - 1]: the candidate tensors for output (N >= 2)
.
For each index i from 0 to batchSize - 1, the output is the i-th row of the
the (
index[i] + 1
)-th tensor.
the (
Ids[i]
)-th tensor.
For i-th row of the output tensor:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{
1
}.width - 1)
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{
0
}.width - 1)
where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = x{0}[i] + 1`.
and `k = Ids[i]`.
)DOC"
);
}
};
...
...
@@ -90,8 +97,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null."
);
auto
d_ins
=
ctx
.
MultiOutput
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
//
don't compute gradient for index (ins[0]
)
for
(
size_t
i
=
1
;
i
<
ins
.
size
();
i
++
)
{
//
No need to compute gradient for Input(Ids
)
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
if
(
d_ins
[
i
])
{
d_ins
[
i
]
->
Resize
(
ins
[
i
]
->
dims
());
}
...
...
paddle/operators/multiplex_op.cu
浏览文件 @
236af566
...
...
@@ -25,21 +25,23 @@ class MultiplexGPUKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
*
ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
rows
=
ins
[
1
]
->
dims
()[
0
];
auto
cols
=
ins
[
1
]
->
dims
()[
1
];
auto
rows
=
ins
[
0
]
->
dims
()[
0
];
auto
cols
=
ins
[
0
]
->
dims
()[
1
];
// copy index to cpu
Tensor
index_t_cpu
;
index_t_cpu
.
CopyFrom
<
T
>
(
*
(
ins
[
0
])
,
platform
::
CPUPlace
());
auto
*
index
=
index_t_cpu
.
data
<
T
>
();
index_t_cpu
.
CopyFrom
<
int32_t
>
(
*
ids
,
platform
::
CPUPlace
());
auto
*
index
=
index_t_cpu
.
data
<
int32_t
>
();
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
();
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
size_t
k
=
(
size_t
)
index
[
i
]
+
1
;
int32_t
k
=
index
[
i
];
PADDLE_ENFORCE_GE
(
k
,
0
,
"index must be nonnegative."
);
PADDLE_ENFORCE_LT
(
k
,
ins
.
size
(),
"index exceeds the number of candidate tensors."
);
memory
::
Copy
(
place
,
out
->
data
<
T
>
()
+
i
*
cols
,
place
,
...
...
@@ -54,8 +56,9 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
*
ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
d_ins
=
ctx
.
MultiOutput
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
for
(
size_t
i
=
1
;
i
<
d_ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
d_ins
.
size
();
i
++
)
{
if
(
d_ins
[
i
])
{
d_ins
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_ins
[
i
]);
...
...
@@ -63,19 +66,19 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
}
}
auto
rows
=
ins
[
1
]
->
dims
()[
0
];
auto
cols
=
ins
[
1
]
->
dims
()[
1
];
auto
rows
=
ins
[
0
]
->
dims
()[
0
];
auto
cols
=
ins
[
0
]
->
dims
()[
1
];
// copy index to cpu
Tensor
index_t_cpu
;
index_t_cpu
.
CopyFrom
<
T
>
(
*
(
ins
[
0
])
,
platform
::
CPUPlace
());
auto
*
index
=
index_t_cpu
.
data
<
T
>
();
index_t_cpu
.
CopyFrom
<
int32_t
>
(
*
ids
,
platform
::
CPUPlace
());
auto
*
index
=
index_t_cpu
.
data
<
int32_t
>
();
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
();
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
size_t
k
=
(
size_t
)
index
[
i
]
+
1
;
size_t
k
=
static_cast
<
size_t
>
(
index
[
i
])
;
if
(
d_ins
[
k
])
{
memory
::
Copy
(
place
,
d_ins
[
k
]
->
data
<
T
>
()
+
i
*
cols
,
place
,
d_out
->
data
<
T
>
()
+
i
*
cols
,
cols
*
sizeof
(
T
),
stream
);
...
...
paddle/operators/multiplex_op.h
浏览文件 @
236af566
...
...
@@ -27,17 +27,19 @@ class MultiplexCPUKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
ids
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Ids"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
rows
=
ins
[
1
]
->
dims
()[
0
];
auto
cols
=
ins
[
1
]
->
dims
()[
1
];
auto
*
index
=
ins
[
0
]
->
data
<
T
>
();
auto
rows
=
ins
[
0
]
->
dims
()[
0
];
auto
cols
=
ins
[
0
]
->
dims
()[
1
];
auto
index
=
ids
->
data
<
int32_t
>
();
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
size_t
k
=
(
size_t
)
index
[
i
]
+
1
;
PADDLE_ENFORCE_LT
(
k
,
ins
.
size
(),
int32_t
k
=
index
[
i
];
PADDLE_ENFORCE_GE
(
k
,
0
,
"index must be nonnegative."
);
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
k
),
ins
.
size
(),
"index exceeds the number of candidate tensors."
);
memory
::
Copy
(
place
,
out
->
data
<
T
>
()
+
i
*
cols
,
place
,
ins
[
k
]
->
data
<
T
>
()
+
i
*
cols
,
cols
*
sizeof
(
T
));
...
...
@@ -50,10 +52,11 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
ids
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Ids"
);
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
d_ins
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
for
(
size_t
i
=
1
;
i
<
d_ins
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
d_ins
.
size
();
i
++
)
{
if
(
d_ins
[
i
])
{
d_ins
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
d_ins
[
i
]);
...
...
@@ -61,12 +64,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
}
}
auto
rows
=
ins
[
1
]
->
dims
()[
0
];
auto
cols
=
ins
[
1
]
->
dims
()[
1
];
auto
*
index
=
i
ns
[
0
]
->
data
<
T
>
();
auto
rows
=
ins
[
0
]
->
dims
()[
0
];
auto
cols
=
ins
[
0
]
->
dims
()[
1
];
auto
*
index
=
i
ds
->
data
<
int32_t
>
();
Place
place
=
boost
::
get
<
Place
>
(
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
rows
;
i
++
)
{
size_t
k
=
(
size_t
)
index
[
i
]
+
1
;
size_t
k
=
static_cast
<
size_t
>
(
index
[
i
])
;
if
(
d_ins
[
k
])
{
memory
::
Copy
(
place
,
d_ins
[
k
]
->
data
<
T
>
()
+
i
*
cols
,
place
,
d_out
->
data
<
T
>
()
+
i
*
cols
,
cols
*
sizeof
(
T
));
...
...
@@ -74,5 +77,5 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
}
}
};
}
}
}
// namespace operators
}
// namespace paddle
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
236af566
...
...
@@ -921,7 +921,7 @@ def data_layer(name, size, depth=None, height=None, width=None,
data = data_layer(name="input", size=1000)
:param name: The name of this layer.
It is optional.
:param name: The name of this layer.
:type name: basestring
:param size: Size of this data layer.
:type size: int
...
...
@@ -3668,6 +3668,7 @@ def gru_step_naive_layer(input,
:param param_attr:
:param layer_attr:
:return:
:rtype: LayerOutput
"""
if
input
.
size
%
3
!=
0
:
raise
ValueError
(
"GruStep input size must be divided by 3"
)
...
...
python/paddle/v2/framework/tests/test_multiplex_op.py
浏览文件 @
236af566
...
...
@@ -6,20 +6,22 @@ from op_test import OpTest
class
TestMultiplexOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"multiplex"
rows
=
3
index
=
np
.
array
([
3
,
1
,
0
])
rows
=
4
index
=
np
.
arange
(
0
,
rows
).
astype
(
'int32'
)
np
.
random
.
shuffle
(
index
)
index
=
np
.
reshape
(
index
,
(
rows
,
1
))
ins1
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins2
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins3
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
ins4
=
np
.
random
.
random
((
rows
,
10
)).
astype
(
"float32"
)
self
.
inputs
=
{
'
X'
:
[(
'index'
,
index
),
(
'x1'
,
ins1
),
(
'x2'
,
ins2
),
(
'x3'
,
ins3
)
,
(
'x4'
,
ins4
)]
'
Ids'
:
index
,
'X'
:
[(
'x1'
,
ins1
),
(
'x2'
,
ins2
),
(
'x3'
,
ins3
),
(
'x4'
,
ins4
)]
}
# multiplex output
output
=
np
.
zeros_like
(
ins1
)
for
i
in
range
(
0
,
rows
):
k
=
index
[
i
]
+
1
k
=
index
[
i
]
[
0
]
output
[
i
]
=
self
.
inputs
[
'X'
][
k
][
1
][
i
]
self
.
outputs
=
{
'Out'
:
output
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录