提交 8f09109a 编写于 作者: G gongweibao

fix

上级 8fa1d84d
......@@ -29,93 +29,25 @@ from .. import unique_name
from functools import reduce
__all__ = [
'fc',
'embedding',
'dynamic_lstm',
'dynamic_lstmp',
'dynamic_gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'cos_sim',
'cross_entropy',
'square_error_cost',
'chunk_eval',
'sequence_conv',
'conv2d',
'conv3d',
'sequence_pool',
'sequence_softmax',
'softmax',
'pool2d',
'pool3d',
'batch_norm',
'beam_search_decode',
'conv2d_transpose',
'conv3d_transpose',
'sequence_expand',
'sequence_expand_as',
'sequence_pad',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_prod',
'sequence_first_step',
'sequence_last_step',
'dropout',
'split',
'ctc_greedy_decoder',
'edit_distance',
'l2_normalize',
'matmul',
'topk',
'warpctc',
'sequence_reshape',
'transpose',
'im2sequence',
'nce',
'hsigmoid',
'beam_search',
'row_conv',
'multiplex',
'layer_norm',
'softmax_with_cross_entropy',
'smooth_l1',
'one_hot',
'autoincreased_step_counter',
'reshape',
'squeeze',
'unsqueeze',
'lod_reset',
'lrn',
'pad',
'pad_constant_like',
'label_smooth',
'roi_pool',
'dice_loss',
'image_resize',
'image_resize_short',
'resize_bilinear',
'gather',
'scatter',
'sequence_scatter',
'random_crop',
'mean_iou',
'relu',
'log',
'crop',
'rank_loss',
'prelu',
'flatten',
'sequence_mask',
'stack',
'pad2d',
'unstack',
'sequence_enumerate',
'expand',
'sequence_concat',
'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru',
'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy',
'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d', 'conv3d',
'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', 'pool3d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'conv3d_transpose',
'sequence_expand', 'sequence_expand_as', 'sequence_pad', 'lstm_unit',
'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'reduce_prod',
'sequence_first_step', 'sequence_last_step', 'dropout', 'split',
'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'topk',
'warpctc', 'sequence_reshape', 'transpose', 'im2sequence', 'nce',
'hsigmoid', 'beam_search', 'row_conv', 'multiplex', 'layer_norm',
'softmax_with_cross_entropy', 'smooth_l1', 'one_hot',
'autoincreased_step_counter', 'reshape', 'squeeze', 'unsqueeze',
'lod_reset', 'lrn', 'pad', 'pad_constant_like', 'label_smooth', 'roi_pool',
'dice_loss', 'image_resize', 'image_resize_short', 'resize_bilinear',
'gather', 'scatter', 'sequence_scatter', 'random_crop', 'mean_iou', 'relu',
'log', 'crop', 'rank_loss', 'prelu', 'flatten', 'sequence_mask', 'stack',
'pad2d', 'unstack', 'sequence_enumerate', 'expand', 'sequence_concat',
'uniform_random_batch_size_like'
]
......@@ -6234,3 +6166,54 @@ def expand(x, expand_times, name=None):
outputs={'Out': out},
attrs={'expand_times': expand_times})
return out
from paddle.fluid.framework import convert_np_dtype_to_dtype_
def uniform_random_batch_size_like(input,
shape,
dtype='float32',
input_dim_idx=0,
output_dim_idx=0,
min=-1.0,
max=1.0,
seed=0):
"""
UniformRandomBatchSizeLike operator.
This operator initializes a tensor with the same batch_size as the Input tensor with random values sampled from a uniform distribution.
Args:
input (Variable): Tensor whose input_dim_idx'th dimension specifies the batch_size.
shape (tuple|list): the shape of the output.
input_dim_idx (Int): The index of input's batch size dimension.
output_dim_idx (Int): The index of output's batch size dimension.
min (Float): Minimum value of uniform random.
max (Float): Maximum value of uniform random.
seed (Int): Random seed used for generating samples. 0 means use a seed generated by the system.
Note that if seed is not 0, this operator will always generate the same random numbers every time.
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
Returns:
output(Variable): Output of this operator.
"""
helper = LayerHelper('uniform_random_batch_size_like', **locals())
out = helper.create_tmp_variable(dtype)
c_dtype = convert_np_dtype_to_dtype_(dtype)
helper.append_op(
type='uniform_random_batch_size_like',
inputs={'Input': input},
outputs={'Out': out},
attrs={
'shape': shape,
'input_dim_idx': input_dim_idx,
'output_dim_idx': output_dim_idx,
'min': min,
'max': max,
'seed': seed,
'dtype': c_dtype
})
return out
......@@ -62,7 +62,6 @@ __all__ = [
'logical_or',
'logical_xor',
'logical_not',
'uniform_random_batch_size_like',
'gaussian_random',
'sampling_id',
'gaussian_random_batch_size_like',
......
......@@ -252,6 +252,9 @@ class OpTest(unittest.TestCase):
block = program.global_block()
self._append_ops(block)
from paddle.fluid.transpiler.details import program_to_code
program_to_code(program)
inputs = self._get_inputs(block)
outputs = self._get_outputs(block)
feed_map = self.feed_var(inputs, place)
......
......@@ -596,6 +596,14 @@ class TestBook(unittest.TestCase):
out = layers.expand(x, [1, 2])
print(str(program))
def test_uniform_random_batch_size_like(self):
program = Program()
with program_guard(program):
input = layers.data(
name="input", shape=[500, 2000], dtype='float32')
out = layers.uniform_random_batch_size_like(input, [-1, 2000])
self.assertIsNotNone(out)
if __name__ == '__main__':
unittest.main()
......@@ -23,7 +23,7 @@ from paddle.fluid.proto import framework_pb2
from paddle.fluid.framework import OpProtoHolder, Variable
from paddle.fluid.layer_helper import LayerHelper
g_filer_attrs = ['op_role', 'op_role_var', 'op_namescope', 'dtype']
g_filer_attrs = ['op_role', 'op_role_var', 'op_namescope']
def _convert_(name):
......@@ -46,7 +46,7 @@ def _get_inputs(op_type):
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
inputs = dict()
for ipt in op_proto.inputs:
inputs[ipt.name] = ""
inputs[ipt.name] = ipt.comment
return inputs
......@@ -60,6 +60,34 @@ def _get_outputs(op_type):
return outputs
_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def escape_math(text):
return _two_bang_pattern_.sub(
r'$$\1$$',
_single_dollar_pattern_.sub(r':math:`\1`',
_two_dollar_pattern_.sub(r"!!\1!!", text)))
def get_comment(op_type):
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
comment_lines = op_proto.comment.split("\n")
comment = ""
for line in comment_lines:
line = line.strip()
if len(line) != 0:
comment += escape_math(line)
comment += " "
elif len(comment) != 0:
comment += "\n "
return comment
def _get_attrs(op_type):
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
return op_proto.attrs
......@@ -77,14 +105,14 @@ def get_input_comments(op_type, indent=2):
ret = ""
inputs = _get_inputs(op_type)
for t in inputs:
ret += get_indent_space(2) + "input(${%s_type}): ${%s_comment}\n" % (
_convert_(t), _convert_(t))
ret += get_indent_space(2) + "%s (Type): %s\n" % (_convert_(t),
inputs[t])
for t in _get_attrs(op_type):
if t.name in g_filer_attrs:
continue
ret += get_indent_space(2) + "input(${%s_type}): ${%s_comment}\n" % (
_convert_(t.name), _convert_(t.name))
ret += get_indent_space(2) + "%s (%s): %s\n" % (
_convert_(t.name), t.type, _convert_(t.comment))
return ret
......@@ -122,7 +150,7 @@ def get_inputs(op_type):
ret = "inputs={"
inputs = _get_inputs(op_type)
for t in inputs:
ret += "{}={},".format(t, _convert_(t))
ret += "'{}': {},".format(t, _convert_(t))
ret = ret.strip(",")
ret += "}"
......@@ -132,39 +160,11 @@ def get_inputs(op_type):
return ret
"""
def get_input_dtype(op_type):
dtype = None
for ipt in _get_inputs():
name = _convert_(ipt.name)
val = kwargs.pop(name, [])
if not isinstance(val, list) and not isinstance(val, tuple):
val = [val]
if len(val) == 0:
val = [args[0]]
args = args[1:]
for each in val:
if not isinstance(each, Variable):
raise ValueError("input of {0} must be variable".format(
op_type))
if dtype is None:
dtype = each.dtype
elif dtype != each.dtype:
raise ValueError(
"operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype))
return dtype
"""
def get_outputs(op_type):
ret = "outputs={"
inputs = _get_outputs(op_type)
for t in inputs:
ret += "{}={},".format(t, _convert_(t))
ret += "'{}': {},".format(t, _convert_(t))
ret = ret.strip(",")
ret += "}"
......@@ -174,44 +174,13 @@ def get_outputs(op_type):
return ret
"""
attr_names = sorted(op.attr_names)
attrs_str = ""
for i in range(0, len(attr_names)):
name = attr_names[i]
attr_type = op.desc.attr_type(name)
if attr_type == core.AttrType.BLOCK:
a = "{name} = block[{value}]".format(
name=name, type=attr_type, value=op.block_attr_id(name))
attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
continue
if attr_type == core.AttrType.BLOCKS:
a = "{name} = blocks{value}".format(
name=name, type=attr_type, value=op.blocks_attr_ids(name))
attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
continue
a = "{name} = {value}".format(
name=name, type=attr_type, value=op.desc.attr(name))
attrs_str += a
if i != len(attr_names) - 1:
attrs_str += ", "
"""
def get_attrs(op_type):
ret = "attrs={"
for t in _get_attrs(op_type):
if t.name in g_filer_attrs:
continue
ret += "%s=%s," % (t.name, _convert_(t.name))
ret += "'%s': %s," % (t.name, _convert_(t.name))
ret = ret.strip(",")
ret += "}"
......@@ -220,12 +189,13 @@ def get_attrs(op_type):
def get_outvars(op_type, indent=1):
inputs = _get_inputs(op_type)
ret = ""
for t in _get_outputs(op_type):
ret += get_indent_space(
indent
) + "%s = helper.create_tmp_variable(dtype=helper.input_dtype())\n" % (
_convert_(t))
) + "%s = helper.create_tmp_variable(dtype=helper.input_dtype('%s'))\n" % (
(_convert_(t), list(inputs)[0]))
ret = ret.strip('\n')
return ret
......@@ -238,17 +208,15 @@ def get_op_py(op_type):
outputs = get_outputs(op_type)
attrs = get_attrs(op_type)
out_vars = get_outvars(op_type)
comment = get_comment(op_type)
code = """
@templatedoc()
def {op_type}({args}):
\"\"\"
{op_type}
{comment}
Args:
{input_comments}
Returns:
{output_comments}
\"\"\"
......@@ -263,7 +231,7 @@ def {op_type}({args}):
return out
""".format(
comment="${comment}",
comment=comment,
input_comments=input_comments.strip('\n'),
output_comments=output_comments,
args=args,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册