Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
023fe1e3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
023fe1e3
编写于
8月 04, 2017
作者:
T
Tao Luo
提交者:
GitHub
8月 04, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3208 from luotao1/save_log_prob
Save log prob
上级
2dc34c96
a58a77e8
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
9 deletion
+29
-9
paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
+19
-7
paddle/gserver/gradientmachines/RecurrentGradientMachine.h
paddle/gserver/gradientmachines/RecurrentGradientMachine.h
+10
-2
未找到文件。
paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
浏览文件 @
023fe1e3
...
@@ -967,8 +967,9 @@ void RecurrentGradientMachine::generateSequence() {
...
@@ -967,8 +967,9 @@ void RecurrentGradientMachine::generateSequence() {
size_t
numSequences
=
getGenBatchSize
();
size_t
numSequences
=
getGenBatchSize
();
resizeBootFrame
(
numSequences
);
resizeBootFrame
(
numSequences
);
// We create only two sub-network in generation for alternate use.
// We create only two sub-network in generation, one stores states of all
// Thus, we can reduce total memory of output_ in layer forward.
// layers in previous time step and the other storing the states at current
// time step.
resizeOrCreateFrames
(
2
);
resizeOrCreateFrames
(
2
);
// outFrameLines_.size() > 1UL
// outFrameLines_.size() > 1UL
...
@@ -1001,10 +1002,9 @@ void RecurrentGradientMachine::generateSequence() {
...
@@ -1001,10 +1002,9 @@ void RecurrentGradientMachine::generateSequence() {
// init outArg
// init outArg
size_t
resultNum
=
generator_
.
config
.
num_results_per_sample
();
size_t
resultNum
=
generator_
.
config
.
num_results_per_sample
();
IVector
::
resizeOrCreate
(
size_t
maxGenWordCount
=
generator_
.
outArg
.
ids
,
generator_
.
config
.
max_num_frames
()
*
numSequences
*
resultNum
;
generator_
.
config
.
max_num_frames
()
*
numSequences
*
resultNum
,
IVector
::
resizeOrCreate
(
generator_
.
outArg
.
ids
,
maxGenWordCount
,
false
);
false
);
if
(
resultNum
>
1
)
{
if
(
resultNum
>
1
)
{
CHECK_LE
(
resultNum
,
static_cast
<
size_t
>
(
generator_
.
config
.
beam_size
()));
CHECK_LE
(
resultNum
,
static_cast
<
size_t
>
(
generator_
.
config
.
beam_size
()));
Matrix
::
resizeOrCreate
(
generator_
.
outArg
.
in
,
Matrix
::
resizeOrCreate
(
generator_
.
outArg
.
in
,
...
@@ -1012,6 +1012,11 @@ void RecurrentGradientMachine::generateSequence() {
...
@@ -1012,6 +1012,11 @@ void RecurrentGradientMachine::generateSequence() {
/* width */
resultNum
,
/* width */
resultNum
,
false
,
false
,
/* useGpu */
false
);
/* useGpu */
false
);
Matrix
::
resizeOrCreate
(
generator_
.
outArg
.
value
,
/* height */
maxGenWordCount
,
/* width */
1
,
false
,
/* useGpu */
false
);
}
}
ICpuGpuVector
::
resizeOrCreate
(
generator_
.
outArg
.
sequenceStartPositions
,
ICpuGpuVector
::
resizeOrCreate
(
generator_
.
outArg
.
sequenceStartPositions
,
numSequences
+
1
,
numSequences
+
1
,
...
@@ -1313,13 +1318,20 @@ void RecurrentGradientMachine::fillGenOutputs() {
...
@@ -1313,13 +1318,20 @@ void RecurrentGradientMachine::fillGenOutputs() {
starts
[
0
]
=
0
;
starts
[
0
]
=
0
;
if
(
numResults
>
1
)
{
if
(
numResults
>
1
)
{
real
*
probs
=
generator_
.
outArg
.
in
->
getData
();
real
*
probs
=
generator_
.
outArg
.
in
->
getData
();
real
*
idsProb
=
generator_
.
outArg
.
value
->
getData
();
size_t
curPos
=
0
;
for
(
size_t
i
=
0
;
i
<
finalPaths_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
finalPaths_
.
size
();
++
i
)
{
for
(
size_t
j
=
0
;
j
<
finalPaths_
[
i
].
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
finalPaths_
[
i
].
size
();
++
j
)
{
Path
&
path
=
finalPaths_
[
i
][
j
];
Path
&
path
=
finalPaths_
[
i
][
j
];
generator_
.
ids
.
push_back
(
path
.
ids
.
size
());
// sequence size
size_t
genLen
=
path
.
ids
.
size
();
generator_
.
ids
.
push_back
(
genLen
);
// sequence size
generator_
.
ids
.
insert
(
generator_
.
ids
.
insert
(
generator_
.
ids
.
end
(),
path
.
ids
.
begin
(),
path
.
ids
.
end
());
generator_
.
ids
.
end
(),
path
.
ids
.
begin
(),
path
.
ids
.
end
());
generator_
.
ids
.
push_back
(
-
1
);
// end of sequence
generator_
.
ids
.
push_back
(
-
1
);
// end of sequence
memcpy
(
idsProb
+
curPos
,
path
.
idsProb
.
data
(),
sizeof
(
real
)
*
genLen
);
curPos
+=
genLen
;
idsProb
[
curPos
++
]
=
-
1.0
;
probs
[
i
*
numResults
+
j
]
=
path
.
logProb
;
probs
[
i
*
numResults
+
j
]
=
path
.
logProb
;
if
(
!
j
&&
dataArgsSize_
)
{
if
(
!
j
&&
dataArgsSize_
)
{
...
...
paddle/gserver/gradientmachines/RecurrentGradientMachine.h
浏览文件 @
023fe1e3
...
@@ -189,6 +189,11 @@ public:
...
@@ -189,6 +189,11 @@ public:
*/
*/
std
::
vector
<
int
>
ids
;
std
::
vector
<
int
>
ids
;
/**
* @brief idsProb, log probability of each generated words.
*/
std
::
vector
<
real
>
idsProb
;
/**
/**
* @brief logProb, current probability of path.
* @brief logProb, current probability of path.
*/
*/
...
@@ -228,11 +233,13 @@ public:
...
@@ -228,11 +233,13 @@ public:
*/
*/
Path
(
Path
&
old
,
int
newId
,
real
logProb
,
int
machineId
,
int
topIndex
)
Path
(
Path
&
old
,
int
newId
,
real
logProb
,
int
machineId
,
int
topIndex
)
:
ids
(
old
.
ids
),
:
ids
(
old
.
ids
),
idsProb
(
old
.
idsProb
),
logProb
(
old
.
logProb
+
logProb
),
logProb
(
old
.
logProb
+
logProb
),
machineId
(
machineId
),
machineId
(
machineId
),
topIndex
(
topIndex
),
topIndex
(
topIndex
),
seqId
(
old
.
seqId
)
{
seqId
(
old
.
seqId
)
{
ids
.
push_back
(
newId
);
ids
.
push_back
(
newId
);
idsProb
.
push_back
(
logProb
);
if
(
!
old
.
probHistory
.
empty
())
{
if
(
!
old
.
probHistory
.
empty
())
{
this
->
probHistory
=
old
.
probHistory
;
this
->
probHistory
=
old
.
probHistory
;
// probHistory store current prob, not sum
// probHistory store current prob, not sum
...
@@ -411,8 +418,9 @@ protected:
...
@@ -411,8 +418,9 @@ protected:
struct
Generator
{
struct
Generator
{
GeneratorConfig
config
;
GeneratorConfig
config
;
std
::
vector
<
int
>
ids
;
// store generated sequences
std
::
vector
<
int
>
ids
;
// store generated sequences
Argument
outArg
;
// final output argument
std
::
vector
<
real
>
idsProb
;
// log probability of each generated word
Argument
outArg
;
// final output argument
};
};
bool
generating_
;
bool
generating_
;
Generator
generator_
;
Generator
generator_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录