Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
dfd1eee7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dfd1eee7
编写于
10月 12, 2019
作者:
G
Guo Sheng
提交者:
Tao Luo
10月 12, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add seq2seq api related code (#19820)
上级
e87cabb7
变更
24
展开全部
隐藏空白更改
内联
并排
Showing
24 changed file
with
2480 addition
and
72 deletion
+2480
-72
paddle/fluid/API.spec
paddle/fluid/API.spec
+36
-2
paddle/fluid/operators/assign_op.cc
paddle/fluid/operators/assign_op.cc
+4
-2
paddle/fluid/operators/fill_constant_batch_size_like_op.cc
paddle/fluid/operators/fill_constant_batch_size_like_op.cc
+8
-1
paddle/fluid/operators/fill_constant_batch_size_like_op.cu.cc
...le/fluid/operators/fill_constant_batch_size_like_op.cu.cc
+3
-1
paddle/fluid/operators/fill_constant_batch_size_like_op.h
paddle/fluid/operators/fill_constant_batch_size_like_op.h
+14
-5
paddle/fluid/operators/fill_constant_op.cu.cc
paddle/fluid/operators/fill_constant_op.cu.cc
+1
-0
paddle/fluid/operators/gather_nd_op.cc
paddle/fluid/operators/gather_nd_op.cc
+8
-3
paddle/fluid/operators/gather_nd_op.cu
paddle/fluid/operators/gather_nd_op.cu
+1
-0
paddle/fluid/operators/gather_tree_op.cc
paddle/fluid/operators/gather_tree_op.cc
+78
-0
paddle/fluid/operators/gather_tree_op.cu
paddle/fluid/operators/gather_tree_op.cu
+80
-0
paddle/fluid/operators/gather_tree_op.h
paddle/fluid/operators/gather_tree_op.h
+58
-0
paddle/fluid/operators/reduce_ops/reduce_all_op.cc
paddle/fluid/operators/reduce_ops/reduce_all_op.cc
+3
-1
paddle/fluid/operators/reduce_ops/reduce_any_op.cc
paddle/fluid/operators/reduce_ops/reduce_any_op.cc
+3
-1
paddle/fluid/operators/reduce_ops/reduce_op.h
paddle/fluid/operators/reduce_ops/reduce_op.h
+20
-7
paddle/fluid/operators/tensor_array_to_tensor_op.cc
paddle/fluid/operators/tensor_array_to_tensor_op.cc
+63
-12
python/paddle/fluid/layers/__init__.py
python/paddle/fluid/layers/__init__.py
+4
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+76
-0
python/paddle/fluid/layers/rnn.py
python/paddle/fluid/layers/rnn.py
+1165
-0
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+75
-36
python/paddle/fluid/layers/utils.py
python/paddle/fluid/layers/utils.py
+172
-0
python/paddle/fluid/tests/unittests/test_gather_tree_op.py
python/paddle/fluid/tests/unittests/test_gather_tree_op.py
+65
-0
python/paddle/fluid/tests/unittests/test_rnn_cell_api.py
python/paddle/fluid/tests/unittests/test_rnn_cell_api.py
+249
-0
python/paddle/fluid/tests/unittests/test_rnn_decode_api.py
python/paddle/fluid/tests/unittests/test_rnn_decode_api.py
+214
-0
python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py
...ddle/fluid/tests/unittests/test_tensor_array_to_tensor.py
+80
-1
未找到文件。
paddle/fluid/API.spec
浏览文件 @
dfd1eee7
...
@@ -306,6 +306,7 @@ paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'tran
...
@@ -306,6 +306,7 @@ paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'tran
paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_tag', 'is_lod'], varargs=None, keywords=None, defaults=None), ('document', '7703a2088af8de4128b143ff1164ca4a'))
paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_tag', 'is_lod'], varargs=None, keywords=None, defaults=None), ('document', '7703a2088af8de4128b143ff1164ca4a'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '3c6b30e9cd57b38d4a5fa1ade887f779'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '3c6b30e9cd57b38d4a5fa1ade887f779'))
paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', 'bd763b9ca99239d624c3cb4626e3627a'))
paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', 'bd763b9ca99239d624c3cb4626e3627a'))
paddle.fluid.layers.gather_tree (ArgSpec(args=['ids', 'parents'], varargs=None, keywords=None, defaults=None), ('document', '201b54fa7512305078c70a6610beaead'))
paddle.fluid.layers.mse_loss (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', '88b967ef5132567396062d5d654b3064'))
paddle.fluid.layers.mse_loss (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', '88b967ef5132567396062d5d654b3064'))
paddle.fluid.layers.uniform_random (ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=('float32', -1.0, 1.0, 0)), ('document', '34e7c1ff0263baf9551000b6bb3bc47e'))
paddle.fluid.layers.uniform_random (ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=('float32', -1.0, 1.0, 0)), ('document', '34e7c1ff0263baf9551000b6bb3bc47e'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
...
@@ -318,11 +319,11 @@ paddle.fluid.layers.create_tensor (ArgSpec(args=['dtype', 'name', 'persistable']
...
@@ -318,11 +319,11 @@ paddle.fluid.layers.create_tensor (ArgSpec(args=['dtype', 'name', 'persistable']
paddle.fluid.layers.create_parameter (ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', '727aa63c061919bee38547fb126d9428'))
paddle.fluid.layers.create_parameter (ArgSpec(args=['shape', 'dtype', 'name', 'attr', 'is_bias', 'default_initializer'], varargs=None, keywords=None, defaults=(None, None, False, None)), ('document', '727aa63c061919bee38547fb126d9428'))
paddle.fluid.layers.create_global_var (ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None)), ('document', 'fa7f74cfb940521cc9fdffabc83debbf'))
paddle.fluid.layers.create_global_var (ArgSpec(args=['shape', 'value', 'dtype', 'persistable', 'force_cpu', 'name'], varargs=None, keywords=None, defaults=(False, False, None)), ('document', 'fa7f74cfb940521cc9fdffabc83debbf'))
paddle.fluid.layers.cast (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '45df178cbd8c302f92c30ebdaaa6fa8a'))
paddle.fluid.layers.cast (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '45df178cbd8c302f92c30ebdaaa6fa8a'))
paddle.fluid.layers.tensor_array_to_tensor (ArgSpec(args=['input', 'axis', 'name'
], varargs=None, keywords=None, defaults=(1, None)), ('document', 'dd7d2f1e12a8a4225d017209866e5621
'))
paddle.fluid.layers.tensor_array_to_tensor (ArgSpec(args=['input', 'axis', 'name'
, 'use_stack'], varargs=None, keywords=None, defaults=(1, None, False)), ('document', '4aa82374218ccf593bb8011df79c71e3
'))
paddle.fluid.layers.concat (ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'ec7d6e716fb29ef1e73e1e3efa5ca46b'))
paddle.fluid.layers.concat (ArgSpec(args=['input', 'axis', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'ec7d6e716fb29ef1e73e1e3efa5ca46b'))
paddle.fluid.layers.sums (ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,)), ('document', '5df743d578638cd2bbb9369499b44af4'))
paddle.fluid.layers.sums (ArgSpec(args=['input', 'out'], varargs=None, keywords=None, defaults=(None,)), ('document', '5df743d578638cd2bbb9369499b44af4'))
paddle.fluid.layers.assign (ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,)), ('document', '8bd94aef4e123986d9a8c29f67b5532b'))
paddle.fluid.layers.assign (ArgSpec(args=['input', 'output'], varargs=None, keywords=None, defaults=(None,)), ('document', '8bd94aef4e123986d9a8c29f67b5532b'))
paddle.fluid.layers.fill_constant_batch_size_like (ArgSpec(args=['input', 'shape', 'dtype', 'value', 'input_dim_idx', 'output_dim_idx'
], varargs=None, keywords=None, defaults=(0, 0)), ('document', '37a288e4400f6d5510e982827461c11b
'))
paddle.fluid.layers.fill_constant_batch_size_like (ArgSpec(args=['input', 'shape', 'dtype', 'value', 'input_dim_idx', 'output_dim_idx'
, 'force_cpu'], varargs=None, keywords=None, defaults=(0, 0, False)), ('document', '2bb57637664173fee5f654e55896aec6
'))
paddle.fluid.layers.fill_constant (ArgSpec(args=['shape', 'dtype', 'value', 'force_cpu', 'out'], varargs=None, keywords=None, defaults=(False, None)), ('document', '66e1e468666dd47e5b2715226cebeac0'))
paddle.fluid.layers.fill_constant (ArgSpec(args=['shape', 'dtype', 'value', 'force_cpu', 'out'], varargs=None, keywords=None, defaults=(False, None)), ('document', '66e1e468666dd47e5b2715226cebeac0'))
paddle.fluid.layers.argmin (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', '53629e27597e5dfb7020aac5bc639ebb'))
paddle.fluid.layers.argmin (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', '53629e27597e5dfb7020aac5bc639ebb'))
paddle.fluid.layers.argmax (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', 'd9a89fbedbaebd5f65897ac75ee636f3'))
paddle.fluid.layers.argmax (ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)), ('document', 'd9a89fbedbaebd5f65897ac75ee636f3'))
...
@@ -467,6 +468,39 @@ paddle.fluid.layers.MultivariateNormalDiag.entropy (ArgSpec(args=['self'], varar
...
@@ -467,6 +468,39 @@ paddle.fluid.layers.MultivariateNormalDiag.entropy (ArgSpec(args=['self'], varar
paddle.fluid.layers.MultivariateNormalDiag.kl_divergence (ArgSpec(args=['self', 'other'], varargs=None, keywords=None, defaults=None), ('document', 'd9190d29dbd54c81f747a6436c35f062'))
paddle.fluid.layers.MultivariateNormalDiag.kl_divergence (ArgSpec(args=['self', 'other'], varargs=None, keywords=None, defaults=None), ('document', 'd9190d29dbd54c81f747a6436c35f062'))
paddle.fluid.layers.MultivariateNormalDiag.log_prob (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', 'c0edd2e2fc76711477b32dc4da9de768'))
paddle.fluid.layers.MultivariateNormalDiag.log_prob (ArgSpec(args=['self', 'value'], varargs=None, keywords=None, defaults=None), ('document', 'c0edd2e2fc76711477b32dc4da9de768'))
paddle.fluid.layers.MultivariateNormalDiag.sample (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '08a2bbcaa20ee176ee7ec3d05737a0f6'))
paddle.fluid.layers.MultivariateNormalDiag.sample (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '08a2bbcaa20ee176ee7ec3d05737a0f6'))
paddle.fluid.layers.RNNCell ('paddle.fluid.layers.rnn.RNNCell', ('document', '2c3a2d3ecb4a3cec130395e7df0bd5c9'))
paddle.fluid.layers.RNNCell.__init__
paddle.fluid.layers.RNNCell.call (ArgSpec(args=['self', 'inputs', 'states'], varargs=None, keywords='kwargs', defaults=None), ('document', '3ac714b638258c520d66f682be67b658'))
paddle.fluid.layers.RNNCell.get_initial_states (ArgSpec(args=['self', 'batch_ref', 'shape', 'dtype', 'init_value'], varargs=None, keywords=None, defaults=(None, None, 0)), ('document', '003d1b4c99128f798ac0b0eecc81c489'))
paddle.fluid.layers.GRUCell ('paddle.fluid.layers.rnn.GRUCell', ('document', '7b2902a91258c4688a879805290adc00'))
paddle.fluid.layers.GRUCell.__init__ (ArgSpec(args=['self', 'hidden_size', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, 'float32', 'GRUCell')), ('document', '3624a6c93b4a999d0d809eb1a66d272e'))
paddle.fluid.layers.GRUCell.call (ArgSpec(args=['self', 'inputs', 'states'], varargs=None, keywords=None, defaults=None), ('document', '6094ab09a56c732c76abb5105327ea54'))
paddle.fluid.layers.GRUCell.get_initial_states (ArgSpec(args=['self', 'batch_ref', 'shape', 'dtype', 'init_value'], varargs=None, keywords=None, defaults=(None, None, 0)), ('document', '003d1b4c99128f798ac0b0eecc81c489'))
paddle.fluid.layers.LSTMCell ('paddle.fluid.layers.rnn.LSTMCell', ('document', '5cbd87bce446ba0f50398ce2772d43e9'))
paddle.fluid.layers.LSTMCell.__init__ (ArgSpec(args=['self', 'hidden_size', 'param_attr', 'bias_attr', 'gate_activation', 'activation', 'forget_bias', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, 1.0, 'float32', 'LSTMCell')), ('document', '9015961869b436d2739a0347618028e3'))
paddle.fluid.layers.LSTMCell.call (ArgSpec(args=['self', 'inputs', 'states'], varargs=None, keywords=None, defaults=None), ('document', '9c84a477021e4a7d0a497c1e6a31be2d'))
paddle.fluid.layers.LSTMCell.get_initial_states (ArgSpec(args=['self', 'batch_ref', 'shape', 'dtype', 'init_value'], varargs=None, keywords=None, defaults=(None, None, 0)), ('document', '003d1b4c99128f798ac0b0eecc81c489'))
paddle.fluid.layers.Decoder ('paddle.fluid.layers.rnn.Decoder', ('document', '23838bd065fddca1557a6a3368d9e365'))
paddle.fluid.layers.Decoder.__init__
paddle.fluid.layers.Decoder.finalize (ArgSpec(args=['self', 'outputs', 'final_states', 'sequence_lengths'], varargs=None, keywords=None, defaults=None), ('document', 'cab7fc752a05db18e99258473f50359d'))
paddle.fluid.layers.Decoder.initialize (ArgSpec(args=['self', 'inits'], varargs=None, keywords=None, defaults=None), ('document', '68cf1846fb58056dbe5a524f1ca9dff5'))
paddle.fluid.layers.Decoder.step (ArgSpec(args=['self', 'time', 'inputs', 'states'], varargs=None, keywords=None, defaults=None), ('document', '151d0229930b9654689f86c85f7c4c3f'))
paddle.fluid.layers.BeamSearchDecoder ('paddle.fluid.layers.rnn.BeamSearchDecoder', ('document', 'd7ef0c9229bfe73e0daefcfda24a2635'))
paddle.fluid.layers.BeamSearchDecoder.OutputWrapper ('paddle.fluid.layers.rnn.OutputWrapper', ('document', 'a7141ebf1fb097fa71006cdd35bdc219'))
paddle.fluid.layers.BeamSearchDecoder.OutputWrapper.__init__
paddle.fluid.layers.BeamSearchDecoder.OutputWrapper.count T.count(value) -> integer -- return number of occurrences of value
paddle.fluid.layers.BeamSearchDecoder.OutputWrapper.index T.index(value, [start, [stop]]) -> integer -- return first index of value.
paddle.fluid.layers.BeamSearchDecoder.StateWrapper ('paddle.fluid.layers.rnn.StateWrapper', ('document', '157731f37c88ea01bc746653125a41c8'))
paddle.fluid.layers.BeamSearchDecoder.StateWrapper.__init__
paddle.fluid.layers.BeamSearchDecoder.StateWrapper.count T.count(value) -> integer -- return number of occurrences of value
paddle.fluid.layers.BeamSearchDecoder.StateWrapper.index T.index(value, [start, [stop]]) -> integer -- return first index of value.
paddle.fluid.layers.BeamSearchDecoder.__init__ (ArgSpec(args=['self', 'cell', 'start_token', 'end_token', 'beam_size', 'embedding_fn', 'output_fn'], varargs=None, keywords=None, defaults=(None, None)), ('document', '68951eaed573ec47c17a43155514b2f1'))
paddle.fluid.layers.BeamSearchDecoder.finalize (ArgSpec(args=['self', 'outputs', 'final_states', 'sequence_lengths'], varargs=None, keywords=None, defaults=None), ('document', '9a7f0a8fc5802bf860f2ac960466fb45'))
paddle.fluid.layers.BeamSearchDecoder.initialize (ArgSpec(args=['self', 'initial_cell_states'], varargs=None, keywords=None, defaults=None), ('document', '01ee508a9615e2483fe6ddcf14d5fa25'))
paddle.fluid.layers.BeamSearchDecoder.step (ArgSpec(args=['self', 'time', 'inputs', 'states'], varargs=None, keywords='kwargs', defaults=None), ('document', '35ee583c3c0fe7cceeafa289ed3374bd'))
paddle.fluid.layers.BeamSearchDecoder.tile_beam_merge_with_batch (ArgSpec(args=['x', 'beam_size'], varargs=None, keywords=None, defaults=None), ('document', 'ce7ffacba6f56f57acbf5d4dd82fe04d'))
paddle.fluid.layers.rnn (ArgSpec(args=['cell', 'inputs', 'initial_states', 'sequence_length', 'time_major', 'is_reverse'], varargs=None, keywords='kwargs', defaults=(None, None, False, False)), ('document', 'c36ade777ff43d2ba5542079b66a012b'))
paddle.fluid.layers.dynamic_decode (ArgSpec(args=['decoder', 'inits', 'max_step_num', 'output_time_major'], varargs=None, keywords='kwargs', defaults=(None, None, False)), ('document', '55b44de9d290c0c2ad8fdd635e6ab575'))
paddle.fluid.contrib.InitState ('paddle.fluid.contrib.decoder.beam_search_decoder.InitState', ('document', '3afd1f84232718e628e9e566941c5f05'))
paddle.fluid.contrib.InitState ('paddle.fluid.contrib.decoder.beam_search_decoder.InitState', ('document', '3afd1f84232718e628e9e566941c5f05'))
paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.StateCell ('paddle.fluid.contrib.decoder.beam_search_decoder.StateCell', ('document', 'ecd0066c02867d445d7b461e28220c50'))
paddle.fluid.contrib.StateCell ('paddle.fluid.contrib.decoder.beam_search_decoder.StateCell', ('document', 'ecd0066c02867d445d7b461e28220c50'))
...
...
paddle/fluid/operators/assign_op.cc
浏览文件 @
dfd1eee7
...
@@ -154,10 +154,12 @@ REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
...
@@ -154,10 +154,12 @@ REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
ops
::
AssignOpProtoMaker
,
ops
::
AssignOpInplaceInferer
);
ops
::
AssignOpProtoMaker
,
ops
::
AssignOpInplaceInferer
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
assign
,
float
,
ops
::
AssignKernel
,
double
,
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
assign
,
float
,
ops
::
AssignKernel
,
double
,
ops
::
AssignKernel
,
int
,
ops
::
AssignKernel
,
ops
::
AssignKernel
,
int
,
ops
::
AssignKernel
,
int64_t
,
ops
::
AssignKernel
);
int64_t
,
ops
::
AssignKernel
,
bool
,
ops
::
AssignKernel
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR
(
assign
,
float
,
ops
::
AssignKernel
,
double
,
REGISTER_OP_CUDA_KERNEL_FUNCTOR
(
assign
,
float
,
ops
::
AssignKernel
,
double
,
ops
::
AssignKernel
,
int
,
ops
::
AssignKernel
,
ops
::
AssignKernel
,
int
,
ops
::
AssignKernel
,
int64_t
,
ops
::
AssignKernel
);
int64_t
,
ops
::
AssignKernel
,
bool
,
ops
::
AssignKernel
);
#endif
#endif
paddle/fluid/operators/fill_constant_batch_size_like_op.cc
浏览文件 @
dfd1eee7
...
@@ -38,6 +38,11 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
...
@@ -38,6 +38,11 @@ class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
.
SetDefault
(
framework
::
proto
::
VarType
::
FP32
);
.
SetDefault
(
framework
::
proto
::
VarType
::
FP32
);
AddAttr
<
float
>
(
"value"
,
"default 0. The value to be filled"
)
AddAttr
<
float
>
(
"value"
,
"default 0. The value to be filled"
)
.
SetDefault
(
0.0
f
);
.
SetDefault
(
0.0
f
);
AddAttr
<
bool
>
(
"force_cpu"
,
"(bool, default false) Force fill output variable to cpu "
"memory. Otherwise, fill output variable to the running "
"device"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
This function creates a tensor of specified *shape*, *dtype* and batch size,
This function creates a tensor of specified *shape*, *dtype* and batch size,
and initializes this with a constant supplied in *value*. The batch size is
and initializes this with a constant supplied in *value*. The batch size is
...
@@ -65,4 +70,6 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -65,4 +70,6 @@ REGISTER_OP_CPU_KERNEL(
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
int
>
,
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
int64_t
>
,
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
bool
>
);
paddle/fluid/operators/fill_constant_batch_size_like_op.cu.cc
浏览文件 @
dfd1eee7
...
@@ -25,4 +25,6 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -25,4 +25,6 @@ REGISTER_OP_CUDA_KERNEL(
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
int
>
,
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
int64_t
>
,
ops
::
FillConstantBatchSizeLikeOpKernel
<
paddle
::
platform
::
CUDADeviceContext
,
bool
>
);
paddle/fluid/operators/fill_constant_batch_size_like_op.h
浏览文件 @
dfd1eee7
...
@@ -23,6 +23,11 @@ template <typename DeviceContext, typename T>
...
@@ -23,6 +23,11 @@ template <typename DeviceContext, typename T>
class
FillConstantBatchSizeLikeOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FillConstantBatchSizeLikeOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
ctx
.
Attr
<
int
>
(
"dtype"
));
auto
value
=
ctx
.
Attr
<
float
>
(
"value"
);
auto
force_cpu
=
ctx
.
Attr
<
bool
>
(
"force_cpu"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
);
if
(
in
->
lod
().
size
()
&&
ctx
.
Attr
<
int
>
(
"input_dim_idx"
)
==
0
)
{
if
(
in
->
lod
().
size
()
&&
ctx
.
Attr
<
int
>
(
"input_dim_idx"
)
==
0
)
{
...
@@ -32,12 +37,16 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
...
@@ -32,12 +37,16 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
odims
[
output_dim_idx
]
=
static_cast
<
int
>
(
in
->
lod
().
back
().
size
())
-
1
;
odims
[
output_dim_idx
]
=
static_cast
<
int
>
(
in
->
lod
().
back
().
size
())
-
1
;
out
->
mutable_data
<
T
>
(
odims
,
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
odims
,
ctx
.
GetPlace
());
}
}
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
value
=
ctx
.
Attr
<
float
>
(
"value"
);
math
::
SetConstant
<
DeviceContext
,
T
>
setter
;
if
(
force_cpu
)
{
setter
(
ctx
.
template
device_context
<
DeviceContext
>(),
out
,
out
->
mutable_data
(
platform
::
CPUPlace
(),
data_type
);
static_cast
<
T
>
(
value
));
}
else
{
out
->
mutable_data
(
ctx
.
GetPlace
(),
data_type
);
}
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
ctx
.
GetPlace
());
math
::
set_constant
(
dev_ctx
,
out
,
value
);
}
}
};
};
...
...
paddle/fluid/operators/fill_constant_op.cu.cc
浏览文件 @
dfd1eee7
...
@@ -19,4 +19,5 @@ REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>,
...
@@ -19,4 +19,5 @@ REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops
::
FillConstantKernel
<
double
>
,
ops
::
FillConstantKernel
<
double
>
,
ops
::
FillConstantKernel
<
int64_t
>
,
ops
::
FillConstantKernel
<
int64_t
>
,
ops
::
FillConstantKernel
<
int
>
,
ops
::
FillConstantKernel
<
int
>
,
ops
::
FillConstantKernel
<
bool
>
,
ops
::
FillConstantKernel
<
paddle
::
platform
::
float16
>
);
ops
::
FillConstantKernel
<
paddle
::
platform
::
float16
>
);
paddle/fluid/operators/gather_nd_op.cc
浏览文件 @
dfd1eee7
...
@@ -60,8 +60,13 @@ class GatherNdOp : public framework::OperatorWithKernel {
...
@@ -60,8 +60,13 @@ class GatherNdOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
ctx
.
device_context
());
const
auto
&
x_type
=
x
->
type
();
return
framework
::
OpKernelType
(
x_type
,
x_type
==
framework
::
proto
::
VarType
::
BOOL
?
x
->
place
()
// to be consistent with compare and logical ops
:
ctx
.
device_context
().
GetPlace
());
}
}
};
};
...
@@ -173,7 +178,7 @@ REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
...
@@ -173,7 +178,7 @@ REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
REGISTER_OP_CPU_KERNEL
(
gather_nd
,
ops
::
GatherNdOpKernel
<
float
>
,
REGISTER_OP_CPU_KERNEL
(
gather_nd
,
ops
::
GatherNdOpKernel
<
float
>
,
ops
::
GatherNdOpKernel
<
double
>
,
ops
::
GatherNdOpKernel
<
double
>
,
ops
::
GatherNdOpKernel
<
int64_t
>
,
ops
::
GatherNdOpKernel
<
int64_t
>
,
ops
::
GatherNdOpKernel
<
int
>
,
ops
::
GatherNdOpKernel
<
int
>
,
ops
::
GatherNdOpKernel
<
bool
>
,
ops
::
GatherNdOpKernel
<
uint8_t
>
);
ops
::
GatherNdOpKernel
<
uint8_t
>
);
REGISTER_OP_CPU_KERNEL
(
gather_nd_grad
,
ops
::
GatherNdGradOpKernel
<
float
>
,
REGISTER_OP_CPU_KERNEL
(
gather_nd_grad
,
ops
::
GatherNdGradOpKernel
<
float
>
,
...
...
paddle/fluid/operators/gather_nd_op.cu
浏览文件 @
dfd1eee7
...
@@ -95,6 +95,7 @@ REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>,
...
@@ -95,6 +95,7 @@ REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>,
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
double
>
,
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
double
>
,
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
int64_t
>
,
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
int64_t
>
,
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
int
>
,
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
int
>
,
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
bool
>
,
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
plat
::
float16
>
);
ops
::
GatherNdOpCUDAKernel
<
CUDA
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
gather_nd_grad
,
REGISTER_OP_CUDA_KERNEL
(
gather_nd_grad
,
...
...
paddle/fluid/operators/gather_tree_op.cc
0 → 100644
浏览文件 @
dfd1eee7
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gather_tree_op.h"
namespace
paddle
{
namespace
operators
{
class
GatherTreeOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) of GatherTreeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Parents"
),
"Input(Parents) of GatherTreeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of GatherTreeOp should not be null."
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
auto
parents_dims
=
ctx
->
GetInputDim
(
"Parents"
);
PADDLE_ENFORCE
(
ids_dims
==
parents_dims
,
"The shape of Input(Parents) must be same with the shape of "
"Input(Ids)."
);
ctx
->
SetOutputDim
(
"Out"
,
ids_dims
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Ids"
)
->
type
(),
ctx
.
device_context
());
}
};
class
GatherTreeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Ids"
,
"The Tensor with shape [length, batch_size, beam_size] containing "
"the selected ids of all time steps."
);
AddInput
(
"Parents"
,
"The Tensor has the same shape as Ids and contains the parents "
"corresponding to selected ids when searching among beams."
);
AddOutput
(
"Out"
,
"A Tensor with shape [length, batch_size, beam_size] containing the "
"full sequences. The sequences is collected by backtracing from the "
"last time step of Ids."
);
AddComment
(
R"DOC(
GatherTree Operator.
Backtrace from the last time step and generate the full sequences by collecting beam search
selected ids.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
gather_tree
,
ops
::
GatherTreeOp
,
ops
::
GatherTreeOpMaker
);
REGISTER_OP_CPU_KERNEL
(
gather_tree
,
ops
::
GatherTreeOpKernel
<
int32_t
>
,
ops
::
GatherTreeOpKernel
<
int64_t
>
);
paddle/fluid/operators/gather_tree_op.cu
0 → 100644
浏览文件 @
dfd1eee7
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather_tree_op.h"
namespace
paddle
{
namespace
operators
{
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template
<
typename
T
>
__global__
void
GatherTree
(
const
T
*
ids_data
,
const
T
*
parents_data
,
T
*
out_data
,
const
int64_t
max_length
,
const
int64_t
batch_size
,
const
int64_t
beam_size
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
batch_size
*
beam_size
)
{
int
batch
=
i
/
beam_size
;
int
beam
=
i
%
beam_size
;
auto
idx
=
(
max_length
-
1
)
*
batch_size
*
beam_size
+
batch
*
beam_size
+
beam
;
out_data
[
idx
]
=
ids_data
[
idx
];
auto
parent
=
parents_data
[
idx
];
for
(
int
step
=
max_length
-
2
;
step
>=
0
;
step
--
)
{
idx
=
step
*
batch_size
*
beam_size
+
batch
*
beam_size
;
out_data
[
idx
+
beam
]
=
ids_data
[
idx
+
parent
];
parent
=
parents_data
[
idx
+
parent
];
}
}
}
template
<
typename
T
>
class
GatherTreeOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
parents
=
ctx
.
Input
<
Tensor
>
(
"Parents"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
const
auto
*
ids_data
=
ids
->
data
<
T
>
();
const
auto
*
parents_data
=
parents
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
ids_dims
=
ids
->
dims
();
int64_t
max_length
=
ids_dims
[
0
];
int64_t
batch_size
=
ids_dims
[
1
];
int64_t
beam_size
=
ids_dims
[
2
];
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
const
int
block
=
512
;
int
max_threads
=
std
::
min
(
static_cast
<
int64_t
>
(
dev_ctx
.
GetMaxPhysicalThreadCount
()),
batch_size
*
beam_size
);
const
int
grid
=
std
::
max
(
max_threads
/
block
,
1
);
GatherTree
<<<
grid
,
block
>>>
(
ids_data
,
parents_data
,
out_data
,
max_length
,
batch_size
,
beam_size
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
gather_tree
,
ops
::
GatherTreeOpCUDAKernel
<
int32_t
>
,
ops
::
GatherTreeOpCUDAKernel
<
int64_t
>
);
paddle/fluid/operators/gather_tree_op.h
0 → 100644
浏览文件 @
dfd1eee7
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
GatherTreeOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
ids
=
ctx
.
Input
<
Tensor
>
(
"Ids"
);
auto
*
parents
=
ctx
.
Input
<
Tensor
>
(
"Parents"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
const
auto
*
ids_data
=
ids
->
data
<
T
>
();
const
auto
*
parents_data
=
parents
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
ids_dims
=
ids
->
dims
();
auto
max_length
=
ids_dims
[
0
];
auto
batch_size
=
ids_dims
[
1
];
auto
beam_size
=
ids_dims
[
2
];
for
(
int
batch
=
0
;
batch
<
batch_size
;
batch
++
)
{
for
(
int
beam
=
0
;
beam
<
beam_size
;
beam
++
)
{
auto
idx
=
(
max_length
-
1
)
*
batch_size
*
beam_size
+
batch
*
beam_size
+
beam
;
out_data
[
idx
]
=
ids_data
[
idx
];
auto
parent
=
parents_data
[
idx
];
for
(
int
step
=
max_length
-
2
;
step
>=
0
;
step
--
)
{
idx
=
step
*
batch_size
*
beam_size
+
batch
*
beam_size
;
out_data
[
idx
+
beam
]
=
ids_data
[
idx
+
parent
];
parent
=
parents_data
[
idx
+
parent
];
}
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/reduce_ops/reduce_all_op.cc
浏览文件 @
dfd1eee7
...
@@ -14,7 +14,9 @@
...
@@ -14,7 +14,9 @@
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_all_op.h"
REGISTER_REDUCE_OP_WITHOUT_GRAD
(
reduce_all
);
// kernel's device type is decided by input tensor place, to be consistent with
// compare and logical ops
REGISTER_REDUCE_OP_WITHOUT_GRAD
(
reduce_all
,
UseInputPlace
);
REGISTER_OP_CPU_KERNEL
(
reduce_all
,
REGISTER_OP_CPU_KERNEL
(
reduce_all
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CPUDeviceContext
,
bool
,
ops
::
AllFunctor
>
);
bool
,
ops
::
AllFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_any_op.cc
浏览文件 @
dfd1eee7
...
@@ -14,7 +14,9 @@
...
@@ -14,7 +14,9 @@
#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_any_op.h"
REGISTER_REDUCE_OP_WITHOUT_GRAD
(
reduce_any
);
// kernel's device type is decided by input tensor place, to be consistent with
// compare and logical ops
REGISTER_REDUCE_OP_WITHOUT_GRAD
(
reduce_any
,
UseInputPlace
);
REGISTER_OP_CPU_KERNEL
(
reduce_any
,
REGISTER_OP_CPU_KERNEL
(
reduce_any
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
ReduceKernel
<
paddle
::
platform
::
CPUDeviceContext
,
bool
,
ops
::
AnyFunctor
>
);
bool
,
ops
::
AnyFunctor
>
);
paddle/fluid/operators/reduce_ops/reduce_op.h
浏览文件 @
dfd1eee7
...
@@ -223,6 +223,19 @@ class ReduceOp : public framework::OperatorWithKernel {
...
@@ -223,6 +223,19 @@ class ReduceOp : public framework::OperatorWithKernel {
}
}
};
};
class
ReduceOpUseInputPlace
:
public
ReduceOp
{
public:
using
ReduceOp
::
ReduceOp
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
OpKernelType
kt
=
OperatorWithKernel
::
GetExpectedKernelType
(
ctx
);
kt
.
place_
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
place
();
return
kt
;
}
};
class
ReduceGradOp
:
public
framework
::
OperatorWithKernel
{
class
ReduceGradOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
...
@@ -313,11 +326,11 @@ namespace ops = paddle::operators;
...
@@ -313,11 +326,11 @@ namespace ops = paddle::operators;
paddle::framework::DefaultGradOpDescMaker<true>); \
paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(op_name##_grad, ops::ReduceGradOp)
REGISTER_OPERATOR(op_name##_grad, ops::ReduceGradOp)
#define REGISTER_REDUCE_OP_WITHOUT_GRAD(op_name
)
\
#define REGISTER_REDUCE_OP_WITHOUT_GRAD(op_name
, ...)
\
class __##op_name##Maker__ : public ops::ReduceOpMaker { \
class __##op_name##Maker__ : public ops::ReduceOpMaker {
\
protected: \
protected:
\
virtual std::string GetName() const { return #op_name; } \
virtual std::string GetName() const { return #op_name; }
\
virtual std::string GetOpType() const { return "Reduce " #op_name; } \
virtual std::string GetOpType() const { return "Reduce " #op_name; }
\
}; \
};
\
REGISTER_OPERATOR(op_name, ops::ReduceOp
, __##op_name##Maker__,
\
REGISTER_OPERATOR(op_name, ops::ReduceOp
##__VA_ARGS__, __##op_name##Maker__,
\
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker);
paddle/fluid/operators/tensor_array_to_tensor_op.cc
浏览文件 @
dfd1eee7
...
@@ -120,11 +120,18 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
...
@@ -120,11 +120,18 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase {
out
.
Resize
(
out_dims
);
out
.
Resize
(
out_dims
);
LodTensorArray2LodTensorVector
(
scope
,
base_name
,
Input
(
"X"
),
&
names
);
LodTensorArray2LodTensorVector
(
scope
,
base_name
,
Input
(
"X"
),
&
names
);
// Invoke concat Op
auto
concat_op
=
framework
::
OpRegistry
::
CreateOp
(
"concat"
,
{{
"X"
,
names
}},
{{
"Out"
,
{
Output
(
"Out"
)}}},
attrs
);
concat_op
->
Run
(
scope
,
place
);
auto
use_stack
=
Attr
<
bool
>
(
"use_stack"
);
// Invoke concat Op or stack Op
auto
op
=
use_stack
?
framework
::
OpRegistry
::
CreateOp
(
"stack"
,
{{
"X"
,
names
}},
{{
"Y"
,
{
Output
(
"Out"
)}}},
attrs
)
:
framework
::
OpRegistry
::
CreateOp
(
"concat"
,
{{
"X"
,
names
}},
{{
"Out"
,
{
Output
(
"Out"
)}}},
attrs
);
op
->
Run
(
scope
,
place
);
}
}
};
};
...
@@ -139,17 +146,32 @@ class LoDTensorArray2TensorOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -139,17 +146,32 @@ class LoDTensorArray2TensorOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
int
>
(
"axis"
,
AddAttr
<
int
>
(
"axis"
,
"The axis along which the input tensors will be concatenated."
)
"The axis along which the input tensors will be concatenated."
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
bool
>
(
"use_stack"
,
"Act as concat_op or stack_op. For stack mode, all tensors "
"in the tensor array must have the same shape."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
tensor_array_to_tensor Operator.
tensor_array_to_tensor Operator.
Concatenate the input LoDTensorArray along dimension axis to the output Tensor.
If use concat mode, concatenate all tensors in the input LoDTensorArray along
axis into the output Tensor.
Examples:
Input = {[1,2], [3,4], [5,6]}
axis = 0
Output = [1,2,3,4,5,6]
OutputIndex = [2,2,2]
If use stack mode, stack all tensors in the input LoDTensorArray along axis into
the output Tensor.
Examples:
Examples:
Input = {[1,2], [3,4], [5,6]}
Input = {[1,2], [3,4], [5,6]}
axis = 0
axis = 0
Output = [[1,2],
Output = [[1,2],
[3,4],
[3,4],
[5,6]]
[5,6]]
OutputIndex = [
1,1,1
]
OutputIndex = [
2,2,2
]
)DOC"
);
)DOC"
);
}
}
...
@@ -157,12 +179,34 @@ Examples:
...
@@ -157,12 +179,34 @@ Examples:
class
LoDTensorArray2TensorOpInferShape
:
public
framework
::
InferShapeBase
{
class
LoDTensorArray2TensorOpInferShape
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
// in runtime, shape is determined by RunImpl
if
(
ctx
->
IsRuntime
())
return
;
auto
dims
=
ctx
->
GetInputDim
(
"X"
);
// if the shape is empty
if
(
dims
==
framework
::
make_ddim
({
0UL
}))
return
;
// otherwise, suppose the shape of array is the shape of tensor in the
// array, which is consistent with what tensor_array_read_write dose
auto
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
auto
use_stack
=
ctx
->
Attrs
().
Get
<
bool
>
(
"use_stack"
);
if
(
use_stack
)
{
auto
dim_vec
=
framework
::
vectorize
<
int
>
(
dims
);
// use -1 for the stack dim size
dim_vec
.
insert
(
dim_vec
.
begin
()
+
axis
,
-
1
);
dims
=
framework
::
make_ddim
(
dim_vec
);
}
else
{
// use -1 for the concat dim size
dims
[
axis
]
=
-
1
;
}
ctx
->
SetOutputDim
(
"Out"
,
dims
);
}
};
};
class
LoDTensorArray2TensorGradInferShape
:
public
framework
::
InferShapeBase
{
class
LoDTensorArray2TensorGradInferShape
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
};
};
class
LoDTensorArray2TensorGradInferVarType
class
LoDTensorArray2TensorGradInferVarType
...
@@ -204,11 +248,18 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
...
@@ -204,11 +248,18 @@ class LoDTensorArray2TensorGradOp : public framework::OperatorBase {
LodTensorVectorResizeFromLodTensorArray
(
scope
,
"grad_name"
,
Input
(
"X"
),
LodTensorVectorResizeFromLodTensorArray
(
scope
,
"grad_name"
,
Input
(
"X"
),
&
grad_names
);
&
grad_names
);
auto
concat_grad_op
=
framework
::
OpRegistry
::
CreateOp
(
auto
use_stack
=
Attr
<
bool
>
(
"use_stack"
);
"concat_grad"
,
{{
"X"
,
names
},
{
"Out@GRAD"
,
{
dout_name
}}},
{{
"X@GRAD"
,
grad_names
}},
attrs
);
auto
grad_op
=
use_stack
?
framework
::
OpRegistry
::
CreateOp
(
"stack_grad"
,
{{
"X"
,
names
},
{
"Y@GRAD"
,
{
dout_name
}}},
{{
"X@GRAD"
,
grad_names
}},
attrs
)
:
framework
::
OpRegistry
::
CreateOp
(
"concat_grad"
,
{{
"X"
,
names
},
{
"Out@GRAD"
,
{
dout_name
}}},
{{
"X@GRAD"
,
grad_names
}},
attrs
);
concat_
grad_op
->
Run
(
scope
,
place
);
grad_op
->
Run
(
scope
,
place
);
LodTensorArrayCreateFromLodTensorArray
(
scope
,
Input
(
"X"
),
dx_name
);
LodTensorArrayCreateFromLodTensorArray
(
scope
,
Input
(
"X"
),
dx_name
);
auto
&
grad_inx
=
auto
&
grad_inx
=
...
...
python/paddle/fluid/layers/__init__.py
浏览文件 @
dfd1eee7
...
@@ -35,6 +35,7 @@ from .metric_op import *
...
@@ -35,6 +35,7 @@ from .metric_op import *
from
.learning_rate_scheduler
import
*
from
.learning_rate_scheduler
import
*
from
.collective
import
*
from
.collective
import
*
from
.distributions
import
*
from
.distributions
import
*
from
.
import
rnn
__all__
=
[]
__all__
=
[]
__all__
+=
nn
.
__all__
__all__
+=
nn
.
__all__
...
@@ -47,3 +48,6 @@ __all__ += detection.__all__
...
@@ -47,3 +48,6 @@ __all__ += detection.__all__
__all__
+=
metric_op
.
__all__
__all__
+=
metric_op
.
__all__
__all__
+=
learning_rate_scheduler
.
__all__
__all__
+=
learning_rate_scheduler
.
__all__
__all__
+=
distributions
.
__all__
__all__
+=
distributions
.
__all__
__all__
+=
rnn
.
__all__
from
.rnn
import
*
python/paddle/fluid/layers/nn.py
浏览文件 @
dfd1eee7
...
@@ -221,6 +221,7 @@ __all__ = [
...
@@ -221,6 +221,7 @@ __all__ = [
'filter_by_instag',
'filter_by_instag',
'shard_index',
'shard_index',
'hard_swish',
'hard_swish',
'gather_tree',
'mse_loss',
'mse_loss',
'uniform_random',
'uniform_random',
]
]
...
@@ -16994,6 +16995,81 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
...
@@ -16994,6 +16995,81 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
return out
return out
def gather_tree(ids, parents):
"""
To be used after beam search. After beam search, we get selected ids at
each time step and the corresponding parents in the search tree. Both ids
and parents have the layout :attr:`[max_time, batch_size, beam_size]`. Then
:attr:`gather_tree` is used to backtrace from the last time step and
generate the full sequences by collecting selected ids.
Here is an example:
.. code-block:: text
Given:
ids = [[[2 2]
[6 1]]
[[3 9]
[6 1]]
[[0 1]
[9 0]]]
parents = [[[0 0]
[1 1]]
[[1 0]
[1 0]]
[[0 0]
[0 1]]]
Then:
gather_tree(ids, parents)
= [[[2 2]
[1 6]]
[[3 3]
[6 1]]
[[0 1]
[9 0]]]
Args:
ids(Variable): A Tensor with shape :attr:`[length, batch_size, beam_size]`
and data type :attr:`int32` or :attr:`int64`. It contains the selected
ids of all time steps.
parents(Variable): A Tensor with the same shape and data type as :attr:`ids`,
It contains the parents corresponding to selected ids when searching
among beams.
Returns:
Variable: A Tensor with the same shape and data type as :attr:`ids`. \
It contains the full sequences. The sequences are collected from \
:attr:`ids` by backtracing according to :attr:`parents`.
Examples:
.. code-block:: python
import paddle.fluid as fluid
ids = fluid.layers.data(name='ids',
shape=[5, 2, 2],
dtype='int64',
append_batch_size=False)
parents = fluid.layers.data(name='parents',
shape=[5, 2, 2],
dtype='int64',
append_batch_size=False)
final_sequences = fluid.layers.gather_tree(ids, parents)
"""
helper = LayerHelper('gather_tree', **locals())
out = helper.create_variable_for_type_inference(dtype=ids.dtype)
helper.append_op(
type="gather_tree",
inputs={"Ids": ids,
"Parents": parents},
outputs={"Out": out})
return out
def mse_loss(input, label):
def mse_loss(input, label):
"""
"""
This op accepts input predications and target label and returns the mean square error.
This op accepts input predications and target label and returns the mean square error.
...
...
python/paddle/fluid/layers/rnn.py
0 → 100644
浏览文件 @
dfd1eee7
此差异已折叠。
点击以展开。
python/paddle/fluid/layers/tensor.py
浏览文件 @
dfd1eee7
...
@@ -273,50 +273,85 @@ def concat(input, axis=0, name=None):
...
@@ -273,50 +273,85 @@ def concat(input, axis=0, name=None):
return
out
return
out
def
tensor_array_to_tensor
(
input
,
axis
=
1
,
name
=
None
):
def
tensor_array_to_tensor
(
input
,
axis
=
1
,
name
=
None
,
use_stack
=
False
):
"""
"""
This OP concatenates the input LodTensorArray along the axis.
This function concatenates or stacks all tensors in the input LoDTensorArray
along the axis mentioned and returns that as the output.
For Example:
.. code-block:: text
Case 1:
Given:
input.data = {[[0.6, 0.1, 0.3],
[0.5, 0.3, 0.2]],
[[1.3],
[1.8]],
[[2.3, 2.1],
[2.5, 2.4]]}
axis = 1, use_stack = False
Then:
output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1],
[0.5, 0.3, 0.2, 1.8, 2.5, 2.4]]
output_index.data = [3, 1, 2]
Case 2:
Given:
input.data = {[[0.6, 0.1],
[0.5, 0.3]],
[[0.3, 1.3],
[0.2, 1.8]],
[[2.3, 2.1],
[2.5, 2.4]]}
axis = 1, use_stack = True
Then:
output.data = [[[0.6, 0.1]
[0.3, 1.3]
[2.3, 2.1],
[[0.5, 0.3]
[0.2, 1.8]
[2.5, 2.4]]]
output_index.data = [2, 2, 2]
Args:
Args:
input(Variable): A LodTensorArray with data type float32, float64, int32,
input(Variable): A LodTensorArray variable.
int64.
axis(int): The axis along which the tensors in attr::`input` will be
axis(int, optional): Axis to compute indices along. The effective range
concatenated or stacked.
is [-R, R), where R is Rank(x). when axis<0, it works the same way
name(str|None): A name for this layer(optional). If set None, the layer
as axis+R. Default is 1.
will be named automatically.
name (str, optional): The default value is None. Normally there is no
use_stack(bool): Act as concat_op or stack_op. For stack mode, all
need for user to set this property. For more information, please
tensors in the tensor array must have the same shape.
refer to :ref:`api_guide_Name`.
Returns:
Returns:
Variable: A LoDTensor with the same data type as input's
Variable: The concatenated or stacked tensor variable.
Variable: The input LodTensorArray items' dims along the axis.
Variable: A 1-D tensor variable with int32 data type. The data in this
\
tensor contains all input including tensors' sizes along the axis.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid as fluid
import numpy as np
import numpy as np
x0 = fluid.layers.assign(np.random.rand(2, 2).astype("float32"))
place = fluid.CPUPlace()
x1 = fluid.layers.assign(np.random.rand(2, 2).astype("float32"))
i = fluid.layers.fill_constant(shape=[1], dtype="int64", value=0)
x1 = fluid.data(name="x", shape=[2,2], lod_level=0)
array = fluid.layers.create_array(dtype='float32')
tmp = fluid.layers.fill_constant(shape=[2,3], dtype="float32", value=1)
fluid.layers.array_write(x0, i, array)
x_arr = fluid.layers.create_array(dtype="float32")
fluid.layers.array_write(x1, i + 1, array)
c0 = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
output, output_index = fluid.layers.tensor_array_to_tensor(input=array)
fluid.layers.array_write(x=tmp, i=c0, array=x_arr)
c1 = fluid.layers.fill_constant(shape=[1], dtype='int64', value=1)
fluid.layers.array_write(x=x1, i=c1, array=x_arr)
output, output_index = fluid.layers.tensor_array_to_tensor(input=x_arr, axis=1)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feedx = fluid.LoDTensor()
feedx.set(np.array([[1.3,-2.4],[0,4]]).astype("float32"), place)
res = exe.run(fluid.default_main_program(), feed={'x':feedx}, fetch_list=[output], return_numpy=False)
print(np.array(res[0]))
# [[ 1. 1. 1. 1.3 -2.4]
# [ 1. 1. 1. 0. 4. ]]
"""
"""
helper
=
LayerHelper
(
'tensor_array_to_tensor'
,
**
locals
())
helper
=
LayerHelper
(
'tensor_array_to_tensor'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
helper
.
input_dtype
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
helper
.
input_dtype
())
...
@@ -326,7 +361,8 @@ def tensor_array_to_tensor(input, axis=1, name=None):
...
@@ -326,7 +361,8 @@ def tensor_array_to_tensor(input, axis=1, name=None):
inputs
=
{
'X'
:
input
},
inputs
=
{
'X'
:
input
},
outputs
=
{
'Out'
:
[
out
],
outputs
=
{
'Out'
:
[
out
],
'OutIndex'
:
[
out_index
]},
'OutIndex'
:
[
out_index
]},
attrs
=
{
'axis'
:
axis
})
attrs
=
{
'axis'
:
axis
,
'use_stack'
:
use_stack
})
return
out
,
out_index
return
out
,
out_index
...
@@ -517,7 +553,8 @@ def fill_constant_batch_size_like(input,
...
@@ -517,7 +553,8 @@ def fill_constant_batch_size_like(input,
dtype
,
dtype
,
value
,
value
,
input_dim_idx
=
0
,
input_dim_idx
=
0
,
output_dim_idx
=
0
):
output_dim_idx
=
0
,
force_cpu
=
False
):
"""
"""
This OP creates a Tesnor accroding the shape and dtype, and initializes the
This OP creates a Tesnor accroding the shape and dtype, and initializes the
Tensor with the constants provided in ``value``. When the input is LoDTensor
Tensor with the constants provided in ``value``. When the input is LoDTensor
...
@@ -537,6 +574,7 @@ def fill_constant_batch_size_like(input,
...
@@ -537,6 +574,7 @@ def fill_constant_batch_size_like(input,
The default value is 0.
The default value is 0.
output_dim_idx(int): Used to specify which dimension of Tensor is created to be set
output_dim_idx(int): Used to specify which dimension of Tensor is created to be set
the value of batch_size of input Tensor. The default value is 0.
the value of batch_size of input Tensor. The default value is 0.
force_cpu(bool): data should be on CPU if it's true, defalut value is False.
Returns:
Returns:
Variable: Tensor which will be created according to dtype.
Variable: Tensor which will be created according to dtype.
...
@@ -562,7 +600,8 @@ def fill_constant_batch_size_like(input,
...
@@ -562,7 +600,8 @@ def fill_constant_batch_size_like(input,
'dtype'
:
out
.
dtype
,
'dtype'
:
out
.
dtype
,
'value'
:
float
(
value
),
'value'
:
float
(
value
),
'input_dim_idx'
:
input_dim_idx
,
'input_dim_idx'
:
input_dim_idx
,
'output_dim_idx'
:
output_dim_idx
'output_dim_idx'
:
output_dim_idx
,
'force_cpu'
:
force_cpu
or
force_init_on_cpu
()
})
})
out
.
stop_gradient
=
True
out
.
stop_gradient
=
True
return
out
return
out
...
...
python/paddle/fluid/layers/utils.py
浏览文件 @
dfd1eee7
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
six
import
numpy
as
np
import
numpy
as
np
...
@@ -59,3 +61,173 @@ def convert_to_list(value, n, name, dtype=np.int):
...
@@ -59,3 +61,173 @@ def convert_to_list(value, n, name, dtype=np.int):
"including element "
+
str
(
single_value
)
+
" of type"
+
" "
"including element "
+
str
(
single_value
)
+
" of type"
+
" "
+
str
(
type
(
single_value
)))
+
str
(
type
(
single_value
)))
return
value_list
return
value_list
def
is_sequence
(
seq
):
"""
Whether `seq` is an entry or nested structure
"""
if
isinstance
(
seq
,
dict
):
return
True
return
(
isinstance
(
seq
,
collections
.
Sequence
)
and
not
isinstance
(
seq
,
six
.
string_types
))
def
_sorted
(
dict_
):
"""
Returns a sorted list of the dict keys, with error if keys not sortable.
"""
try
:
return
sorted
(
six
.
iterkeys
(
dict_
))
except
TypeError
:
raise
TypeError
(
"nest only supports dicts with sortable keys."
)
def
_yield_value
(
iterable
):
if
isinstance
(
iterable
,
dict
):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
for
key
in
_sorted
(
iterable
):
yield
iterable
[
key
]
else
:
for
value
in
iterable
:
yield
value
def
_yield_flat_nest
(
nest
):
for
n
in
_yield_value
(
nest
):
if
is_sequence
(
n
):
for
ni
in
_yield_flat_nest
(
n
):
yield
ni
else
:
yield
n
def
flatten
(
nest
):
"""
Traverse all entries in the nested structure and put them into an list.
"""
if
is_sequence
(
nest
):
return
list
(
_yield_flat_nest
(
nest
))
else
:
return
[
nest
]
def
_sequence_like
(
instance
,
args
):
"""
Convert the sequence `args` to the same type as `instance`.
"""
if
isinstance
(
instance
,
dict
):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
result
=
dict
(
zip
(
_sorted
(
instance
),
args
))
return
type
(
instance
)((
key
,
result
[
key
])
for
key
in
six
.
iterkeys
(
instance
))
elif
(
isinstance
(
instance
,
tuple
)
and
hasattr
(
instance
,
"_fields"
)
and
isinstance
(
instance
.
_fields
,
collections
.
Sequence
)
and
all
(
isinstance
(
f
,
six
.
string_types
)
for
f
in
instance
.
_fields
)):
# This is a namedtuple
return
type
(
instance
)(
*
args
)
else
:
# Not a namedtuple
return
type
(
instance
)(
args
)
def
_packed_nest_with_indices
(
structure
,
flat
,
index
):
"""
Helper function for pack_sequence_as.
"""
packed
=
[]
for
s
in
_yield_value
(
structure
):
if
is_sequence
(
s
):
new_index
,
child
=
_packed_nest_with_indices
(
s
,
flat
,
index
)
packed
.
append
(
_sequence_like
(
s
,
child
))
index
=
new_index
else
:
packed
.
append
(
flat
[
index
])
index
+=
1
return
index
,
packed
def
pack_sequence_as
(
structure
,
flat_sequence
):
"""
Pack a given flattened sequence into a given structure.
"""
if
not
is_sequence
(
flat_sequence
):
raise
TypeError
(
"flat_sequence must be a sequence"
)
if
not
is_sequence
(
structure
):
if
len
(
flat_sequence
)
!=
1
:
raise
ValueError
(
"Structure is a scalar but len(flat_sequence) == %d > 1"
%
len
(
flat_sequence
))
return
flat_sequence
[
0
]
flat_structure
=
flatten
(
structure
)
if
len
(
flat_structure
)
!=
len
(
flat_sequence
):
raise
ValueError
(
"Could not pack sequence. Structure had %d elements, but flat_sequence "
"had %d elements. Structure: %s, flat_sequence: %s."
%
(
len
(
flat_structure
),
len
(
flat_sequence
),
structure
,
flat_sequence
))
_
,
packed
=
_packed_nest_with_indices
(
structure
,
flat_sequence
,
0
)
return
_sequence_like
(
structure
,
packed
)
def
map_structure
(
func
,
*
structure
):
"""
Apply `func` to each entry in `structure` and return a new structure.
"""
flat_structure
=
[
flatten
(
s
)
for
s
in
structure
]
entries
=
zip
(
*
flat_structure
)
return
pack_sequence_as
(
structure
[
0
],
[
func
(
*
x
)
for
x
in
entries
])
def
_recursive_assert_same_structure
(
nest1
,
nest2
,
check_types
):
"""
Helper function for `assert_same_structure`.
"""
is_sequence_nest1
=
is_sequence
(
nest1
)
if
is_sequence_nest1
!=
is_sequence
(
nest2
):
raise
ValueError
(
"The two structures don't have the same nested structure.
\n\n
"
"First structure: %s
\n\n
Second structure: %s."
%
(
nest1
,
nest2
))
if
not
is_sequence_nest1
:
return
# finished checking
if
check_types
:
type_nest1
=
type
(
nest1
)
type_nest2
=
type
(
nest2
)
if
type_nest1
!=
type_nest2
:
raise
TypeError
(
"The two structures don't have the same sequence type. First "
"structure has type %s, while second structure has type %s."
%
(
type_nest1
,
type_nest2
))
if
isinstance
(
nest1
,
dict
):
keys1
=
set
(
six
.
iterkeys
(
nest1
))
keys2
=
set
(
six
.
iterkeys
(
nest2
))
if
keys1
!=
keys2
:
raise
ValueError
(
"The two dictionaries don't have the same set of keys. First "
"structure has keys {}, while second structure has keys {}."
.
format
(
keys1
,
keys2
))
nest1_as_sequence
=
[
n
for
n
in
_yield_value
(
nest1
)]
nest2_as_sequence
=
[
n
for
n
in
_yield_value
(
nest2
)]
for
n1
,
n2
in
zip
(
nest1_as_sequence
,
nest2_as_sequence
):
_recursive_assert_same_structure
(
n1
,
n2
,
check_types
)
def
assert_same_structure
(
nest1
,
nest2
,
check_types
=
True
):
"""
Confirm two nested structures with the same structure.
"""
len_nest1
=
len
(
flatten
(
nest1
))
if
is_sequence
(
nest1
)
else
1
len_nest2
=
len
(
flatten
(
nest2
))
if
is_sequence
(
nest2
)
else
1
if
len_nest1
!=
len_nest2
:
raise
ValueError
(
"The two structures don't have the same number of "
"elements.
\n\n
First structure (%i elements): %s
\n\n
"
"Second structure (%i elements): %s"
%
(
len_nest1
,
nest1
,
len_nest2
,
nest2
))
_recursive_assert_same_structure
(
nest1
,
nest2
,
check_types
)
python/paddle/fluid/tests/unittests/test_gather_tree_op.py
0 → 100644
浏览文件 @
dfd1eee7
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle.fluid
as
fluid
class
TestGatherTreeOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"gather_tree"
max_length
,
batch_size
,
beam_size
=
5
,
2
,
2
ids
=
np
.
random
.
randint
(
0
,
high
=
10
,
size
=
(
max_length
,
batch_size
,
beam_size
))
parents
=
np
.
random
.
randint
(
0
,
high
=
beam_size
,
size
=
(
max_length
,
batch_size
,
beam_size
))
self
.
inputs
=
{
"Ids"
:
ids
,
"Parents"
:
parents
}
self
.
outputs
=
{
'Out'
:
self
.
backtrace
(
ids
,
parents
)}
def
test_check_output
(
self
):
self
.
check_output
()
@
staticmethod
def
backtrace
(
ids
,
parents
):
out
=
np
.
zeros_like
(
ids
)
(
max_length
,
batch_size
,
beam_size
)
=
ids
.
shape
for
batch
in
range
(
batch_size
):
for
beam
in
range
(
beam_size
):
out
[
max_length
-
1
,
batch
,
beam
]
=
ids
[
max_length
-
1
,
batch
,
beam
]
parent
=
parents
[
max_length
-
1
,
batch
,
beam
]
for
step
in
range
(
max_length
-
2
,
-
1
,
-
1
):
out
[
step
,
batch
,
beam
]
=
ids
[
step
,
batch
,
parent
]
parent
=
parents
[
step
,
batch
,
parent
]
return
out
class
TestGatherTreeOpAPI
(
OpTest
):
def
test_case
(
self
):
ids
=
fluid
.
layers
.
data
(
name
=
'ids'
,
shape
=
[
5
,
2
,
2
],
dtype
=
'int64'
,
append_batch_size
=
False
)
parents
=
fluid
.
layers
.
data
(
name
=
'parents'
,
shape
=
[
5
,
2
,
2
],
dtype
=
'int64'
,
append_batch_size
=
False
)
final_sequences
=
fluid
.
layers
.
gather_tree
(
ids
,
parents
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_rnn_cell_api.py
0 → 100644
浏览文件 @
dfd1eee7
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
import
paddle.fluid.core
as
core
from
paddle.fluid.executor
import
Executor
from
paddle.fluid
import
framework
from
paddle.fluid.layers.rnn
import
LSTMCell
,
GRUCell
,
RNNCell
from
paddle.fluid.layers
import
rnn
as
dynamic_rnn
from
paddle.fluid
import
contrib
from
paddle.fluid.contrib.layers
import
basic_lstm
import
paddle.fluid.layers.utils
as
utils
import
numpy
as
np
class
TestLSTMCell
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
4
self
.
input_size
=
16
self
.
hidden_size
=
16
def
test_run
(
self
):
inputs
=
fluid
.
data
(
name
=
'inputs'
,
shape
=
[
None
,
self
.
input_size
],
dtype
=
'float32'
)
pre_hidden
=
fluid
.
data
(
name
=
'pre_hidden'
,
shape
=
[
None
,
self
.
hidden_size
],
dtype
=
'float32'
)
pre_cell
=
fluid
.
data
(
name
=
'pre_cell'
,
shape
=
[
None
,
self
.
hidden_size
],
dtype
=
'float32'
)
cell
=
LSTMCell
(
self
.
hidden_size
)
lstm_hidden_new
,
lstm_states_new
=
cell
(
inputs
,
[
pre_hidden
,
pre_cell
])
lstm_unit
=
contrib
.
layers
.
rnn_impl
.
BasicLSTMUnit
(
"basicLSTM"
,
self
.
hidden_size
,
None
,
None
,
None
,
None
,
1.0
,
"float32"
)
lstm_hidden
,
lstm_cell
=
lstm_unit
(
inputs
,
pre_hidden
,
pre_cell
)
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
else
:
place
=
core
.
CPUPlace
()
exe
=
Executor
(
place
)
exe
.
run
(
framework
.
default_startup_program
())
inputs_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
batch_size
,
self
.
input_size
)).
astype
(
'float32'
)
pre_hidden_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
batch_size
,
self
.
hidden_size
)).
astype
(
'float32'
)
pre_cell_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
batch_size
,
self
.
hidden_size
)).
astype
(
'float32'
)
param_names
=
[[
"LSTMCell/BasicLSTMUnit_0.w_0"
,
"basicLSTM/BasicLSTMUnit_0.w_0"
],
[
"LSTMCell/BasicLSTMUnit_0.b_0"
,
"basicLSTM/BasicLSTMUnit_0.b_0"
]]
for
names
in
param_names
:
param
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
names
[
0
]).
get_tensor
(
))
param
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
size
=
param
.
shape
).
astype
(
'float32'
)
fluid
.
global_scope
().
find_var
(
names
[
0
]).
get_tensor
().
set
(
param
,
place
)
fluid
.
global_scope
().
find_var
(
names
[
1
]).
get_tensor
().
set
(
param
,
place
)
out
=
exe
.
run
(
feed
=
{
'inputs'
:
inputs_np
,
'pre_hidden'
:
pre_hidden_np
,
'pre_cell'
:
pre_cell_np
},
fetch_list
=
[
lstm_hidden_new
,
lstm_hidden
])
self
.
assertTrue
(
np
.
allclose
(
out
[
0
],
out
[
1
],
rtol
=
1e-4
,
atol
=
0
))
class
TestGRUCell
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
4
self
.
input_size
=
16
self
.
hidden_size
=
16
def
test_run
(
self
):
inputs
=
fluid
.
data
(
name
=
'inputs'
,
shape
=
[
None
,
self
.
input_size
],
dtype
=
'float32'
)
pre_hidden
=
layers
.
data
(
name
=
'pre_hidden'
,
shape
=
[
None
,
self
.
hidden_size
],
append_batch_size
=
False
,
dtype
=
'float32'
)
cell
=
GRUCell
(
self
.
hidden_size
)
gru_hidden_new
,
_
=
cell
(
inputs
,
pre_hidden
)
gru_unit
=
contrib
.
layers
.
rnn_impl
.
BasicGRUUnit
(
"basicGRU"
,
self
.
hidden_size
,
None
,
None
,
None
,
None
,
"float32"
)
gru_hidden
=
gru_unit
(
inputs
,
pre_hidden
)
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
else
:
place
=
core
.
CPUPlace
()
exe
=
Executor
(
place
)
exe
.
run
(
framework
.
default_startup_program
())
inputs_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
batch_size
,
self
.
input_size
)).
astype
(
'float32'
)
pre_hidden_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
batch_size
,
self
.
hidden_size
)).
astype
(
'float32'
)
param_names
=
[
[
"GRUCell/BasicGRUUnit_0.w_0"
,
"basicGRU/BasicGRUUnit_0.w_0"
],
[
"GRUCell/BasicGRUUnit_0.w_1"
,
"basicGRU/BasicGRUUnit_0.w_1"
],
[
"GRUCell/BasicGRUUnit_0.b_0"
,
"basicGRU/BasicGRUUnit_0.b_0"
],
[
"GRUCell/BasicGRUUnit_0.b_1"
,
"basicGRU/BasicGRUUnit_0.b_1"
]
]
for
names
in
param_names
:
param
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
names
[
0
]).
get_tensor
(
))
param
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
size
=
param
.
shape
).
astype
(
'float32'
)
fluid
.
global_scope
().
find_var
(
names
[
0
]).
get_tensor
().
set
(
param
,
place
)
fluid
.
global_scope
().
find_var
(
names
[
1
]).
get_tensor
().
set
(
param
,
place
)
out
=
exe
.
run
(
feed
=
{
'inputs'
:
inputs_np
,
'pre_hidden'
:
pre_hidden_np
},
fetch_list
=
[
gru_hidden_new
,
gru_hidden
])
self
.
assertTrue
(
np
.
allclose
(
out
[
0
],
out
[
1
],
rtol
=
1e-4
,
atol
=
0
))
class
TestRnn
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
4
self
.
input_size
=
16
self
.
hidden_size
=
16
self
.
seq_len
=
4
def
test_run
(
self
):
inputs_basic_lstm
=
fluid
.
data
(
name
=
'inputs_basic_lstm'
,
shape
=
[
None
,
None
,
self
.
input_size
],
dtype
=
'float32'
)
sequence_length
=
fluid
.
data
(
name
=
"sequence_length"
,
shape
=
[
None
],
dtype
=
'int64'
)
inputs_dynamic_rnn
=
layers
.
transpose
(
inputs_basic_lstm
,
perm
=
[
1
,
0
,
2
])
cell
=
LSTMCell
(
self
.
hidden_size
,
name
=
"LSTMCell_for_rnn"
)
output
,
final_state
=
dynamic_rnn
(
cell
=
cell
,
inputs
=
inputs_dynamic_rnn
,
sequence_length
=
sequence_length
,
is_reverse
=
False
)
output_new
=
layers
.
transpose
(
output
,
perm
=
[
1
,
0
,
2
])
rnn_out
,
last_hidden
,
last_cell
=
basic_lstm
(
inputs_basic_lstm
,
None
,
None
,
self
.
hidden_size
,
num_layers
=
1
,
\
batch_first
=
False
,
bidirectional
=
False
,
sequence_length
=
sequence_length
,
forget_bias
=
1.0
)
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
else
:
place
=
core
.
CPUPlace
()
exe
=
Executor
(
place
)
exe
.
run
(
framework
.
default_startup_program
())
inputs_basic_lstm_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
seq_len
,
self
.
batch_size
,
self
.
input_size
)).
astype
(
'float32'
)
sequence_length_np
=
np
.
ones
(
self
.
batch_size
,
dtype
=
'int64'
)
*
self
.
seq_len
inputs_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
batch_size
,
self
.
input_size
)).
astype
(
'float32'
)
pre_hidden_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
batch_size
,
self
.
hidden_size
)).
astype
(
'float32'
)
pre_cell_np
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
(
self
.
batch_size
,
self
.
hidden_size
)).
astype
(
'float32'
)
param_names
=
[[
"LSTMCell_for_rnn/BasicLSTMUnit_0.w_0"
,
"basic_lstm_layers_0/BasicLSTMUnit_0.w_0"
],
[
"LSTMCell_for_rnn/BasicLSTMUnit_0.b_0"
,
"basic_lstm_layers_0/BasicLSTMUnit_0.b_0"
]]
for
names
in
param_names
:
param
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
names
[
0
]).
get_tensor
(
))
param
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
size
=
param
.
shape
).
astype
(
'float32'
)
fluid
.
global_scope
().
find_var
(
names
[
0
]).
get_tensor
().
set
(
param
,
place
)
fluid
.
global_scope
().
find_var
(
names
[
1
]).
get_tensor
().
set
(
param
,
place
)
out
=
exe
.
run
(
feed
=
{
'inputs_basic_lstm'
:
inputs_basic_lstm_np
,
'sequence_length'
:
sequence_length_np
,
'inputs'
:
inputs_np
,
'pre_hidden'
:
pre_hidden_np
,
'pre_cell'
:
pre_cell_np
},
fetch_list
=
[
output_new
,
rnn_out
])
self
.
assertTrue
(
np
.
allclose
(
out
[
0
],
out
[
1
],
rtol
=
1e-4
))
class
TestRnnUtil
(
unittest
.
TestCase
):
"""
Test cases for rnn apis' utility methods for coverage.
"""
def
test_case
(
self
):
inputs
=
{
"key1"
:
1
,
"key2"
:
2
}
func
=
lambda
x
:
x
+
1
outputs
=
utils
.
map_structure
(
func
,
inputs
)
utils
.
assert_same_structure
(
inputs
,
outputs
)
try
:
inputs
[
"key3"
]
=
3
utils
.
assert_same_structure
(
inputs
,
outputs
)
except
ValueError
as
identifier
:
pass
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_rnn_decode_api.py
0 → 100644
浏览文件 @
dfd1eee7
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
import
paddle.fluid.core
as
core
from
paddle.fluid.executor
import
Executor
from
paddle.fluid
import
framework
from
paddle.fluid.layers.rnn
import
LSTMCell
,
GRUCell
,
RNNCell
,
BeamSearchDecoder
,
dynamic_decode
from
paddle.fluid.layers
import
rnn
as
dynamic_rnn
from
paddle.fluid
import
contrib
from
paddle.fluid.contrib.layers
import
basic_lstm
import
numpy
as
np
class
EncoderCell
(
RNNCell
):
def
__init__
(
self
,
num_layers
,
hidden_size
,
dropout_prob
=
0.
):
self
.
num_layers
=
num_layers
self
.
hidden_size
=
hidden_size
self
.
dropout_prob
=
dropout_prob
self
.
lstm_cells
=
[]
for
i
in
range
(
num_layers
):
self
.
lstm_cells
.
append
(
LSTMCell
(
hidden_size
))
def
call
(
self
,
step_input
,
states
):
new_states
=
[]
for
i
in
range
(
self
.
num_layers
):
out
,
new_state
=
self
.
lstm_cells
[
i
](
step_input
,
states
[
i
])
step_input
=
layers
.
dropout
(
out
,
self
.
dropout_prob
)
if
self
.
dropout_prob
>
0
else
out
new_states
.
append
(
new_state
)
return
step_input
,
new_states
@
property
def
state_shape
(
self
):
return
[
cell
.
state_shape
for
cell
in
self
.
lstm_cells
]
class
DecoderCell
(
RNNCell
):
def
__init__
(
self
,
num_layers
,
hidden_size
,
dropout_prob
=
0.
):
self
.
num_layers
=
num_layers
self
.
hidden_size
=
hidden_size
self
.
dropout_prob
=
dropout_prob
self
.
lstm_cells
=
[]
for
i
in
range
(
num_layers
):
self
.
lstm_cells
.
append
(
LSTMCell
(
hidden_size
))
def
attention
(
self
,
hidden
,
encoder_output
,
encoder_padding_mask
):
query
=
layers
.
fc
(
hidden
,
size
=
encoder_output
.
shape
[
-
1
],
bias_attr
=
False
)
attn_scores
=
layers
.
matmul
(
layers
.
unsqueeze
(
query
,
[
1
]),
encoder_output
,
transpose_y
=
True
)
if
encoder_padding_mask
is
not
None
:
attn_scores
=
layers
.
elementwise_add
(
attn_scores
,
encoder_padding_mask
)
attn_scores
=
layers
.
softmax
(
attn_scores
)
attn_out
=
layers
.
squeeze
(
layers
.
matmul
(
attn_scores
,
encoder_output
),
[
1
])
attn_out
=
layers
.
concat
([
attn_out
,
hidden
],
1
)
attn_out
=
layers
.
fc
(
attn_out
,
size
=
self
.
hidden_size
,
bias_attr
=
False
)
return
attn_out
def
call
(
self
,
step_input
,
states
,
encoder_output
,
encoder_padding_mask
=
None
):
lstm_states
,
input_feed
=
states
new_lstm_states
=
[]
step_input
=
layers
.
concat
([
step_input
,
input_feed
],
1
)
for
i
in
range
(
self
.
num_layers
):
out
,
new_lstm_state
=
self
.
lstm_cells
[
i
](
step_input
,
lstm_states
[
i
])
step_input
=
layers
.
dropout
(
out
,
self
.
dropout_prob
)
if
self
.
dropout_prob
>
0
else
out
new_lstm_states
.
append
(
new_lstm_state
)
out
=
self
.
attention
(
step_input
,
encoder_output
,
encoder_padding_mask
)
return
out
,
[
new_lstm_states
,
out
]
class
TestDynamicDecode
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
4
self
.
input_size
=
16
self
.
hidden_size
=
16
self
.
seq_len
=
4
def
test_run
(
self
):
start_token
=
0
end_token
=
1
src_vocab_size
=
10
trg_vocab_size
=
10
num_layers
=
1
hidden_size
=
self
.
hidden_size
beam_size
=
8
max_length
=
self
.
seq_len
src
=
layers
.
data
(
name
=
"src"
,
shape
=
[
-
1
,
1
],
dtype
=
'int64'
)
src_len
=
layers
.
data
(
name
=
"src_len"
,
shape
=
[
-
1
],
dtype
=
'int64'
)
trg
=
layers
.
data
(
name
=
"trg"
,
shape
=
[
-
1
,
1
],
dtype
=
'int64'
)
trg_len
=
layers
.
data
(
name
=
"trg_len"
,
shape
=
[
-
1
],
dtype
=
'int64'
)
src_embeder
=
lambda
x
:
fluid
.
embedding
(
x
,
size
=
[
src_vocab_size
,
hidden_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"src_embedding"
))
trg_embeder
=
lambda
x
:
fluid
.
embedding
(
x
,
size
=
[
trg_vocab_size
,
hidden_size
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"trg_embedding"
))
# use basic_lstm
encoder_cell
=
EncoderCell
(
num_layers
,
hidden_size
)
encoder_output
,
encoder_final_state
=
dynamic_rnn
(
cell
=
encoder_cell
,
inputs
=
src_embeder
(
src
),
sequence_length
=
src_len
,
is_reverse
=
False
)
src_mask
=
layers
.
sequence_mask
(
src_len
,
maxlen
=
layers
.
shape
(
src
)[
1
],
dtype
=
'float32'
)
encoder_padding_mask
=
(
src_mask
-
1.0
)
*
1000000000
encoder_padding_mask
=
layers
.
unsqueeze
(
encoder_padding_mask
,
[
1
])
decoder_cell
=
DecoderCell
(
num_layers
,
hidden_size
)
decoder_initial_states
=
[
encoder_final_state
,
decoder_cell
.
get_initial_states
(
batch_ref
=
encoder_output
,
shape
=
[
hidden_size
])
]
decoder_output
,
_
=
dynamic_rnn
(
cell
=
decoder_cell
,
inputs
=
trg_embeder
(
trg
),
initial_states
=
decoder_initial_states
,
sequence_length
=
None
,
encoder_output
=
encoder_output
,
encoder_padding_mask
=
encoder_padding_mask
)
output_layer
=
lambda
x
:
layers
.
fc
(
x
,
size
=
trg_vocab_size
,
num_flatten_dims
=
len
(
x
.
shape
)
-
1
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"output_w"
),
bias_attr
=
False
)
# inference
encoder_output
=
BeamSearchDecoder
.
tile_beam_merge_with_batch
(
encoder_output
,
beam_size
)
encoder_padding_mask
=
BeamSearchDecoder
.
tile_beam_merge_with_batch
(
encoder_padding_mask
,
beam_size
)
beam_search_decoder
=
BeamSearchDecoder
(
decoder_cell
,
start_token
,
end_token
,
beam_size
,
embedding_fn
=
trg_embeder
,
output_fn
=
output_layer
)
outputs
,
_
=
dynamic_decode
(
beam_search_decoder
,
inits
=
decoder_initial_states
,
max_step_num
=
max_length
,
encoder_output
=
encoder_output
,
encoder_padding_mask
=
encoder_padding_mask
)
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
else
:
place
=
core
.
CPUPlace
()
exe
=
Executor
(
place
)
exe
.
run
(
framework
.
default_startup_program
())
src_np
=
np
.
random
.
randint
(
0
,
src_vocab_size
,
(
self
.
batch_size
,
max_length
)).
astype
(
'int64'
)
src_len_np
=
np
.
ones
(
self
.
batch_size
,
dtype
=
'int64'
)
*
max_length
trg_np
=
np
.
random
.
randint
(
0
,
trg_vocab_size
,
(
self
.
batch_size
,
max_length
)).
astype
(
'int64'
)
trg_len_np
=
np
.
ones
(
self
.
batch_size
,
dtype
=
'int64'
)
*
max_length
out
=
exe
.
run
(
feed
=
{
'src'
:
src_np
,
'src_len'
:
src_len_np
,
'trg'
:
trg_np
,
'trg_len'
:
trg_len_np
},
fetch_list
=
[
outputs
])
self
.
assertTrue
(
out
[
0
].
shape
[
0
]
==
self
.
batch_size
)
self
.
assertTrue
(
out
[
0
].
shape
[
1
]
<=
max_length
+
1
)
self
.
assertTrue
(
out
[
0
].
shape
[
2
]
==
beam_size
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_tensor_array_to_tensor.py
浏览文件 @
dfd1eee7
...
@@ -23,6 +23,8 @@ from paddle.fluid.executor import Executor
...
@@ -23,6 +23,8 @@ from paddle.fluid.executor import Executor
class
TestLoDTensorArrayConcat
(
unittest
.
TestCase
):
class
TestLoDTensorArrayConcat
(
unittest
.
TestCase
):
"""Test case for concat mode of tensor_array_to_tensor."""
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"tensor_array_to_tensor"
self
.
op_type
=
"tensor_array_to_tensor"
self
.
attrs
=
{
"axis"
:
0
}
self
.
attrs
=
{
"axis"
:
0
}
...
@@ -98,7 +100,7 @@ class TestLoDTensorArrayConcat(unittest.TestCase):
...
@@ -98,7 +100,7 @@ class TestLoDTensorArrayConcat(unittest.TestCase):
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
out
=
exe
.
run
(
program
,
fetch_list
=
fetch_list
,
scope
=
scope
)
out
=
exe
.
run
(
program
,
fetch_list
=
fetch_list
,
scope
=
scope
)
#print ("index: ", numpy.array(out[1]))
#print ("index: ", numpy.array(out[1]))
# test forward
# test forward
tensor_res
=
numpy
.
array
(
out
[
0
])
tensor_res
=
numpy
.
array
(
out
[
0
])
...
@@ -138,5 +140,82 @@ class TestLoDTensorArrayConcat(unittest.TestCase):
...
@@ -138,5 +140,82 @@ class TestLoDTensorArrayConcat(unittest.TestCase):
numpy
.
array
(
random_grad
[
i
+
1
]))
numpy
.
array
(
random_grad
[
i
+
1
]))
class
TestLoDTensorArrayStack
(
unittest
.
TestCase
):
"""Test case for stack mode of tensor_array_to_tensor."""
def
setUp
(
self
):
self
.
op_type
=
"tensor_array_to_tensor"
self
.
attrs
=
{
"axis"
:
1
,
"use_stack"
:
True
}
self
.
inputs
=
[
numpy
.
random
.
rand
(
2
,
3
,
4
).
astype
(
"float32"
),
numpy
.
random
.
rand
(
2
,
3
,
4
).
astype
(
"float32"
),
numpy
.
random
.
rand
(
2
,
3
,
4
).
astype
(
"float32"
)
]
self
.
outputs
=
[
numpy
.
stack
(
self
.
inputs
,
axis
=
self
.
attrs
[
"axis"
]),
numpy
.
array
(
[
x
.
shape
[
self
.
attrs
[
"axis"
]]
for
x
in
self
.
inputs
],
dtype
=
"int32"
)
]
self
.
input_grads
=
[
numpy
.
ones_like
(
x
)
for
x
in
self
.
inputs
]
self
.
set_program
()
for
var
in
self
.
program
.
list_vars
():
# to avoid scope clearing after execution
var
.
persistable
=
True
def
set_program
(
self
):
self
.
program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
self
.
program
):
self
.
array
=
array
=
fluid
.
layers
.
create_array
(
dtype
=
'float32'
)
idx
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
"int64"
,
value
=
0
)
for
i
,
x
in
enumerate
(
self
.
inputs
):
x
=
fluid
.
layers
.
assign
(
x
)
fluid
.
layers
.
array_write
(
x
,
idx
+
i
,
array
)
output
,
output_index
=
fluid
.
layers
.
tensor_array_to_tensor
(
input
=
array
,
**
self
.
attrs
)
loss
=
fluid
.
layers
.
reduce_sum
(
output
)
fluid
.
backward
.
append_backward
(
loss
)
self
.
output_vars
=
[
output
,
output_index
]
def
run_check
(
self
,
executor
,
scope
):
executor
.
run
(
self
.
program
,
scope
=
scope
)
for
i
,
output
in
enumerate
(
self
.
outputs
):
numpy
.
allclose
(
numpy
.
array
(
scope
.
var
(
self
.
output_vars
[
i
].
name
).
get_tensor
()),
output
,
atol
=
0
)
tensor_array_grad
=
scope
.
var
(
self
.
array
.
name
).
get_lod_tensor_array
()
for
i
,
input_grad
in
enumerate
(
self
.
input_grads
):
numpy
.
allclose
(
numpy
.
array
(
tensor_array_grad
[
i
]),
input_grad
,
atol
=
0
)
def
test_cpu
(
self
):
scope
=
core
.
Scope
()
place
=
core
.
CPUPlace
()
executor
=
fluid
.
Executor
(
place
)
self
.
run_check
(
executor
,
scope
)
def
test_gpu
(
self
):
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
scope
=
core
.
Scope
()
executor
=
fluid
.
Executor
(
place
)
self
.
run_check
(
executor
,
scope
)
class
TestTensorArrayToTensorAPI
(
unittest
.
TestCase
):
def
test_case
(
self
):
x0
=
fluid
.
layers
.
assign
(
numpy
.
random
.
rand
(
2
,
3
,
4
).
astype
(
"float32"
))
x1
=
fluid
.
layers
.
assign
(
numpy
.
random
.
rand
(
2
,
3
,
4
).
astype
(
"float32"
))
i
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
"int64"
,
value
=
0
)
array
=
fluid
.
layers
.
create_array
(
dtype
=
'float32'
)
fluid
.
layers
.
array_write
(
x0
,
i
,
array
)
fluid
.
layers
.
array_write
(
x1
,
i
+
1
,
array
)
output
,
output_index
=
fluid
.
layers
.
tensor_array_to_tensor
(
input
=
array
,
axis
=
1
,
use_stack
=
True
)
output
,
output_index
=
fluid
.
layers
.
tensor_array_to_tensor
(
input
=
array
,
axis
=
1
,
use_stack
=
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录