diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 1ce2cf83e8dcbcc3eb18ea3b7099dfb257fc8158..d27289ed6d14a08895e33675781bcff986a3686b 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -170,6 +170,14 @@ paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, key paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.expand ArgSpec(args=['x', 'expand_times', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_concat ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.scale ArgSpec(args=['x', 'scale', 'bias', 'bias_after_scale'], varargs=None, keywords=None, defaults=(1.0, 0.0, True)) +paddle.fluid.layers.elementwise_add ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn'], varargs=None, keywords=None, defaults=(-1, False)) +paddle.fluid.layers.elementwise_div ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn'], varargs=None, keywords=None, defaults=(-1, False)) +paddle.fluid.layers.elementwise_sub ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn'], varargs=None, keywords=None, defaults=(-1, False)) +paddle.fluid.layers.elementwise_mul ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn'], varargs=None, keywords=None, defaults=(-1, False)) +paddle.fluid.layers.elementwise_max ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn'], varargs=None, keywords=None, defaults=(-1, False)) +paddle.fluid.layers.elementwise_min ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn'], varargs=None, keywords=None, defaults=(-1, False)) +paddle.fluid.layers.elementwise_pow ArgSpec(args=['x', 'y', 'axis', 'use_mkldnn'], varargs=None, keywords=None, defaults=(-1, False)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) @@ -234,15 +242,7 @@ paddle.fluid.layers.Print ArgSpec(args=['input', 'first_n', 'message', 'summariz paddle.fluid.layers.is_empty ArgSpec(args=['x', 'cond'], varargs=None, keywords='ignored', defaults=(None,)) paddle.fluid.layers.mean ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.mul ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.scale ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.elementwise_add ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.elementwise_div ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.elementwise_sub ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.elementwise_mul ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.elementwise_max ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.elementwise_min ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) -paddle.fluid.layers.elementwise_pow ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.clip ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.clip_by_norm ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.logical_and ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 13be6c65be58314a75124106eb09b1300305baf0..bf4df4f600c14050b636b7ee6d7b6973b57adb94 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -46,9 +46,15 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( **Scale operator** -Multiply the input tensor with a float scalar to scale the input tensor. +Apply scaling and bias addition to the input tensor. -$$Out = scale*X$$ +if bias_after_scale=True: + +$$Out = scale*X + bias$$ + +else: + +$$Out = scale*(X + bias)$$ )DOC"); AddAttr("scale", "The scaling factor of the scale operator.") .SetDefault(1.0); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 0abbb6815123f8ba65b637b3f3accef91fe66ef8..59ffc5c8a104473e30ea12337a3e7e7b4ceb5387 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -489,7 +489,8 @@ class OpProtoHolder(object): def generated_op_attr_names(): return { core.op_proto_and_checker_maker.kOpRoleAttrName(), - core.op_proto_and_checker_maker.kOpRoleVarAttrName() + core.op_proto_and_checker_maker.kOpRoleVarAttrName(), + core.op_proto_and_checker_maker.kOpNameScopeAttrName() } diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index 8963d74de014d69c590276d5ff7080111f614230..00d0b7e6083e64fd791af0d9e0ee29c5bffdf809 100644 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -58,7 +58,7 @@ def escape_math(text): _two_dollar_pattern_.sub(r"!!\1!!", text))) -def _generate_doc_string_(op_proto): +def _generate_doc_string_(op_proto, additional_args_lines=None): """ Generate docstring by OpProto @@ -98,6 +98,13 @@ def _generate_doc_string_(op_proto): buf.write(escape_math(each_attr.comment)) buf.write('\n') + if additional_args_lines is not None: + for line in additional_args_lines: + line = line.strip() + buf.write(' ') + buf.write(line) + buf.write('\n') + if len(op_proto.outputs) != 0: buf.write('\nReturns:\n') buf.write(' ') diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index be368007dd7061ba7fc97414dbadfce00d158776..7c6cb932a3767efd6a3e8d5ad1ebfb2badf6211e 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -67,7 +67,7 @@ def noam_decay(d_model, warmup_steps): a = global_step**-0.5 b = (warmup_steps**-1.5) * global_step - lr_value = (d_model**-0.5) * ops.elementwise_min(a, b) + lr_value = (d_model**-0.5) * nn.elementwise_min(a, b) return lr_value @@ -234,7 +234,7 @@ def polynomial_decay(learning_rate, else: decay_steps_var = tensor.fill_constant( shape=[1], dtype='float32', value=float(decay_steps)) - global_step = ops.elementwise_min(x=global_step, y=decay_steps_var) + global_step = nn.elementwise_min(x=global_step, y=decay_steps_var) decayed_lr = (learning_rate - end_learning_rate) * \ ((1 - global_step / decay_steps) ** power) + end_learning_rate diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f896cfa04b3e0d89daaa1bd7fd893b5892a09a4e..4b696b06baa5f4df29edb8d7aefd7c75d3ca765f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -20,9 +20,9 @@ from __future__ import print_function import numpy as np from ..layer_helper import LayerHelper from ..initializer import Normal, Constant -from ..framework import Variable +from ..framework import Variable, OpProtoHolder from ..param_attr import ParamAttr -from .layer_function_generator import autodoc, templatedoc +from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ from .tensor import concat from . import utils from .. import unique_name @@ -116,6 +116,14 @@ __all__ = [ 'sequence_enumerate', 'expand', 'sequence_concat', + 'scale', + 'elementwise_add', + 'elementwise_div', + 'elementwise_sub', + 'elementwise_mul', + 'elementwise_max', + 'elementwise_min', + 'elementwise_pow', ] @@ -3605,7 +3613,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): attrs={ 'transpose_X': transpose_x, 'transpose_Y': transpose_y, - 'alpha': alpha, + 'alpha': float(alpha), }) return out @@ -6234,3 +6242,98 @@ def expand(x, expand_times, name=None): outputs={'Out': out}, attrs={'expand_times': expand_times}) return out + + +def _elementwise_op(helper): + op_type = helper.layer_type + x = helper.kwargs.get('x', None) + y = helper.kwargs.get('y', None) + assert x is not None, 'x cannot be None in {}'.format(op_type) + assert y is not None, 'y cannot be None in {}'.format(op_type) + axis = helper.kwargs.get('axis', -1) + use_mkldnn = helper.kwargs.get('use_mkldnn', False) + out = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type=op_type, + inputs={'X': x, + 'Y': y}, + outputs={'Out': out}, + attrs={'axis': axis, + 'use_mkldnn': use_mkldnn}) + return helper.append_activation(out) + + +@templatedoc() +def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, name=None): + """ + ${comment} + + Args: + x(${x_type}): ${x_comment} + scale(${scale_type}): ${scale_comment} + bias(${bias_type}): ${bias_comment} + bias_after_scale(${bias_after_scale_type}): ${bias_after_scale_comment} + name(basestring|None): Name of the output. + + Returns: + out(${out_type}): ${out_comment} + """ + + helper = LayerHelper('scale', **locals()) + out = helper.create_tmp_variable(dtype=x.dtype) + if name is None: + out = helper.create_tmp_variable(dtype=x.dtype) + else: + out = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type='scale', + inputs={'X': x}, + outputs={'Out': out}, + attrs={ + 'scale': float(scale), + 'bias': float(bias), + 'bias_after_scale': bias_after_scale + }) + return out + + +def elementwise_add(x, y, axis=-1, use_mkldnn=False, act=None): + return _elementwise_op(LayerHelper('elementwise_add', **locals())) + + +def elementwise_div(x, y, axis=-1, use_mkldnn=False, act=None): + return _elementwise_op(LayerHelper('elementwise_div', **locals())) + + +def elementwise_sub(x, y, axis=-1, use_mkldnn=False, act=None): + return _elementwise_op(LayerHelper('elementwise_sub', **locals())) + + +def elementwise_mul(x, y, axis=-1, use_mkldnn=False, act=None): + return _elementwise_op(LayerHelper('elementwise_mul', **locals())) + + +def elementwise_max(x, y, axis=-1, use_mkldnn=False, act=None): + return _elementwise_op(LayerHelper('elementwise_max', **locals())) + + +def elementwise_min(x, y, axis=-1, use_mkldnn=False, act=None): + return _elementwise_op(LayerHelper('elementwise_min', **locals())) + + +def elementwise_pow(x, y, axis=-1, use_mkldnn=False, act=None): + return _elementwise_op(LayerHelper('elementwise_pow', **locals())) + + +for func in [ + elementwise_add, elementwise_div, elementwise_sub, elementwise_mul, + elementwise_max, elementwise_min, elementwise_pow +]: + op_proto = OpProtoHolder.instance().get_op_proto(func.__name__) + func.__doc__ = _generate_doc_string_( + op_proto, + additional_args_lines=[ + "act(basestring|None): Activation to be applied to the output." + ]) diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 129252653dc139b7405626e6fd410704a4ad06d9..3e25394a05fd9e6fa1f6ac968b2d12a1c0e43c34 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -47,15 +47,7 @@ __activations__ = [ __all__ = [ 'mean', 'mul', - 'scale', 'sigmoid_cross_entropy_with_logits', - 'elementwise_add', - 'elementwise_div', - 'elementwise_sub', - 'elementwise_mul', - 'elementwise_max', - 'elementwise_min', - 'elementwise_pow', 'clip', 'clip_by_norm', 'logical_and', @@ -75,6 +67,11 @@ __all__ = [ for _OP in set(__all__): globals()[_OP] = generate_layer_fn(_OP) +# It is a hot fix in some unittest using: +# fluid.layers.scale(x=x, scale=10.0, out=out_var) +# e.g.: test_program_code.py, test_dist_train.py +globals()['_scale'] = generate_layer_fn('scale') + __all__ += ["uniform_random"] _uniform_random_ = generate_layer_fn('uniform_random') diff --git a/python/paddle/fluid/tests/unittests/test_dist_train.py b/python/paddle/fluid/tests/unittests/test_dist_train.py index 083525ccf54d389b60c4aaa9f8c6223f07c773cd..d0875d9ea442d0e88dfd958e5948b26225416df2 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_train.py +++ b/python/paddle/fluid/tests/unittests/test_dist_train.py @@ -27,6 +27,7 @@ import paddle.fluid.layers as layers from paddle.fluid.layers.io import ListenAndServ from paddle.fluid.layers.io import Recv from paddle.fluid.layers.io import Send +import paddle.fluid.layers.ops as ops from paddle.fluid import core @@ -89,7 +90,7 @@ class TestSendOp(unittest.TestCase): name="X", append_batch_size=False) fluid.initializer.Constant(value=1.0)(x, main.global_block()) - layers.scale(x=x, scale=10.0, out=out_var) + ops._scale(x=x, scale=10.0, out=out_var) self.server_exe = fluid.Executor(place) self.server_exe.run(main) diff --git a/python/paddle/fluid/tests/unittests/test_program_code.py b/python/paddle/fluid/tests/unittests/test_program_code.py index e9c2b928617dce3904ca119896ca81454256e82e..27b22ba9392b63c0ccd7904ff03d737b977cc9fc 100644 --- a/python/paddle/fluid/tests/unittests/test_program_code.py +++ b/python/paddle/fluid/tests/unittests/test_program_code.py @@ -25,6 +25,7 @@ import paddle.fluid.layers as layers from paddle.fluid.layers.io import ListenAndServ from paddle.fluid.layers.io import Recv from paddle.fluid.layers.io import Send +import paddle.fluid.layers.ops as ops from paddle.fluid.transpiler.details import program_to_code @@ -52,7 +53,7 @@ class TestProgram2Code(unittest.TestCase): name="X", append_batch_size=False) fluid.initializer.Constant(value=1.0)(x, main.global_block()) - layers.scale(x=x, scale=10.0, out=out_var) + ops._scale(x=x, scale=10.0, out=out_var) program_to_code(main)