From 1958654d6f15087c28b44759c1a8d004826f00ce Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Fri, 15 Jun 2018 14:28:17 +0800 Subject: [PATCH] refine \odot in elementwise_mul --- paddle/fluid/operators/elementwise_mul_op.cc | 2 +- .../fluid/layers/layer_function_generator.py | 28 +++++++++++++------ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise_mul_op.cc index ba343909bb8..7cd67e74de6 100644 --- a/paddle/fluid/operators/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise_mul_op.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise_op.h" namespace ops = paddle::operators; -REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\odot\\ Y"); +REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\\\odot Y"); REGISTER_OP_CPU_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, diff --git a/python/paddle/fluid/layers/layer_function_generator.py b/python/paddle/fluid/layers/layer_function_generator.py index cb60a3aec9a..0f05ea2b08d 100644 --- a/python/paddle/fluid/layers/layer_function_generator.py +++ b/python/paddle/fluid/layers/layer_function_generator.py @@ -44,6 +44,11 @@ def _type_to_str_(tp): return framework_pb2.AttrType.Name(tp) +_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$") +_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$") +_two_bang_pattern_ = re.compile(r"!!([^!]+)!!") + + def _generate_doc_string_(op_proto): """ Generate docstring by OpProto @@ -55,22 +60,27 @@ def _generate_doc_string_(op_proto): str: the document string """ + def escape_math(text): + return _two_bang_pattern_.sub( + r'$$\1$$', + _single_dollar_pattern_.sub( + r':math:`\1`', _two_dollar_pattern_.sub(r"!!\1!!", text))) + if not isinstance(op_proto, framework_pb2.OpProto): raise TypeError("OpProto should be `framework_pb2.OpProto`") buf = cStringIO.StringIO() - buf.write(op_proto.comment) + buf.write(escape_math(op_proto.comment)) buf.write('\nArgs:\n') for each_input in op_proto.inputs: line_begin = ' {0}: '.format(_convert_(each_input.name)) buf.write(line_begin) - buf.write(each_input.comment) + buf.write(escape_math(each_input.comment)) buf.write('\n') - buf.write(' ' * len(line_begin)) - buf.write('Duplicable: ') - buf.write(str(each_input.duplicable)) - buf.write(' Optional: ') - buf.write(str(each_input.dispensable)) + if each_input.duplicable: + buf.write(" Duplicatable.") + if each_input.dispensable: + buf.write(" Optional.") buf.write('\n') skip_attrs = OpProtoHolder.generated_op_attr_names() @@ -83,7 +93,7 @@ def _generate_doc_string_(op_proto): buf.write(' (') buf.write(_type_to_str_(each_attr.type)) buf.write('): ') - buf.write(each_attr.comment) + buf.write(escape_math(each_attr.comment)) buf.write('\n') if len(op_proto.outputs) != 0: @@ -92,7 +102,7 @@ def _generate_doc_string_(op_proto): for each_opt in op_proto.outputs: if not each_opt.intermediate: break - buf.write(each_opt.comment) + buf.write(escape_math(each_opt.comment)) return buf.getvalue() -- GitLab