diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 5a9f7705369b384e5fd7cbffb625a674e4e76708..3a164a2fd0c5610f2c546a94d50200bcb405d615 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -225,13 +225,15 @@ def monkey_patch_variable(): other_var = tmp out = create_new_tmp_var(current_block(self), dtype=lhs_dtype) - + axis = -1 + if other_var.shape[0] == -1: + axis = 0 current_block(self).append_op( type=op_type, inputs={'X': [self], 'Y': [other_var]}, outputs={'Out': out}, - attrs={'axis': -1}) + attrs={'axis': axis}) return out comment = OpProtoHolder.instance().get_op_proto(op_type).comment