Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a7563602
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a7563602
编写于
4月 17, 2020
作者:
G
gfwm0502
提交者:
GitHub
4月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
OP/API (While/while_loop/DynamicRNN) : Error Message Enhancement (#23896)
As the title
上级
b8866225
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
137 addition
and
76 deletion
+137
-76
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+59
-40
paddle/fluid/operators/controlflow/while_op_helper.cc
paddle/fluid/operators/controlflow/while_op_helper.cc
+26
-8
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+19
-28
python/paddle/fluid/tests/unittests/test_dyn_rnn.py
python/paddle/fluid/tests/unittests/test_dyn_rnn.py
+33
-0
未找到文件。
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
a7563602
...
...
@@ -49,10 +49,17 @@ class WhileOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)),
platform
::
errors
::
NotFound
(
"Input(Condition) of WhileOp is not found."
));
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}),
platform
::
errors
::
InvalidArgument
(
"The shape of Input(Condition) of WhileOp must be 1. But now "
"the Condition's shape is "
,
cond
.
dims
().
to_str
(),
".
\n
"
));
framework
::
Executor
executor
(
dev_place
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
...
...
@@ -72,7 +79,9 @@ class WhileOp : public framework::OperatorBase {
step_scopes
->
clear
();
}
PADDLE_ENFORCE_EQ
(
step_scopes
->
size
(),
0
,
"The StepScope should be empty."
);
PADDLE_ENFORCE_EQ
(
step_scopes
->
size
(),
0
,
platform
::
errors
::
PreconditionNotMet
(
"The Output(StepScope) of WhileOp should be empty."
));
bool
cond_data
=
GetCondData
(
cond
);
bool
is_test
=
Attr
<
bool
>
(
"is_test"
);
...
...
@@ -160,8 +169,10 @@ class WhileGradOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE
(
!
Attr
<
bool
>
(
"is_test"
),
"GradOp is only callable when is_test is false"
);
PADDLE_ENFORCE_EQ
(
Attr
<
bool
>
(
"is_test"
),
false
,
platform
::
errors
::
InvalidArgument
(
"WhileGradOp is only callable when is_test is false."
));
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
@@ -180,7 +191,14 @@ class WhileGradOp : public framework::OperatorBase {
auto
inside_og_names
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"original_output_grad"
);
PADDLE_ENFORCE_EQ
(
outside_og_names
.
size
(),
inside_og_names
.
size
());
PADDLE_ENFORCE_EQ
(
outside_og_names
.
size
(),
inside_og_names
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The number of original output gradient names "
"does not match the number of backward input "
"gradient names. The number of Backward input "
"names is %d and the numbers of original output "
"gradient names is %d."
,
outside_og_names
.
size
(),
inside_og_names
.
size
()));
for
(
auto
cur_scope_iter
=
step_scopes
->
rbegin
();
cur_scope_iter
!=
step_scopes
->
rend
();
++
cur_scope_iter
)
{
...
...
@@ -222,11 +240,18 @@ class WhileGradOp : public framework::OperatorBase {
inside_array
[
j
].
set_lod
(
outside_array
->
at
(
j
).
lod
());
inside_array
[
j
].
ShareDataWith
(
outside_array
->
at
(
j
));
}
else
{
PADDLE_ENFORCE_EQ
(
inside_array
[
j
].
numel
(),
0
);
PADDLE_ENFORCE_EQ
(
inside_array
[
j
].
numel
(),
0
,
platform
::
errors
::
InvalidArgument
(
"The numel of %d-th element of var %s (LoDTensorArray) "
"in while block must be 0, but received its numel is %d."
,
j
,
inside_og_name
,
inside_array
[
j
].
numel
()));
}
}
}
else
{
PADDLE_THROW
(
"Currently only support LoDTensor and LoDTensorArray."
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Currently only support LoDTensor and LoDTensorArray in "
"WhileGradOp."
));
}
}
executor
.
RunPreparedContext
(
ctx
.
get
(),
*
cur_scope_iter
,
false
,
true
,
...
...
@@ -236,7 +261,13 @@ class WhileGradOp : public framework::OperatorBase {
// and inputs.
auto
&
pg_ig_names
=
Outputs
(
kXGRAD
);
auto
&
p_names
=
Inputs
(
kX
);
PADDLE_ENFORCE_EQ
(
pg_ig_names
.
size
(),
p_names
.
size
());
PADDLE_ENFORCE_EQ
(
pg_ig_names
.
size
(),
p_names
.
size
(),
platform
::
errors
::
PreconditionNotMet
(
"The number of names in Outputs(X@GRAD) does not "
"match the number of names in Inputs(X). The "
"number of names in Outputs(X@GRAD) is %d and "
"the number of names in Inputs(X) is %d."
,
pg_ig_names
.
size
(),
p_names
.
size
()));
for
(
size_t
param_id
=
0
;
param_id
<
pg_ig_names
.
size
();
++
param_id
)
{
if
(
pg_ig_names
[
param_id
]
==
framework
::
kEmptyVarName
)
{
continue
;
// parameter doesn't have gradient
...
...
@@ -247,7 +278,9 @@ class WhileGradOp : public framework::OperatorBase {
// for example lookup_table_grad_op, the input(Idx) doesn't have
// gradient.
auto
pg_ig_var
=
cur_scope
.
FindVar
(
inside_grad_name
);
PADDLE_ENFORCE
(
pg_ig_var
!=
nullptr
);
PADDLE_ENFORCE_NOT_NULL
(
pg_ig_var
,
platform
::
errors
::
NotFound
(
"Variable %s is not found."
,
inside_grad_name
));
if
(
pg_ig_var
->
IsType
<
framework
::
LoDTensorArray
>
())
{
auto
pg_ig_lod_t_arr
=
pg_ig_var
->
GetMutable
<
framework
::
LoDTensorArray
>
();
...
...
@@ -277,13 +310,16 @@ class WhileGradOp : public framework::OperatorBase {
// zero gradient variable in step 0
if
(
cur_scope_iter
==
step_scopes
->
rbegin
())
{
auto
*
var
=
(
*
cur_scope_iter
)
->
FindVar
(
inside_grad_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find var %s"
,
inside_grad_name
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Variable %s is not found."
,
inside_grad_name
));
PADDLE_ENFORCE_EQ
(
var
->
IsType
<
framework
::
LoDTensorArray
>
()
||
var
->
IsType
<
LoDTensor
>
(),
"Currently the type of var only can be LoDTensorArray, "
"or LoDTensor, but the received var[%s] is %s."
,
inside_grad_name
,
framework
::
ToTypeName
(
var
->
Type
()));
true
,
platform
::
errors
::
InvalidArgument
(
"Currently the type of var only can be LoDTensorArray, "
"or LoDTensor, but the received var[%s] is %s."
,
inside_grad_name
,
framework
::
ToTypeName
(
var
->
Type
())));
if
(
var
->
IsType
<
LoDTensor
>
())
{
auto
&
inside_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
...
...
@@ -422,41 +458,24 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx
->
HasOutputs
(
framework
::
GradVarName
(
kX
));
ctx
->
HasInputs
(
kOutputs
);
ctx
->
HasInputs
(
framework
::
GradVarName
(
kOutputs
));
auto
pg_ig_names
=
ctx
->
Outputs
(
kXGRAD
);
std
::
vector
<
framework
::
InferShapeVarPtr
>
in_var_ptrs
=
ctx
->
GetInputVarPtrs
(
kX
);
std
::
vector
<
framework
::
InferShapeVarPtr
>
out_var_ptrs
=
ctx
->
GetOutputVarPtrs
(
kXGRAD
);
PADDLE_ENFORCE
(
in_var_ptrs
.
size
()
==
out_var_ptrs
.
size
());
PADDLE_ENFORCE_EQ
(
in_var_ptrs
.
size
(),
out_var_ptrs
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of Inputs(X) must be the same as "
"the size of Outputs(X@GRAD)."
));
for
(
size_t
i
=
0
;
i
<
in_var_ptrs
.
size
();
++
i
)
{
if
(
pg_ig_names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
if
(
ctx
->
IsRuntime
())
{
framework
::
Variable
*
in_var
=
boost
::
get
<
framework
::
Variable
*>
(
in_var_ptrs
[
i
]);
framework
::
Variable
*
out_var
=
boost
::
get
<
framework
::
Variable
*>
(
out_var_ptrs
[
i
]);
auto
type
=
framework
::
ToVarType
(
in_var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
out_var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
in_var
->
Get
<
framework
::
LoDTensor
>
().
dims
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
out_var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
set_height
(
in_var
->
Get
<
framework
::
SelectedRows
>
().
GetCompleteDims
()[
0
]);
}
else
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
PADDLE_THROW
(
"WhileGradOp doesn't support type %d"
,
static_cast
<
int
>
(
type
));
}
}
else
{
framework
::
VarDesc
*
in_var
=
boost
::
get
<
framework
::
VarDesc
*>
(
in_var_ptrs
[
i
]);
boost
::
get
<
framework
::
VarDesc
*>
(
out_var_ptrs
[
i
])
->
SetShape
(
in_var
->
GetShape
());
}
framework
::
VarDesc
*
in_var
=
boost
::
get
<
framework
::
VarDesc
*>
(
in_var_ptrs
[
i
]);
boost
::
get
<
framework
::
VarDesc
*>
(
out_var_ptrs
[
i
])
->
SetShape
(
in_var
->
GetShape
());
}
}
};
...
...
paddle/fluid/operators/controlflow/while_op_helper.cc
浏览文件 @
a7563602
...
...
@@ -83,7 +83,11 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
auto
&
in_grads
=
bwd_op
.
Outputs
().
at
(
framework
::
GradVarName
(
kX
));
PADDLE_ENFORCE_EQ
(
fwd_input
.
size
(),
in_grads
.
size
(),
"Backward input gradient number does not match forward input number."
);
platform
::
errors
::
PreconditionNotMet
(
"Backward output gradient number does not match forward input number."
"The number of forward input number is %d and the number of backward "
"output geadient number is %d."
,
fwd_input
.
size
(),
in_grads
.
size
()));
std
::
unordered_set
<
std
::
string
>
backward_skip_vars
;
for
(
size_t
i
=
0
;
i
<
in_grads
.
size
();
++
i
)
{
...
...
@@ -104,7 +108,13 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
static
void
FindAllWhileAndWhileGradOp
(
const
framework
::
ProgramDesc
&
program
,
std
::
vector
<
OpVariant
>
*
while_ops
,
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
());
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
(),
platform
::
errors
::
PreconditionNotMet
(
"There are more while_grad_ops than forward while_ops in the graph "
"or program, the number of while_ops is %d and the number of "
"while_grad_ops is %d."
,
while_ops
->
size
(),
while_grad_ops
->
size
()));
for
(
size_t
i
=
1
;
i
<
program
.
Size
();
++
i
)
{
auto
&
block
=
program
.
Block
(
i
);
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
...
...
@@ -117,8 +127,13 @@ static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program,
}
}
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
(),
"There are extra while_grad ops in the graph or program"
);
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
(),
platform
::
errors
::
InvalidArgument
(
"There are more while_grad_ops than forward while_ops in the graph "
"or program, the number of while_ops is %d and the number of "
"while_grad_ops is %d."
,
while_ops
->
size
(),
while_grad_ops
->
size
()));
}
static
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
...
...
@@ -140,13 +155,16 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
const
OpVariant
*
matched_fwd_op
=
nullptr
;
for
(
auto
&
fwd_op
:
while_op_set
)
{
if
(
IsMatchedWhileOpAndWhileGradOp
(
fwd_op
,
bwd_op
))
{
PADDLE_ENFORCE
(
matched_fwd_op
==
nullptr
,
"Found multiple matched while ops"
);
PADDLE_ENFORCE_EQ
(
matched_fwd_op
,
nullptr
,
platform
::
errors
::
PreconditionNotMet
(
"Found multiple while forward ops match while "
"grad ops."
));
matched_fwd_op
=
&
fwd_op
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
matched_fwd_op
,
"Cannot find matched forward while op."
);
platform
::
errors
::
PreconditionNotMet
(
"Cannot find matched forward while op."
));
ModifyWhileOpAndWhileGradOpAttr
(
*
matched_fwd_op
,
bwd_op
);
while_op_set
.
erase
(
*
matched_fwd_op
);
}
...
...
@@ -209,7 +227,7 @@ bool GetCondData(const framework::LoDTensor &cond) {
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"This version of PaddlePaddle does NOT support GPU but got GPU tensor "
"Cond in WhileOp. Please compile WITH_GPU option"
));
"Cond in WhileOp. Please compile WITH_GPU option
.
"
));
#endif
return
cpu_cond
->
data
<
bool
>
()[
0
];
}
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
a7563602
...
...
@@ -882,14 +882,10 @@ class While(object):
def
__init__
(
self
,
cond
,
is_test
=
False
,
name
=
None
):
self
.
helper
=
LayerHelper
(
"while"
,
name
=
name
)
self
.
status
=
While
.
BEFORE_WHILE_BLOCK
if
not
isinstance
(
cond
,
Variable
):
raise
TypeError
(
"condition should be a variable"
)
assert
isinstance
(
cond
,
Variable
)
if
cond
.
dtype
!=
core
.
VarDesc
.
VarType
.
BOOL
:
raise
TypeError
(
"condition should be a boolean variable"
)
check_variable_and_dtype
(
cond
,
'cond'
,
[
'bool'
],
'fluid.layers.While'
)
if
reduce
(
lambda
a
,
b
:
a
*
b
,
cond
.
shape
,
1
)
!=
1
:
raise
TypeError
(
"condition expected shape as [], but given shape as {0}."
.
"condition expected shape as [
1
], but given shape as {0}."
.
format
(
list
(
cond
.
shape
)))
self
.
cond_var
=
cond
self
.
is_test
=
is_test
...
...
@@ -999,19 +995,16 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
raise
TypeError
(
"cond in while_loop should be callable"
)
if
not
callable
(
body
):
raise
TypeError
(
"body in while_loop should be callable"
)
if
not
isinstance
(
loop_vars
,
(
list
,
tuple
)):
raise
TypeError
(
"loop_vars in while_loop should be a list or tuple"
)
check_type
(
loop_vars
,
'loop_vars'
,
(
list
,
tuple
),
'fluid.layers.while_loop'
)
if
len
(
loop_vars
)
==
0
:
raise
ValueError
(
"loop_vars in while_loop should not be empty"
)
pre_cond
=
cond
(
*
loop_vars
)
if
not
isinstance
(
pre_cond
,
Variable
):
raise
TypeError
(
"cond in while_loop should return a variable"
)
if
pre_cond
.
dtype
!=
core
.
VarDesc
.
VarType
.
BOOL
:
raise
TypeError
(
"cond in while_loop should return a boolean variable"
)
check_variable_and_dtype
(
pre_cond
,
'var of cond returned'
,
[
'bool'
],
'fluid.layers.while_loop'
)
if
reduce
(
lambda
a
,
b
:
a
*
b
,
pre_cond
.
shape
,
1
)
!=
1
:
raise
TypeError
(
"the shape of the variable returned by cond should be [],"
"the shape of the variable returned by cond should be [
1
],"
"but given shape as {0}."
.
format
(
list
(
pre_cond
.
shape
)))
if
in_dygraph_mode
():
...
...
@@ -2906,9 +2899,7 @@ class DynamicRNN(object):
rnn_output = drnn()
"""
self
.
_assert_in_rnn_block_
(
"step_input"
)
if
not
isinstance
(
x
,
Variable
):
raise
TypeError
(
"step_input() can only take a Variable as its input."
)
check_type
(
x
,
'x'
,
Variable
,
'fluid.layers.DynamicRNN.step_input()'
)
parent_block
=
self
.
_parent_block_
()
if
self
.
lod_rank_table
is
None
:
self
.
lod_rank_table
=
parent_block
.
create_var
(
...
...
@@ -3075,9 +3066,7 @@ class DynamicRNN(object):
rnn_output = drnn()
"""
self
.
_assert_in_rnn_block_
(
"static_input"
)
if
not
isinstance
(
x
,
Variable
):
raise
TypeError
(
"static_input() can only take a Variable as its input"
)
check_type
(
x
,
'x'
,
Variable
,
'fluid.layers.DynamicRNN.static_input()'
)
if
self
.
lod_rank_table
is
None
:
raise
RuntimeError
(
"static_input() must be called after step_input()."
)
...
...
@@ -3242,10 +3231,12 @@ class DynamicRNN(object):
"""
self
.
_assert_in_rnn_block_
(
'memory'
)
self
.
_init_zero_idx_
()
if
shape
is
not
None
:
check_type
(
shape
,
'shape'
,
(
list
,
tuple
),
'fluid.layers.DynamicRNN.memory()'
)
if
init
is
not
None
:
if
not
isinstance
(
init
,
Variable
):
raise
TypeError
(
"The input arg `init` of memory() must be a Variable"
)
check_type
(
init
,
'init'
,
Variable
,
'fluid.layers.DynamicRNN.memory()'
)
parent_block
=
self
.
_parent_block_
()
init_tensor
=
init
if
need_reorder
==
True
:
...
...
@@ -3326,12 +3317,10 @@ class DynamicRNN(object):
ValueError: When :code:`update_memory()` is called before :code:`step_input()` .
"""
self
.
_assert_in_rnn_block_
(
'update_memory'
)
if
not
isinstance
(
ex_mem
,
Variable
):
raise
TypeError
(
"The input arg `ex_mem` of update_memory() must "
"be a Variable"
)
if
not
isinstance
(
new_mem
,
Variable
):
raise
TypeError
(
"The input arg `new_mem` of update_memory() must "
"be a Variable"
)
check_type
(
ex_mem
,
'ex_mem'
,
Variable
,
'fluid.layers.DynamicRNN.update_memory()'
)
check_type
(
new_mem
,
'new_mem'
,
Variable
,
'fluid.layers.DynamicRNN.update_memory()'
)
mem_array
=
self
.
mem_dict
.
get
(
ex_mem
.
name
,
None
)
if
mem_array
is
None
:
...
...
@@ -3358,6 +3347,8 @@ class DynamicRNN(object):
self
.
_assert_in_rnn_block_
(
'output'
)
parent_block
=
self
.
_parent_block_
()
for
each
in
outputs
:
check_type
(
each
,
"outputs"
,
Variable
,
"fluid.layers.DynamicRNN.output"
)
outside_array
=
parent_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"_"
.
join
(
[
self
.
helper
.
name
,
"output_array"
,
each
.
name
])),
...
...
python/paddle/fluid/tests/unittests/test_dyn_rnn.py
浏览文件 @
a7563602
...
...
@@ -19,6 +19,7 @@ import paddle
import
unittest
import
numpy
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.layers.control_flow
import
lod_rank_table
from
paddle.fluid.layers.control_flow
import
max_sequence_len
from
paddle.fluid.layers.control_flow
import
lod_tensor_to_array
...
...
@@ -299,5 +300,37 @@ class TestDynamicRNN(unittest.TestCase):
self
.
train_data
=
train_data_orig
class
TestDynamicRNNErrors
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
init
=
fluid
.
layers
.
zeros
(
shape
=
[
1
],
dtype
=
'float32'
)
shape
=
'shape'
sentence
=
fluid
.
data
(
name
=
'sentence'
,
shape
=
[
None
,
32
],
dtype
=
'float32'
,
lod_level
=
1
)
# The type of Input(shape) in API(memory) must be list or tuple
def
input_shape_type_of_memory
():
drnn
=
fluid
.
layers
.
DynamicRNN
()
with
drnn
.
block
():
res
=
drnn
.
memory
(
init
,
shape
)
self
.
assertRaises
(
TypeError
,
input_shape_type_of_memory
)
# The type of element of Input(*outputs) in API(output) must be Variable.
def
outputs_type_of_output
():
drnn
=
fluid
.
layers
.
DynamicRNN
()
with
drnn
.
block
():
word
=
drnn
.
step_input
(
sentence
)
memory
=
drnn
.
memory
(
shape
=
[
10
],
dtype
=
'float32'
,
value
=
0
)
hidden
=
fluid
.
layers
.
fc
(
input
=
[
word
,
memory
],
size
=
10
,
act
=
'tanh'
)
out
=
np
.
ones
(
1
).
astype
(
'float32'
)
drnn
.
update_memory
(
ex_mem
=
memory
,
new_mem
=
hidden
)
drnn
.
output
(
hidden
,
out
)
self
.
assertRaises
(
TypeError
,
outputs_type_of_output
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录