Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9328c3cf
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9328c3cf
编写于
6月 11, 2018
作者:
Y
Yu Yang
提交者:
GitHub
6月 11, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11308 from reyoung/feature/polish_api_ref
Simplize API Reference Documentation
上级
17b42fc2
dd26329b
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
152 addition
and
123 deletion
+152
-123
paddle/fluid/operators/batch_size_like.h
paddle/fluid/operators/batch_size_like.h
+7
-7
paddle/fluid/operators/bilinear_interp_op.cc
paddle/fluid/operators/bilinear_interp_op.cc
+5
-6
paddle/fluid/operators/fill_constant_batch_size_like_op.cc
paddle/fluid/operators/fill_constant_batch_size_like_op.cc
+7
-7
paddle/fluid/operators/linear_chain_crf_op.cc
paddle/fluid/operators/linear_chain_crf_op.cc
+0
-2
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+4
-11
paddle/fluid/operators/max_sequence_len_op.cc
paddle/fluid/operators/max_sequence_len_op.cc
+9
-4
python/paddle/fluid/layers/control_flow.py
python/paddle/fluid/layers/control_flow.py
+12
-16
python/paddle/fluid/layers/io.py
python/paddle/fluid/layers/io.py
+28
-1
python/paddle/fluid/layers/layer_function_generator.py
python/paddle/fluid/layers/layer_function_generator.py
+29
-11
python/paddle/fluid/layers/learning_rate_scheduler.py
python/paddle/fluid/layers/learning_rate_scheduler.py
+19
-17
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+15
-8
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+17
-33
未找到文件。
paddle/fluid/operators/batch_size_like.h
浏览文件 @
9328c3cf
...
...
@@ -54,18 +54,18 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
class
BatchSizeLikeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
final
{
AddInput
(
"Input"
,
"(Tensor) Tensor "
"
whose input_dim_idx'th dimension specifies the batch_size"
);
AddInput
(
"Input"
,
"Tensor
whose input_dim_idx'th dimension specifies the batch_size"
);
AddOutput
(
"Out"
,
"
(Tensor)
Tensor of specified shape will be filled "
"Tensor of specified shape will be filled "
"with the specified value"
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"
(vector<int>)
The shape of the output"
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"The shape of the output"
);
AddAttr
<
int
>
(
"input_dim_idx"
,
"
(int, default 0)
The index of input's batch size dimension"
)
"
default 0.
The index of input's batch size dimension"
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"output_dim_idx"
,
"
(int, default 0)
The index of output's batch size dimension"
)
"
default 0.
The index of output's batch size dimension"
)
.
SetDefault
(
0
);
Apply
();
}
...
...
paddle/fluid/operators/bilinear_interp_op.cc
浏览文件 @
9328c3cf
...
...
@@ -56,17 +56,16 @@ class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void
Make
()
override
{
AddInput
(
"X"
,
"
(Tensor)
The input tensor of bilinear interpolation, "
"The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)"
);
AddInput
(
"OutSize"
,
"
(Tensor)
This is a 1-D tensor with two number. "
"This is a 1-D tensor with two number. "
"The first number is height and the second number is width."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor) The dimension of output is (N x C x out_h x out_w]"
);
AddOutput
(
"Out"
,
"The dimension of output is (N x C x out_h x out_w)"
);
AddAttr
<
int
>
(
"out_h"
,
"
(int)
output height of bilinear interpolation op."
);
AddAttr
<
int
>
(
"out_w"
,
"
(int)
output width of bilinear interpolation op."
);
AddAttr
<
int
>
(
"out_h"
,
"output height of bilinear interpolation op."
);
AddAttr
<
int
>
(
"out_w"
,
"output width of bilinear interpolation op."
);
AddComment
(
R"DOC(
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
...
...
paddle/fluid/operators/fill_constant_batch_size_like_op.cc
浏览文件 @
9328c3cf
...
...
@@ -32,16 +32,16 @@ class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
class
FillConstantBatchSizeLikeOpMaker
:
public
BatchSizeLikeOpMaker
{
protected:
void
Apply
()
override
{
AddAttr
<
int
>
(
"dtype"
,
"(int, default 5 (FP32)) "
"Output data type
"
)
AddAttr
<
int
>
(
"dtype"
,
"It could be numpy.dtype. Output data type. Default is float32
"
)
.
SetDefault
(
framework
::
proto
::
VarType
::
FP32
);
AddAttr
<
float
>
(
"value"
,
"
(float, default 0)
The value to be filled"
)
AddAttr
<
float
>
(
"value"
,
"
default 0.
The value to be filled"
)
.
SetDefault
(
0.0
f
);
AddComment
(
R"DOC(
FillConstantBatchSizeLike Operator.
Fill up a variable with specified constant value
.
This function creates a tensor of specified *shape*, *dtype* and batch size,
and initializes this with a constant supplied in *value*. The batch size is
obtained from the `input` tensor
.
)DOC"
);
}
...
...
paddle/fluid/operators/linear_chain_crf_op.cc
浏览文件 @
9328c3cf
...
...
@@ -67,8 +67,6 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
"mini-batch. Note: S is equal to the sequence number in a mini-batch. "
"The output is no longer a LoDTensor."
);
AddComment
(
R"DOC(
LinearChainCRF Operator.
Conditional Random Field defines an undirected probabilistic graph with nodes
denoting random variables and edges denoting dependencies between these
variables. CRF learns the conditional probability $P(Y|X)$, where
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
9328c3cf
...
...
@@ -74,25 +74,18 @@ class LoadOp : public framework::OperatorBase {
class
LoadOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddOutput
(
"Out"
,
"
(Tensor)
The tensor need to be loaded"
);
AddOutput
(
"Out"
,
"The tensor need to be loaded"
);
AddAttr
<
bool
>
(
"load_as_fp16"
,
"(boolean, default false)"
"If true, the tensor will be first loaded and then "
"converted to float16 data type. Otherwise, the tensor will be "
"directly loaded without data type conversion."
)
"directly loaded without data type conversion.
Default is false.
"
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"file_path"
,
"(string) "
"Variable will be loaded from
\"
file_path
\"
."
)
R"(Variable will be loaded from "file_path")"
)
.
AddCustomChecker
(
[](
const
std
::
string
&
path
)
{
return
!
path
.
empty
();
});
AddComment
(
R"DOC(
Load Operator.
Load operator will load a tensor variable from disk file.
)DOC"
);
AddComment
(
"Load operator will load a tensor variable from disk file."
);
}
};
}
// namespace operators
...
...
paddle/fluid/operators/max_sequence_len_op.cc
浏览文件 @
9328c3cf
...
...
@@ -42,10 +42,15 @@ class MaxSeqenceLenOp : public framework::OperatorBase {
class
MaxSeqenceLenOpProtoMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"RankTable"
,
"The lod_rank_table."
);
AddOutput
(
"Out"
,
"The max sequence length."
);
AddComment
(
R"DOC(Calculate the max sequence length through lod_rank_table.)DOC"
);
AddInput
(
"RankTable"
,
"Input variable which is a LoDRankTable object"
);
AddOutput
(
"Out"
,
"The max sequence length"
);
AddComment
(
R"DOC(
Given a LoDRankTable object, this layer returns the max length of
a batch of sequences. In fact, a LoDRankTable object contains a list of
tuples(<sequence index, sequence length>) and the list is already sorted by
sequence length in descending order, so the operator just returns the
sequence length of the first tuple element
)DOC"
);
}
};
...
...
python/paddle/fluid/layers/control_flow.py
浏览文件 @
9328c3cf
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
import
contextlib
from
layer_function_generator
import
autodoc
from
layer_function_generator
import
autodoc
,
templatedoc
from
tensor
import
assign
,
fill_constant
from
..
import
core
from
..framework
import
Program
,
Variable
,
Operator
...
...
@@ -721,26 +721,22 @@ def lod_rank_table(x, level=0):
return
table
@
templatedoc
()
def
max_sequence_len
(
rank_table
):
"""Max Sequence Len Operator. Given a LoDRankTable object, this layer
returns the max length of a batch of sequences. In fact, a LoDRankTable
object contains a list of tuples(<sequence index, sequence length>) and
the list is already sorted by sequence length in descending order, so the
operator just returns the sequence length of the first tuple element.
"""
${comment}
>>> import paddle.fluid as fluid
>>> x = fluid.layers.data(name='x', shape=[10], dtype='float32',
>>> lod_level=1)
>>> rank_table = layers.lod_rank_table(x=x, level=0)
>>> max_seq_len = layers.max_sequence_len(rank_table)
Args:
rank_table
(Variable): Input variable which is a LoDRankTable object
.
rank_table
(${rank_table_type}): ${rank_table_comment}
.
Returns:
Variable: The max length of sequence.
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[10],
dtype='float32', lod_level=1)
rank_table = layers.lod_rank_table(x=x, level=0)
max_seq_len = layers.max_sequence_len(rank_table)
${out_comment}.
"""
helper
=
LayerHelper
(
"max_seqence_len"
,
**
locals
())
res
=
helper
.
create_tmp_variable
(
dtype
=
"int64"
)
...
...
python/paddle/fluid/layers/io.py
浏览文件 @
9328c3cf
...
...
@@ -19,11 +19,12 @@ from ..unique_name import generate as unique_name
from
control_flow
import
BlockGuard
from
..layer_helper
import
LayerHelper
from
..executor
import
global_scope
from
layer_function_generator
import
generate_layer_fn
,
templatedoc
__all__
=
[
'data'
,
'BlockGuardServ'
,
'ListenAndServ'
,
'Send'
,
'open_recordio_file'
,
'open_files'
,
'read_file'
,
'shuffle'
,
'batch'
,
'double_buffer'
,
'random_data_generator'
,
'Preprocessor'
'random_data_generator'
,
'Preprocessor'
,
'load'
]
...
...
@@ -662,3 +663,29 @@ class Preprocessor(object):
"sink_var_names"
:
self
.
sink_var_names
})
return
monkey_patch_reader_methods
(
self
.
reader
)
@
templatedoc
()
def
load
(
out
,
file_path
,
load_as_fp16
=
None
):
"""
${comment}
>>> import paddle.fluid as fluid
>>> tmp_tensor = fluid.layers.create_tensor(dtype='float32')
>>> fluid.layers.load(tmp_tensor, "./tmp_tensor.bin")
Args:
out(${out_type}): ${out_comment}.
file_path(${file_path_type}): ${file_path_comment}.
load_as_fp16(${load_as_fp16_type}): ${load_as_fp16_comment}.
Returns:
None
"""
helper
=
LayerHelper
(
"load"
,
**
locals
())
attrs
=
{
"file_path"
:
file_path
}
if
load_as_fp16
is
not
None
:
attrs
[
'load_as_fp16'
]
=
load_as_fp16
helper
.
append_op
(
type
=
"load"
,
inputs
=
{},
output
=
{
"Out"
:
out
},
args
=
attrs
)
python/paddle/fluid/layers/layer_function_generator.py
浏览文件 @
9328c3cf
...
...
@@ -224,7 +224,10 @@ def autodoc(comment=""):
return
__impl__
def
templatedoc
():
_inline_math_single_dollar
=
re
.
compile
(
r
"\$([^\$]+)\$"
)
def
templatedoc
(
op_type
=
None
):
"""
Decorator of layer function. It will use the docstring from the layer
function as the template. The template arguments are:
...
...
@@ -238,32 +241,47 @@ def templatedoc():
Decorated function.
"""
def
trim_ending_dot
(
msg
):
return
msg
.
rstrip
(
'.'
)
def
escape_inline_math
(
msg
):
return
_inline_math_single_dollar
.
sub
(
repl
=
r
':math:`\1`'
,
string
=
msg
)
def
__impl__
(
func
):
op_proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
func
.
__name__
)
if
op_type
is
None
:
op_type_name
=
func
.
__name__
else
:
op_type_name
=
op_type
op_proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
op_type_name
)
tmpl
=
string
.
Template
(
func
.
__doc__
)
comment_lines
=
op_proto
.
comment
.
split
(
"
\n
"
)
comment
=
""
for
line
in
comment_lines
:
line
=
line
.
lstrip
()
comment
+=
line
comment
+=
"
\n
"
args
=
{
"comment"
:
comment
}
line
=
line
.
strip
()
if
len
(
line
)
!=
0
:
comment
+=
escape_inline_math
(
line
)
comment
+=
" "
elif
len
(
comment
)
!=
0
:
comment
+=
"
\n
\n
"
args
=
{
"comment"
:
trim_ending_dot
(
comment
)}
for
each_input
in
op_proto
.
inputs
:
input_name
=
_convert_
(
each_input
.
name
)
args
[
"{0}_comment"
.
format
(
input_name
)]
=
each_input
.
comment
args
[
"{0}_comment"
.
format
(
input_name
)]
=
trim_ending_dot
(
each_input
.
comment
)
args
[
"{0}_type"
.
format
(
input_name
)]
=
"Variable"
for
each_attr
in
op_proto
.
attrs
:
input_name
=
_convert_
(
each_attr
.
name
)
args
[
"{0}_comment"
.
format
(
input_name
)]
=
each_attr
.
comment
args
[
"{0}_comment"
.
format
(
input_name
)]
=
trim_ending_dot
(
each_attr
.
comment
)
args
[
"{0}_type"
.
format
(
input_name
)]
=
_type_to_str_
(
each_attr
.
type
)
for
each_opt
in
op_proto
.
outputs
:
output_name
=
_convert_
(
each_opt
.
name
)
args
[
"{0}_comment"
.
format
(
output_name
)]
=
each_opt
.
comment
args
[
"{0}_comment"
.
format
(
output_name
)]
=
trim_ending_dot
(
each_opt
.
comment
)
args
[
"{0}_type"
.
format
(
output_name
)]
=
"Variable"
func
.
__doc__
=
tmpl
.
substitute
(
args
)
return
func
...
...
python/paddle/fluid/layers/learning_rate_scheduler.py
浏览文件 @
9328c3cf
...
...
@@ -11,6 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
import
control_flow
import
nn
...
...
@@ -22,14 +30,6 @@ __all__ = [
'exponential_decay'
,
'natural_exp_decay'
,
'inverse_time_decay'
,
'polynomial_decay'
,
'piecewise_decay'
,
'noam_decay'
]
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
def
_decay_step_counter
(
begin
=
0
):
...
...
@@ -41,18 +41,20 @@ def _decay_step_counter(begin=0):
def
noam_decay
(
d_model
,
warmup_steps
):
"""Apply decay to learning rate.
```python
lr_value = np.power(d_model, -0.5) * np.min([
np.power(current_steps, -0.5),
np.power(warmup_steps, -1.5) * current_steps
])
```
"""
Noam decay method. The numpy implementation of noam decay as follows.
>>> import numpy as np
>>> lr_value = np.power(d_model, -0.5) * np.min([
>>> np.power(current_steps, -0.5),
>>> np.power(warmup_steps, -1.5) * current_steps])
Please reference `attention is all you need
<https://arxiv.org/pdf/1706.03762.pdf>`_.
Args:
d_model(Variable): The dimensionality of input and output of model.
Reference: attention is all you need
https://arxiv.org/pdf/1706.03762.pdf
warmup_steps(Variable): A super parameter.
Returns:
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
9328c3cf
...
...
@@ -4037,18 +4037,25 @@ def image_resize(input,
return
out
@
templatedoc
(
op_type
=
"bilinear_interp"
)
def
resize_bilinear
(
input
,
out_shape
=
None
,
scale
=
None
,
name
=
None
):
"""
This is an alias of layer 'image_resize' with bilinear interpolation.
${comment}
Args:
input(${x_type}): ${x_comment}.
out_shape(${out_size_type}): ${out_size_comment}.
The mathematical meaning of resize bilinear layer is
Bilinear interpolation.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this layer) on a rectilinear 2D grid.
scale(float|None): The multiplier for the input height or width. At
least one of out_shape or scale must be set. And out_shape has
a higher priority than scale. Default: None.
name(str|None): The output variable name.
Returns:
For details, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
${out_comment}.
"""
return
image_resize
(
input
,
out_shape
,
scale
,
name
,
'BILINEAR'
)
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
9328c3cf
...
...
@@ -18,6 +18,7 @@ from ..framework import convert_np_dtype_to_dtype_
from
..framework
import
Variable
from
..initializer
import
Constant
,
force_init_on_cpu
from
..core
import
VarDesc
from
layer_function_generator
import
templatedoc
import
numpy
__all__
=
[
...
...
@@ -266,6 +267,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
return
out
@
templatedoc
()
def
fill_constant_batch_size_like
(
input
,
shape
,
dtype
,
...
...
@@ -273,30 +275,28 @@ def fill_constant_batch_size_like(input,
input_dim_idx
=
0
,
output_dim_idx
=
0
):
"""
**fill_constant_batch_size_like**
This function creates a tensor of specified *shape*, *dtype* and batch size,
and initializes this with a constant supplied in *value*. The batch size is
obtained from the `input` tensor.
${comment}
It also sets *stop_gradient* to True.
>>> data = fluid.layers.fill_constant_batch_size_like(
>>> input=like, shape=[1], value=0, dtype='int64')
Args:
input(Variable): Tensor whose dimensions will be used to get batch size
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor
value(float): Constant value to initialize the output tensor
input_dim_idx(int): Index of input's batch size dimension
output_dim_idx(int): Index of output's batch size dimension
input(${input_type}): ${input_comment}.
Returns:
Variable: The tensor variable storing the output
shape(${shape_type}): ${shape_comment}.
Examples:
.. code-block:: python
dtype(${dtype_type}): ${dtype_comment}.
value(${value_type}): ${value_comment}.
data = fluid.layers.fill_constant_batch_size_like(
input=like, shape=[1], value=0, dtype='int64')
input_dim_idx(${input_dim_idx_type}): ${input_dim_idx_comment}.
output_dim_idx(${output_dim_idx_type}): ${output_dim_idx_comment}.
Returns:
${out_comment}.
"""
helper
=
LayerHelper
(
"fill_constant_batch_size_like"
,
**
locals
())
out
=
helper
.
create_tmp_variable
(
dtype
=
dtype
)
...
...
@@ -437,22 +437,6 @@ def save_combine(x, file_path, overwrite=True):
"overwrite"
:
overwrite
})
def
load
(
out
,
file_path
):
"""
Loads a variable from a given file.
Args:
out(variable): The variable to be read from the disk file.
file_path(str): The path of the disk file.
"""
helper
=
LayerHelper
(
"load"
,
**
locals
())
helper
.
append_op
(
type
=
"load"
,
inputs
=
{},
output
=
{
"Out"
:
out
},
args
=
{
"file_path"
:
file_path
})
def
load_combine
(
out
,
file_path
):
"""
Loads a list of vairables from a single file.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录