Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
581cf909
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
581cf909
编写于
7月 14, 2017
作者:
C
Cao Ying
提交者:
GitHub
7月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2865 from lcy-seso/add_gated_unit_layer
add configuration helper for the gated unit.
上级
58f3de95
e2fd06c3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
223 addition
and
2 deletion
+223
-2
doc/api/v2/config/layer.rst
doc/api/v2/config/layer.rst
+5
-0
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+95
-1
python/paddle/trainer_config_helpers/tests/configs/file_list.sh
.../paddle/trainer_config_helpers/tests/configs/file_list.sh
+1
-1
python/paddle/trainer_config_helpers/tests/configs/protostr/test_gated_unit_layer.protostr
...ers/tests/configs/protostr/test_gated_unit_layer.protostr
+106
-0
python/paddle/trainer_config_helpers/tests/configs/test_gated_unit_layer.py
...ner_config_helpers/tests/configs/test_gated_unit_layer.py
+16
-0
未找到文件。
doc/api/v2/config/layer.rst
浏览文件 @
581cf909
...
@@ -474,6 +474,11 @@ prelu
...
@@ -474,6 +474,11 @@ prelu
.. autoclass:: paddle.v2.layer.prelu
.. autoclass:: paddle.v2.layer.prelu
:noindex:
:noindex:
gated_unit
-----------
.. autoclass:: paddle.v2.layer.gated_unit
:noindex:
Detection output Layer
Detection output Layer
======================
======================
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
581cf909
...
@@ -126,6 +126,7 @@ __all__ = [
...
@@ -126,6 +126,7 @@ __all__ = [
'row_conv_layer'
,
'row_conv_layer'
,
'dropout_layer'
,
'dropout_layer'
,
'prelu_layer'
,
'prelu_layer'
,
'gated_unit_layer'
,
]
]
...
@@ -5862,7 +5863,7 @@ def prelu_layer(input,
...
@@ -5862,7 +5863,7 @@ def prelu_layer(input,
:rtype: LayerOutput
:rtype: LayerOutput
"""
"""
assert
isinstance
(
input
,
LayerOutput
),
'prelu_layer
only accepts one input
'
assert
isinstance
(
input
,
LayerOutput
),
'prelu_layer
accepts only one input.
'
assert
isinstance
(
param_attr
,
ParameterAttribute
)
assert
isinstance
(
param_attr
,
ParameterAttribute
)
l
=
Layer
(
l
=
Layer
(
...
@@ -5876,3 +5877,96 @@ def prelu_layer(input,
...
@@ -5876,3 +5877,96 @@ def prelu_layer(input,
layer_type
=
LayerType
.
PRELU
,
layer_type
=
LayerType
.
PRELU
,
parents
=
input
,
parents
=
input
,
size
=
l
.
config
.
size
)
size
=
l
.
config
.
size
)
@
wrap_name_default
()
@
layer_support
(
ERROR_CLIPPING
,
DROPOUT
)
@
wrap_act_default
(
act
=
LinearActivation
())
def
gated_unit_layer
(
input
,
size
,
act
=
None
,
name
=
None
,
gate_attr
=
None
,
gate_param_attr
=
None
,
gate_bias_attr
=
True
,
inproj_attr
=
None
,
inproj_param_attr
=
None
,
inproj_bias_attr
=
True
,
layer_attr
=
None
):
"""
The gated unit layer implements a simple gating mechanism over the input.
The input :math:`X` is first projected into a new space :math:`X'`, and
it is also used to produce a gate weight :math:`\sigma`. Element-wise
prodict between :match:`X'` and :math:`\sigma` is finally returned.
Reference:
Language Modeling with Gated Convolutional Networks
https://arxiv.org/abs/1612.08083
.. math::
y=
\\
text{act}(X \cdot W + b)\otimes \sigma(X \cdot V + c)
The example usage is:
.. code-block:: python
gated_unit = gated_unit_layer(size=128, input=input_layer))
:param input: input for this layer.
:type input: LayerOutput
:param size: output size of the gated unit.
:type size: int
:param act: activation type of the projected input.
:type act: BaseActivation
:param name: name of this layer.
:type name: basestring
:param gate_attr: Attributes to tune the gate output, for example, error
clipping threshold, dropout and so on. See ExtraLayerAttribute for
more details.
:type gate_attr: ExtraLayerAttribute|None
:param gate_param_attr: Attributes to tune the learnable projected matrix
parameter of the gate.
:type gate_param_attr: ParameterAttribute|None
:param gate_bias_attr: Attributes to tune the learnable bias of the gate.
:type gate_bias_attr: ParameterAttribute|None
:param inproj_attr: Attributes to the tune the projected input, for
example, error clipping threshold, dropout and so on. See
ExtraLayerAttribute for more details.
:type inproj_attr: ExtraLayerAttribute|None
:param inproj_param_attr: Attributes to tune the learnable parameter of
the projection of input.
:type inproj_param_attr: ParameterAttribute|None
:param inproj_bias_attr: Attributes to tune the learnable bias of
projection of the input.
:type inproj_bias_attr: ParameterAttribute|None
:param layer_attr: Attributes to tune the final output of the gated unit,
for example, error clipping threshold, dropout and so on. See
ExtraLayerAttribute for more details.
:type layer_attr: ExtraLayerAttribute|None
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert
isinstance
(
input
,
LayerOutput
),
'The gated linear unit accepts only one input.'
input_proj
=
fc_layer
(
input
=
input
,
name
=
"%s_input_proj"
%
name
,
size
=
size
,
act
=
act
,
layer_attr
=
inproj_attr
,
param_attr
=
inproj_param_attr
,
bias_attr
=
inproj_bias_attr
)
gate
=
fc_layer
(
size
=
size
,
name
=
"%s_gate"
%
name
,
act
=
SigmoidActivation
(),
input
=
input
,
layer_attr
=
gate_attr
,
param_attr
=
gate_param_attr
,
bias_attr
=
gate_bias_attr
)
return
mixed_layer
(
name
=
"%s_gated_act"
%
name
,
input
=
dotmul_operator
(
input_proj
,
gate
),
layer_attr
=
layer_attr
)
python/paddle/trainer_config_helpers/tests/configs/file_list.sh
浏览文件 @
581cf909
...
@@ -7,6 +7,6 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
...
@@ -7,6 +7,6 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology
)
test_recursive_topology
test_gated_unit_layer
)
export
whole_configs
=(
test_split_datasource
)
export
whole_configs
=(
test_split_datasource
)
python/paddle/trainer_config_helpers/tests/configs/protostr/test_gated_unit_layer.protostr
0 → 100644
浏览文件 @
581cf909
type: "nn"
layers {
name: "input"
type: "data"
size: 256
active_type: ""
}
layers {
name: "__gated_unit_layer_0___input_proj"
type: "fc"
size: 512
active_type: "tanh"
inputs {
input_layer_name: "input"
input_parameter_name: "___gated_unit_layer_0___input_proj.w0"
}
bias_parameter_name: "___gated_unit_layer_0___input_proj.wbias"
error_clipping_threshold: 100.0
}
layers {
name: "__gated_unit_layer_0___gate"
type: "fc"
size: 512
active_type: "sigmoid"
inputs {
input_layer_name: "input"
input_parameter_name: "___gated_unit_layer_0___gate.w0"
}
bias_parameter_name: "___gated_unit_layer_0___gate.wbias"
error_clipping_threshold: 100.0
}
layers {
name: "__gated_unit_layer_0___gated_act"
type: "mixed"
size: 512
active_type: ""
inputs {
input_layer_name: "__gated_unit_layer_0___input_proj"
}
inputs {
input_layer_name: "__gated_unit_layer_0___gate"
}
error_clipping_threshold: 100.0
operator_confs {
type: "dot_mul"
input_indices: 0
input_indices: 1
input_sizes: 512
input_sizes: 512
output_size: 512
dotmul_scale: 1
}
}
parameters {
name: "___gated_unit_layer_0___input_proj.w0"
size: 131072
initial_mean: 0.0
initial_std: 0.0001
dims: 256
dims: 512
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___gated_unit_layer_0___input_proj.wbias"
size: 512
initial_mean: 0.0
initial_std: 1
dims: 1
dims: 512
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___gated_unit_layer_0___gate.w0"
size: 131072
initial_mean: 0.0
initial_std: 0.0001
dims: 256
dims: 512
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___gated_unit_layer_0___gate.wbias"
size: 512
initial_mean: 0.0
initial_std: 1
dims: 1
dims: 512
initial_strategy: 0
initial_smart: false
}
input_layer_names: "input"
output_layer_names: "__gated_unit_layer_0___gated_act"
sub_models {
name: "root"
layer_names: "input"
layer_names: "__gated_unit_layer_0___input_proj"
layer_names: "__gated_unit_layer_0___gate"
layer_names: "__gated_unit_layer_0___gated_act"
input_layer_names: "input"
output_layer_names: "__gated_unit_layer_0___gated_act"
is_recurrent_layer_group: false
}
python/paddle/trainer_config_helpers/tests/configs/test_gated_unit_layer.py
0 → 100644
浏览文件 @
581cf909
from
paddle.trainer_config_helpers
import
*
data
=
data_layer
(
name
=
'input'
,
size
=
256
)
glu
=
gated_unit_layer
(
size
=
512
,
input
=
data
,
act
=
TanhActivation
(),
gate_attr
=
ExtraLayerAttribute
(
error_clipping_threshold
=
100.0
),
gate_param_attr
=
ParamAttr
(
initial_std
=
1e-4
),
gate_bias_attr
=
ParamAttr
(
initial_std
=
1
),
inproj_attr
=
ExtraLayerAttribute
(
error_clipping_threshold
=
100.0
),
inproj_param_attr
=
ParamAttr
(
initial_std
=
1e-4
),
inproj_bias_attr
=
ParamAttr
(
initial_std
=
1
),
layer_attr
=
ExtraLayerAttribute
(
error_clipping_threshold
=
100.0
))
outputs
(
glu
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录