From 8fa1d84d881a9a5a09a31644bf343d1b7343bc3d Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 20 Sep 2018 09:45:58 +0000 Subject: [PATCH] add --- tools/test_generator.py | 201 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 188 insertions(+), 13 deletions(-) diff --git a/tools/test_generator.py b/tools/test_generator.py index 15f9f7db0..399dfe78e 100644 --- a/tools/test_generator.py +++ b/tools/test_generator.py @@ -23,6 +23,8 @@ from paddle.fluid.proto import framework_pb2 from paddle.fluid.framework import OpProtoHolder, Variable from paddle.fluid.layer_helper import LayerHelper +g_filer_attrs = ['op_role', 'op_role_var', 'op_namescope', 'dtype'] + def _convert_(name): """ @@ -46,6 +48,8 @@ def _get_inputs(op_type): for ipt in op_proto.inputs: inputs[ipt.name] = "" + return inputs + def _get_outputs(op_type): op_proto = OpProtoHolder.instance().get_op_proto(op_type) @@ -53,25 +57,177 @@ def _get_outputs(op_type): for ipt in op_proto.outputs: outputs[ipt.name] = "" + return outputs + + +def _get_attrs(op_type): + op_proto = OpProtoHolder.instance().get_op_proto(op_type) + return op_proto.attrs + + +def get_indent_space(indent, space_num=4): + ret = "" + for i in range(0, indent * space_num): + ret += " " + + return ret + + +def get_input_comments(op_type, indent=2): + ret = "" + inputs = _get_inputs(op_type) + for t in inputs: + ret += get_indent_space(2) + "input(${%s_type}): ${%s_comment}\n" % ( + _convert_(t), _convert_(t)) -def get_input_comments(op_type): - return "" + for t in _get_attrs(op_type): + if t.name in g_filer_attrs: + continue + ret += get_indent_space(2) + "input(${%s_type}): ${%s_comment}\n" % ( + _convert_(t.name), _convert_(t.name)) + return ret -def get_output_comments(op_type): - return "" + +def get_output_comments(op_type, indent=2): + ret = "" + for t in _get_outputs(op_type): + ret += get_indent_space(2) + "output(${%s_type}): ${%s_comment}\n" % ( + _convert_(t), _convert_(t)) + return ret def get_func_args(op_type): - return "" + ret = "" + inputs = _get_inputs(op_type) + for t in inputs: + ret += "%s," % _convert_(t) + + for t in _get_attrs(op_type): + if t.name in g_filer_attrs: + continue + + default = re.findall("\(.+\, default (.+)\(?\)", t.comment) + if len(default) > 0: + #print(default[0]) + ret += "{}={},".format(_convert_(t.name), default[0]) + continue + + ret += "%s=," % _convert_(t.name) + + return ret.strip(',') def get_inputs(op_type): - return "" + ret = "inputs={" + inputs = _get_inputs(op_type) + for t in inputs: + ret += "{}={},".format(t, _convert_(t)) + ret = ret.strip(",") + ret += "}" + + if ret == "inputs={}": + return "" + + return ret + + +""" +def get_input_dtype(op_type): + dtype = None + for ipt in _get_inputs(): + name = _convert_(ipt.name) + val = kwargs.pop(name, []) + if not isinstance(val, list) and not isinstance(val, tuple): + val = [val] + if len(val) == 0: + val = [args[0]] + args = args[1:] + + for each in val: + if not isinstance(each, Variable): + raise ValueError("input of {0} must be variable".format( + op_type)) + + if dtype is None: + dtype = each.dtype + elif dtype != each.dtype: + raise ValueError( + "operator {0} must input same dtype. {1} vs {2}".format( + op_type, dtype, each.dtype)) + + return dtype +""" def get_outputs(op_type): - return "" + ret = "outputs={" + inputs = _get_outputs(op_type) + for t in inputs: + ret += "{}={},".format(t, _convert_(t)) + ret = ret.strip(",") + ret += "}" + + if ret == "inputs={}": + return "" + + return ret + + +""" + attr_names = sorted(op.attr_names) + attrs_str = "" + for i in range(0, len(attr_names)): + name = attr_names[i] + + attr_type = op.desc.attr_type(name) + if attr_type == core.AttrType.BLOCK: + a = "{name} = block[{value}]".format( + name=name, type=attr_type, value=op.block_attr_id(name)) + attrs_str += a + if i != len(attr_names) - 1: + attrs_str += ", " + continue + + if attr_type == core.AttrType.BLOCKS: + a = "{name} = blocks{value}".format( + name=name, type=attr_type, value=op.blocks_attr_ids(name)) + attrs_str += a + if i != len(attr_names) - 1: + attrs_str += ", " + continue + + a = "{name} = {value}".format( + name=name, type=attr_type, value=op.desc.attr(name)) + attrs_str += a + if i != len(attr_names) - 1: + attrs_str += ", " +""" + + +def get_attrs(op_type): + ret = "attrs={" + for t in _get_attrs(op_type): + if t.name in g_filer_attrs: + continue + + ret += "%s=%s," % (t.name, _convert_(t.name)) + + ret = ret.strip(",") + ret += "}" + + return ret + + +def get_outvars(op_type, indent=1): + ret = "" + for t in _get_outputs(op_type): + ret += get_indent_space( + indent + ) + "%s = helper.create_tmp_variable(dtype=helper.input_dtype())\n" % ( + _convert_(t)) + ret = ret.strip('\n') + return ret def get_op_py(op_type): @@ -80,32 +236,51 @@ def get_op_py(op_type): args = get_func_args(op_type) inputs = get_inputs(op_type) outputs = get_outputs(op_type) + attrs = get_attrs(op_type) + out_vars = get_outvars(op_type) code = """ -\@templatedoc() +@templatedoc() def {op_type}({args}): \"\"\" {op_type} + {comment} + Args: - {input_comments} +{input_comments} Returns: - {output_comments} +{output_comments} \"\"\" + + helper = LayerHelper('{op_type}', **locals()) +{generated_outvar} helper.append_op( type='{op_type}', {inputs}, - {outputs}) + {outputs}, + {attrs}) + + return out """.format( - input_comments=input_comments, + comment="${comment}", + input_comments=input_comments.strip('\n'), output_comments=output_comments, args=args, + generated_outvar=out_vars, op_type=op_type, inputs=inputs, - outputs=outputs) + outputs=outputs, + attrs=attrs) return code print(get_op_py("uniform_random_batch_size_like")) +#print(get_op_py("gaussian_random")) +#print(get_op_py("sampling_id")) +#print(get_op_py("gaussian_random_batch_size_like")) +#print(get_op_py("sum")) +#print(get_op_py("slice")) +#print(get_op_py("shape")) #get_meta("linear_chain_crf") -- GitLab