Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a7563602
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
a7563602
编写于
4月 17, 2020
作者:
G
gfwm0502
提交者:
GitHub
4月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
OP/API (While/while_loop/DynamicRNN) : Error Message Enhancement (#23896)
As the title
上级
b8866225
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
137 addition
and
76 deletion
+137
-76
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+59
-40
paddle/fluid/operators/controlflow/while_op_helper.cc
paddle/fluid/operators/controlflow/while_op_helper.cc
+26
-8
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+19
-28
python/paddle/fluid/tests/unittests/test_dyn_rnn.py
python/paddle/fluid/tests/unittests/test_dyn_rnn.py
+33
-0
未找到文件。
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
a7563602
...
...
@@ -49,10 +49,17 @@ class WhileOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)));
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kCondition
)),
platform
::
errors
::
NotFound
(
"Input(Condition) of WhileOp is not found."
));
auto
&
cond
=
scope
.
FindVar
(
Input
(
kCondition
))
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}),
platform
::
errors
::
InvalidArgument
(
"The shape of Input(Condition) of WhileOp must be 1. But now "
"the Condition's shape is "
,
cond
.
dims
().
to_str
(),
".
\n
"
));
framework
::
Executor
executor
(
dev_place
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
...
...
@@ -72,7 +79,9 @@ class WhileOp : public framework::OperatorBase {
step_scopes
->
clear
();
}
PADDLE_ENFORCE_EQ
(
step_scopes
->
size
(),
0
,
"The StepScope should be empty."
);
PADDLE_ENFORCE_EQ
(
step_scopes
->
size
(),
0
,
platform
::
errors
::
PreconditionNotMet
(
"The Output(StepScope) of WhileOp should be empty."
));
bool
cond_data
=
GetCondData
(
cond
);
bool
is_test
=
Attr
<
bool
>
(
"is_test"
);
...
...
@@ -160,8 +169,10 @@ class WhileGradOp : public framework::OperatorBase {
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
PADDLE_ENFORCE
(
!
Attr
<
bool
>
(
"is_test"
),
"GradOp is only callable when is_test is false"
);
PADDLE_ENFORCE_EQ
(
Attr
<
bool
>
(
"is_test"
),
false
,
platform
::
errors
::
InvalidArgument
(
"WhileGradOp is only callable when is_test is false."
));
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
...
...
@@ -180,7 +191,14 @@ class WhileGradOp : public framework::OperatorBase {
auto
inside_og_names
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"original_output_grad"
);
PADDLE_ENFORCE_EQ
(
outside_og_names
.
size
(),
inside_og_names
.
size
());
PADDLE_ENFORCE_EQ
(
outside_og_names
.
size
(),
inside_og_names
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The number of original output gradient names "
"does not match the number of backward input "
"gradient names. The number of Backward input "
"names is %d and the numbers of original output "
"gradient names is %d."
,
outside_og_names
.
size
(),
inside_og_names
.
size
()));
for
(
auto
cur_scope_iter
=
step_scopes
->
rbegin
();
cur_scope_iter
!=
step_scopes
->
rend
();
++
cur_scope_iter
)
{
...
...
@@ -222,11 +240,18 @@ class WhileGradOp : public framework::OperatorBase {
inside_array
[
j
].
set_lod
(
outside_array
->
at
(
j
).
lod
());
inside_array
[
j
].
ShareDataWith
(
outside_array
->
at
(
j
));
}
else
{
PADDLE_ENFORCE_EQ
(
inside_array
[
j
].
numel
(),
0
);
PADDLE_ENFORCE_EQ
(
inside_array
[
j
].
numel
(),
0
,
platform
::
errors
::
InvalidArgument
(
"The numel of %d-th element of var %s (LoDTensorArray) "
"in while block must be 0, but received its numel is %d."
,
j
,
inside_og_name
,
inside_array
[
j
].
numel
()));
}
}
}
else
{
PADDLE_THROW
(
"Currently only support LoDTensor and LoDTensorArray."
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Currently only support LoDTensor and LoDTensorArray in "
"WhileGradOp."
));
}
}
executor
.
RunPreparedContext
(
ctx
.
get
(),
*
cur_scope_iter
,
false
,
true
,
...
...
@@ -236,7 +261,13 @@ class WhileGradOp : public framework::OperatorBase {
// and inputs.
auto
&
pg_ig_names
=
Outputs
(
kXGRAD
);
auto
&
p_names
=
Inputs
(
kX
);
PADDLE_ENFORCE_EQ
(
pg_ig_names
.
size
(),
p_names
.
size
());
PADDLE_ENFORCE_EQ
(
pg_ig_names
.
size
(),
p_names
.
size
(),
platform
::
errors
::
PreconditionNotMet
(
"The number of names in Outputs(X@GRAD) does not "
"match the number of names in Inputs(X). The "
"number of names in Outputs(X@GRAD) is %d and "
"the number of names in Inputs(X) is %d."
,
pg_ig_names
.
size
(),
p_names
.
size
()));
for
(
size_t
param_id
=
0
;
param_id
<
pg_ig_names
.
size
();
++
param_id
)
{
if
(
pg_ig_names
[
param_id
]
==
framework
::
kEmptyVarName
)
{
continue
;
// parameter doesn't have gradient
...
...
@@ -247,7 +278,9 @@ class WhileGradOp : public framework::OperatorBase {
// for example lookup_table_grad_op, the input(Idx) doesn't have
// gradient.
auto
pg_ig_var
=
cur_scope
.
FindVar
(
inside_grad_name
);
PADDLE_ENFORCE
(
pg_ig_var
!=
nullptr
);
PADDLE_ENFORCE_NOT_NULL
(
pg_ig_var
,
platform
::
errors
::
NotFound
(
"Variable %s is not found."
,
inside_grad_name
));
if
(
pg_ig_var
->
IsType
<
framework
::
LoDTensorArray
>
())
{
auto
pg_ig_lod_t_arr
=
pg_ig_var
->
GetMutable
<
framework
::
LoDTensorArray
>
();
...
...
@@ -277,13 +310,16 @@ class WhileGradOp : public framework::OperatorBase {
// zero gradient variable in step 0
if
(
cur_scope_iter
==
step_scopes
->
rbegin
())
{
auto
*
var
=
(
*
cur_scope_iter
)
->
FindVar
(
inside_grad_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Can not find var %s"
,
inside_grad_name
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Variable %s is not found."
,
inside_grad_name
));
PADDLE_ENFORCE_EQ
(
var
->
IsType
<
framework
::
LoDTensorArray
>
()
||
var
->
IsType
<
LoDTensor
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Currently the type of var only can be LoDTensorArray, "
"or LoDTensor, but the received var[%s] is %s."
,
inside_grad_name
,
framework
::
ToTypeName
(
var
->
Type
(
)));
inside_grad_name
,
framework
::
ToTypeName
(
var
->
Type
()
)));
if
(
var
->
IsType
<
LoDTensor
>
())
{
auto
&
inside_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
...
...
@@ -422,43 +458,26 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx
->
HasOutputs
(
framework
::
GradVarName
(
kX
));
ctx
->
HasInputs
(
kOutputs
);
ctx
->
HasInputs
(
framework
::
GradVarName
(
kOutputs
));
auto
pg_ig_names
=
ctx
->
Outputs
(
kXGRAD
);
std
::
vector
<
framework
::
InferShapeVarPtr
>
in_var_ptrs
=
ctx
->
GetInputVarPtrs
(
kX
);
std
::
vector
<
framework
::
InferShapeVarPtr
>
out_var_ptrs
=
ctx
->
GetOutputVarPtrs
(
kXGRAD
);
PADDLE_ENFORCE
(
in_var_ptrs
.
size
()
==
out_var_ptrs
.
size
());
PADDLE_ENFORCE_EQ
(
in_var_ptrs
.
size
(),
out_var_ptrs
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of Inputs(X) must be the same as "
"the size of Outputs(X@GRAD)."
));
for
(
size_t
i
=
0
;
i
<
in_var_ptrs
.
size
();
++
i
)
{
if
(
pg_ig_names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
if
(
ctx
->
IsRuntime
())
{
framework
::
Variable
*
in_var
=
boost
::
get
<
framework
::
Variable
*>
(
in_var_ptrs
[
i
]);
framework
::
Variable
*
out_var
=
boost
::
get
<
framework
::
Variable
*>
(
out_var_ptrs
[
i
]);
auto
type
=
framework
::
ToVarType
(
in_var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
out_var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
in_var
->
Get
<
framework
::
LoDTensor
>
().
dims
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
out_var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
set_height
(
in_var
->
Get
<
framework
::
SelectedRows
>
().
GetCompleteDims
()[
0
]);
}
else
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
PADDLE_THROW
(
"WhileGradOp doesn't support type %d"
,
static_cast
<
int
>
(
type
));
}
}
else
{
framework
::
VarDesc
*
in_var
=
boost
::
get
<
framework
::
VarDesc
*>
(
in_var_ptrs
[
i
]);
boost
::
get
<
framework
::
VarDesc
*>
(
out_var_ptrs
[
i
])
->
SetShape
(
in_var
->
GetShape
());
}
}
}
};
}
// namespace operators
...
...
paddle/fluid/operators/controlflow/while_op_helper.cc
浏览文件 @
a7563602
...
...
@@ -83,7 +83,11 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
auto
&
in_grads
=
bwd_op
.
Outputs
().
at
(
framework
::
GradVarName
(
kX
));
PADDLE_ENFORCE_EQ
(
fwd_input
.
size
(),
in_grads
.
size
(),
"Backward input gradient number does not match forward input number."
);
platform
::
errors
::
PreconditionNotMet
(
"Backward output gradient number does not match forward input number."
"The number of forward input number is %d and the number of backward "
"output geadient number is %d."
,
fwd_input
.
size
(),
in_grads
.
size
()));
std
::
unordered_set
<
std
::
string
>
backward_skip_vars
;
for
(
size_t
i
=
0
;
i
<
in_grads
.
size
();
++
i
)
{
...
...
@@ -104,7 +108,13 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
static
void
FindAllWhileAndWhileGradOp
(
const
framework
::
ProgramDesc
&
program
,
std
::
vector
<
OpVariant
>
*
while_ops
,
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
());
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
(),
platform
::
errors
::
PreconditionNotMet
(
"There are more while_grad_ops than forward while_ops in the graph "
"or program, the number of while_ops is %d and the number of "
"while_grad_ops is %d."
,
while_ops
->
size
(),
while_grad_ops
->
size
()));
for
(
size_t
i
=
1
;
i
<
program
.
Size
();
++
i
)
{
auto
&
block
=
program
.
Block
(
i
);
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
...
...
@@ -117,8 +127,13 @@ static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program,
}
}
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
(),
"There are extra while_grad ops in the graph or program"
);
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
(),
platform
::
errors
::
InvalidArgument
(
"There are more while_grad_ops than forward while_ops in the graph "
"or program, the number of while_ops is %d and the number of "
"while_grad_ops is %d."
,
while_ops
->
size
(),
while_grad_ops
->
size
()));
}
static
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
...
...
@@ -140,13 +155,16 @@ static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
const
OpVariant
*
matched_fwd_op
=
nullptr
;
for
(
auto
&
fwd_op
:
while_op_set
)
{
if
(
IsMatchedWhileOpAndWhileGradOp
(
fwd_op
,
bwd_op
))
{
PADDLE_ENFORCE
(
matched_fwd_op
==
nullptr
,
"Found multiple matched while ops"
);
PADDLE_ENFORCE_EQ
(
matched_fwd_op
,
nullptr
,
platform
::
errors
::
PreconditionNotMet
(
"Found multiple while forward ops match while "
"grad ops."
));
matched_fwd_op
=
&
fwd_op
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
matched_fwd_op
,
"Cannot find matched forward while op."
);
platform
::
errors
::
PreconditionNotMet
(
"Cannot find matched forward while op."
));
ModifyWhileOpAndWhileGradOpAttr
(
*
matched_fwd_op
,
bwd_op
);
while_op_set
.
erase
(
*
matched_fwd_op
);
}
...
...
@@ -209,7 +227,7 @@ bool GetCondData(const framework::LoDTensor &cond) {
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"This version of PaddlePaddle does NOT support GPU but got GPU tensor "
"Cond in WhileOp. Please compile WITH_GPU option"
));
"Cond in WhileOp. Please compile WITH_GPU option
.
"
));
#endif
return
cpu_cond
->
data
<
bool
>
()[
0
];
}
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
a7563602
...
...
@@ -882,14 +882,10 @@ class While(object):
def
__init__
(
self
,
cond
,
is_test
=
False
,
name
=
None
):
self
.
helper
=
LayerHelper
(
"while"
,
name
=
name
)
self
.
status
=
While
.
BEFORE_WHILE_BLOCK
if
not
isinstance
(
cond
,
Variable
):
raise
TypeError
(
"condition should be a variable"
)
assert
isinstance
(
cond
,
Variable
)
if
cond
.
dtype
!=
core
.
VarDesc
.
VarType
.
BOOL
:
raise
TypeError
(
"condition should be a boolean variable"
)
check_variable_and_dtype
(
cond
,
'cond'
,
[
'bool'
],
'fluid.layers.While'
)
if
reduce
(
lambda
a
,
b
:
a
*
b
,
cond
.
shape
,
1
)
!=
1
:
raise
TypeError
(
"condition expected shape as [], but given shape as {0}."
.
"condition expected shape as [
1
], but given shape as {0}."
.
format
(
list
(
cond
.
shape
)))
self
.
cond_var
=
cond
self
.
is_test
=
is_test
...
...
@@ -999,19 +995,16 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
raise
TypeError
(
"cond in while_loop should be callable"
)
if
not
callable
(
body
):
raise
TypeError
(
"body in while_loop should be callable"
)
if
not
isinstance
(
loop_vars
,
(
list
,
tuple
)):
raise
TypeError
(
"loop_vars in while_loop should be a list or tuple"
)
check_type
(
loop_vars
,
'loop_vars'
,
(
list
,
tuple
),
'fluid.layers.while_loop'
)
if
len
(
loop_vars
)
==
0
:
raise
ValueError
(
"loop_vars in while_loop should not be empty"
)
pre_cond
=
cond
(
*
loop_vars
)
if
not
isinstance
(
pre_cond
,
Variable
):
raise
TypeError
(
"cond in while_loop should return a variable"
)
if
pre_cond
.
dtype
!=
core
.
VarDesc
.
VarType
.
BOOL
:
raise
TypeError
(
"cond in while_loop should return a boolean variable"
)
check_variable_and_dtype
(
pre_cond
,
'var of cond returned'
,
[
'bool'
],
'fluid.layers.while_loop'
)
if
reduce
(
lambda
a
,
b
:
a
*
b
,
pre_cond
.
shape
,
1
)
!=
1
:
raise
TypeError
(
"the shape of the variable returned by cond should be [],"
"the shape of the variable returned by cond should be [
1
],"
"but given shape as {0}."
.
format
(
list
(
pre_cond
.
shape
)))
if
in_dygraph_mode
():
...
...
@@ -2906,9 +2899,7 @@ class DynamicRNN(object):
rnn_output = drnn()
"""
self
.
_assert_in_rnn_block_
(
"step_input"
)
if
not
isinstance
(
x
,
Variable
):
raise
TypeError
(
"step_input() can only take a Variable as its input."
)
check_type
(
x
,
'x'
,
Variable
,
'fluid.layers.DynamicRNN.step_input()'
)
parent_block
=
self
.
_parent_block_
()
if
self
.
lod_rank_table
is
None
:
self
.
lod_rank_table
=
parent_block
.
create_var
(
...
...
@@ -3075,9 +3066,7 @@ class DynamicRNN(object):
rnn_output = drnn()
"""
self
.
_assert_in_rnn_block_
(
"static_input"
)
if
not
isinstance
(
x
,
Variable
):
raise
TypeError
(
"static_input() can only take a Variable as its input"
)
check_type
(
x
,
'x'
,
Variable
,
'fluid.layers.DynamicRNN.static_input()'
)
if
self
.
lod_rank_table
is
None
:
raise
RuntimeError
(
"static_input() must be called after step_input()."
)
...
...
@@ -3242,10 +3231,12 @@ class DynamicRNN(object):
"""
self
.
_assert_in_rnn_block_
(
'memory'
)
self
.
_init_zero_idx_
()
if
shape
is
not
None
:
check_type
(
shape
,
'shape'
,
(
list
,
tuple
),
'fluid.layers.DynamicRNN.memory()'
)
if
init
is
not
None
:
if
not
isinstance
(
init
,
Variable
):
raise
TypeError
(
"The input arg `init` of memory() must be a Variable"
)
check_type
(
init
,
'init'
,
Variable
,
'fluid.layers.DynamicRNN.memory()'
)
parent_block
=
self
.
_parent_block_
()
init_tensor
=
init
if
need_reorder
==
True
:
...
...
@@ -3326,12 +3317,10 @@ class DynamicRNN(object):
ValueError: When :code:`update_memory()` is called before :code:`step_input()` .
"""
self
.
_assert_in_rnn_block_
(
'update_memory'
)
if
not
isinstance
(
ex_mem
,
Variable
):
raise
TypeError
(
"The input arg `ex_mem` of update_memory() must "
"be a Variable"
)
if
not
isinstance
(
new_mem
,
Variable
):
raise
TypeError
(
"The input arg `new_mem` of update_memory() must "
"be a Variable"
)
check_type
(
ex_mem
,
'ex_mem'
,
Variable
,
'fluid.layers.DynamicRNN.update_memory()'
)
check_type
(
new_mem
,
'new_mem'
,
Variable
,
'fluid.layers.DynamicRNN.update_memory()'
)
mem_array
=
self
.
mem_dict
.
get
(
ex_mem
.
name
,
None
)
if
mem_array
is
None
:
...
...
@@ -3358,6 +3347,8 @@ class DynamicRNN(object):
self
.
_assert_in_rnn_block_
(
'output'
)
parent_block
=
self
.
_parent_block_
()
for
each
in
outputs
:
check_type
(
each
,
"outputs"
,
Variable
,
"fluid.layers.DynamicRNN.output"
)
outside_array
=
parent_block
.
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"_"
.
join
(
[
self
.
helper
.
name
,
"output_array"
,
each
.
name
])),
...
...
python/paddle/fluid/tests/unittests/test_dyn_rnn.py
浏览文件 @
a7563602
...
...
@@ -19,6 +19,7 @@ import paddle
import
unittest
import
numpy
from
paddle.fluid.framework
import
Program
,
program_guard
from
paddle.fluid.layers.control_flow
import
lod_rank_table
from
paddle.fluid.layers.control_flow
import
max_sequence_len
from
paddle.fluid.layers.control_flow
import
lod_tensor_to_array
...
...
@@ -299,5 +300,37 @@ class TestDynamicRNN(unittest.TestCase):
self
.
train_data
=
train_data_orig
class
TestDynamicRNNErrors
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
init
=
fluid
.
layers
.
zeros
(
shape
=
[
1
],
dtype
=
'float32'
)
shape
=
'shape'
sentence
=
fluid
.
data
(
name
=
'sentence'
,
shape
=
[
None
,
32
],
dtype
=
'float32'
,
lod_level
=
1
)
# The type of Input(shape) in API(memory) must be list or tuple
def
input_shape_type_of_memory
():
drnn
=
fluid
.
layers
.
DynamicRNN
()
with
drnn
.
block
():
res
=
drnn
.
memory
(
init
,
shape
)
self
.
assertRaises
(
TypeError
,
input_shape_type_of_memory
)
# The type of element of Input(*outputs) in API(output) must be Variable.
def
outputs_type_of_output
():
drnn
=
fluid
.
layers
.
DynamicRNN
()
with
drnn
.
block
():
word
=
drnn
.
step_input
(
sentence
)
memory
=
drnn
.
memory
(
shape
=
[
10
],
dtype
=
'float32'
,
value
=
0
)
hidden
=
fluid
.
layers
.
fc
(
input
=
[
word
,
memory
],
size
=
10
,
act
=
'tanh'
)
out
=
np
.
ones
(
1
).
astype
(
'float32'
)
drnn
.
update_memory
(
ex_mem
=
memory
,
new_mem
=
hidden
)
drnn
.
output
(
hidden
,
out
)
self
.
assertRaises
(
TypeError
,
outputs_type_of_output
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录