提交 1958654d 编写于 作者: L Luo Tao

refine \odot in elementwise_mul

上级 1e2acd97
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\odot\\ Y"); REGISTER_ELEMWISE_OP(elementwise_mul, "Mul", "Out = X \\\\odot Y");
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>, ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -44,6 +44,11 @@ def _type_to_str_(tp): ...@@ -44,6 +44,11 @@ def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp) return framework_pb2.AttrType.Name(tp)
_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def _generate_doc_string_(op_proto): def _generate_doc_string_(op_proto):
""" """
Generate docstring by OpProto Generate docstring by OpProto
...@@ -55,22 +60,27 @@ def _generate_doc_string_(op_proto): ...@@ -55,22 +60,27 @@ def _generate_doc_string_(op_proto):
str: the document string str: the document string
""" """
def escape_math(text):
return _two_bang_pattern_.sub(
r'$$\1$$',
_single_dollar_pattern_.sub(
r':math:`\1`', _two_dollar_pattern_.sub(r"!!\1!!", text)))
if not isinstance(op_proto, framework_pb2.OpProto): if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("OpProto should be `framework_pb2.OpProto`") raise TypeError("OpProto should be `framework_pb2.OpProto`")
buf = cStringIO.StringIO() buf = cStringIO.StringIO()
buf.write(op_proto.comment) buf.write(escape_math(op_proto.comment))
buf.write('\nArgs:\n') buf.write('\nArgs:\n')
for each_input in op_proto.inputs: for each_input in op_proto.inputs:
line_begin = ' {0}: '.format(_convert_(each_input.name)) line_begin = ' {0}: '.format(_convert_(each_input.name))
buf.write(line_begin) buf.write(line_begin)
buf.write(each_input.comment) buf.write(escape_math(each_input.comment))
buf.write('\n') buf.write('\n')
buf.write(' ' * len(line_begin)) if each_input.duplicable:
buf.write('Duplicable: ') buf.write(" Duplicatable.")
buf.write(str(each_input.duplicable)) if each_input.dispensable:
buf.write(' Optional: ') buf.write(" Optional.")
buf.write(str(each_input.dispensable))
buf.write('\n') buf.write('\n')
skip_attrs = OpProtoHolder.generated_op_attr_names() skip_attrs = OpProtoHolder.generated_op_attr_names()
...@@ -83,7 +93,7 @@ def _generate_doc_string_(op_proto): ...@@ -83,7 +93,7 @@ def _generate_doc_string_(op_proto):
buf.write(' (') buf.write(' (')
buf.write(_type_to_str_(each_attr.type)) buf.write(_type_to_str_(each_attr.type))
buf.write('): ') buf.write('): ')
buf.write(each_attr.comment) buf.write(escape_math(each_attr.comment))
buf.write('\n') buf.write('\n')
if len(op_proto.outputs) != 0: if len(op_proto.outputs) != 0:
...@@ -92,7 +102,7 @@ def _generate_doc_string_(op_proto): ...@@ -92,7 +102,7 @@ def _generate_doc_string_(op_proto):
for each_opt in op_proto.outputs: for each_opt in op_proto.outputs:
if not each_opt.intermediate: if not each_opt.intermediate:
break break
buf.write(each_opt.comment) buf.write(escape_math(each_opt.comment))
return buf.getvalue() return buf.getvalue()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册