提交 8c302d48 编写于 作者: F fengjiayi

remove kwargs in layer apis

上级 e9d30991
......@@ -829,12 +829,12 @@ def crf_decoding(input, param_attr, label=None):
return viterbi_path
def cos_sim(X, Y, **kwargs):
def cos_sim(X, Y):
"""
This function performs the cosine similarity between two tensors
X and Y and returns that as the output.
"""
helper = LayerHelper('cos_sim', **kwargs)
helper = LayerHelper('cos_sim', **locals())
out = helper.create_tmp_variable(dtype=X.dtype)
xnorm = helper.create_tmp_variable(dtype=X.dtype)
ynorm = helper.create_tmp_variable(dtype=X.dtype)
......@@ -848,7 +848,7 @@ def cos_sim(X, Y, **kwargs):
return out
def dropout(x, dropout_prob, is_test=False, seed=None, **kwargs):
def dropout(x, dropout_prob, is_test=False, seed=None):
"""
Computes dropout.
......@@ -877,7 +877,7 @@ def dropout(x, dropout_prob, is_test=False, seed=None, **kwargs):
droped = fluid.layers.dropout(input=x, dropout_rate=0.5)
"""
helper = LayerHelper('dropout', **kwargs)
helper = LayerHelper('dropout', **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
mask = helper.create_tmp_variable(dtype=x.dtype, stop_gradient=True)
helper.append_op(
......@@ -894,7 +894,7 @@ def dropout(x, dropout_prob, is_test=False, seed=None, **kwargs):
return out
def cross_entropy(input, label, **kwargs):
def cross_entropy(input, label, soft_label=False):
"""
**Cross Entropy Layer**
......@@ -903,15 +903,15 @@ def cross_entropy(input, label, **kwargs):
computation.
1) One-hot cross-entropy:
`soft_label = False`, `Label[i, 0]` indicates the class index for sample i:
`soft_label = False`, `Label[i, 0]` indicates the class index for sample i:
.. math::
Y[i] = -\log(X[i, Label[i]])
2) Soft-label cross-entropy:
`soft_label = True`, `Label[i, j]` indicates the soft label of class j
for sample i:
`soft_label = True`, `Label[i, j]` indicates the soft label of class j
for sample i:
.. math::
......@@ -921,8 +921,8 @@ def cross_entropy(input, label, **kwargs):
equals one.
3) One-hot cross-entropy with vecterized `label`:
As a special case of 2), when each row of 'label' has only one
non-zero element which is equal to 1, soft-label cross-entropy degenerates
As a special case of 2), when each row of 'label' has only one
non-zero element which is equal to 1, soft-label cross-entropy degenerates
to a one-hot cross-entropy with one-hot label representation.
Args:
......@@ -936,7 +936,7 @@ def cross_entropy(input, label, **kwargs):
tensor<int64> with shape [N x 1]. When
`soft_label` is set to `True`, `label` is a
tensor<float/double> with shape [N x D].
soft_label (bool, via `**kwargs`): a flag indicating whether to
soft_label (bool): a flag indicating whether to
interpretate the given labels as soft
labels, default `False`.
......@@ -956,18 +956,18 @@ def cross_entropy(input, label, **kwargs):
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
"""
helper = LayerHelper('cross_entropy', **kwargs)
helper = LayerHelper('cross_entropy', **locals())
out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='cross_entropy',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out]},
attrs=kwargs)
attrs={"soft_label": soft_label})
return out
def square_error_cost(input, label, **kwargs):
def square_error_cost(input, label):
"""
**Square error cost layer**
......@@ -1002,7 +1002,7 @@ def square_error_cost(input, label, **kwargs):
cost = layers.square_error_cost(input=y_predict, label=y)
"""
helper = LayerHelper('square_error_cost', **kwargs)
helper = LayerHelper('square_error_cost', **locals())
minus_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='elementwise_sub',
......@@ -1017,12 +1017,12 @@ def square_error_cost(input, label, **kwargs):
return square_out
def accuracy(input, label, k=1, correct=None, total=None, **kwargs):
def accuracy(input, label, k=1, correct=None, total=None):
"""
This function computes the accuracy using the input and label.
The output is the top_k inputs and their indices.
"""
helper = LayerHelper("accuracy", **kwargs)
helper = LayerHelper("accuracy", **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
......@@ -1055,13 +1055,12 @@ def chunk_eval(input,
label,
chunk_scheme,
num_chunk_types,
excluded_chunk_types=None,
**kwargs):
excluded_chunk_types=None):
"""
This function computes and outputs the precision, recall and
F1-score of chunk detection.
"""
helper = LayerHelper("chunk_eval", **kwargs)
helper = LayerHelper("chunk_eval", **locals())
# prepare output
precision = helper.create_tmp_variable(dtype="float32")
......@@ -1293,7 +1292,7 @@ def conv2d(input,
return helper.append_activation(pre_act)
def sequence_pool(input, pool_type, **kwargs):
def sequence_pool(input, pool_type):
"""
This function add the operator for sequence pooling.
It pools features of all time-steps of each instance, and is applied
......@@ -1343,7 +1342,7 @@ def sequence_pool(input, pool_type, **kwargs):
sqrt_x = fluid.layers.sequence_pool(input=x, pool_type='sqrt')
max_x = fluid.layers.sequence_pool(input=x, pool_type='max')
"""
helper = LayerHelper('sequence_pool', input=input, **kwargs)
helper = LayerHelper('sequence_pool', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype)
max_index = helper.create_tmp_variable(dtype)
......@@ -1363,7 +1362,7 @@ def sequence_pool(input, pool_type, **kwargs):
return pool_out
def sequence_first_step(input, **kwargs):
def sequence_first_step(input):
"""
This funciton get the first step of sequence.
......@@ -1396,7 +1395,7 @@ def sequence_first_step(input, **kwargs):
return sequence_pool(input=input, pool_type="first")
def sequence_last_step(input, **kwargs):
def sequence_last_step(input):
"""
This funciton get the last step of sequence.
......@@ -2336,7 +2335,8 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
normed = fluid.layers.l2_normalize(x=data, axis=1)
"""
if len(x.shape) == 1: axis = 0
if len(x.shape) == 1:
axis = 0
helper = LayerHelper("l2_normalize", **locals())
......@@ -2654,7 +2654,7 @@ def ctc_greedy_decoder(input, blank, name=None):
return ctc_out
def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
def warpctc(input, label, blank=0, norm_by_times=False):
"""
An operator integrating the open source Warp-CTC library
(https://github.com/baidu-research/warp-ctc)
......@@ -2695,7 +2695,7 @@ def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
cost = layers.warpctc(input=y_predict, label=y)
"""
helper = LayerHelper('warpctc', **kwargs)
helper = LayerHelper('warpctc', **locals())
loss_out = helper.create_tmp_variable(dtype=input.dtype)
grad_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册