未验证 提交 6e67d0fb 编写于 作者: W Wu Yi 提交者: GitHub

layer fixes (#14591)

* layer fixes test=develop

* follow update test=develop
上级 dc458b14
......@@ -20,7 +20,7 @@ import string
from six.moves import cStringIO
from ..proto import framework_pb2
from ..framework import OpProtoHolder, Variable
from ..framework import OpProtoHolder, Variable, core, convert_np_dtype_to_dtype_
from ..layer_helper import LayerHelper
__all__ = [
......@@ -178,6 +178,15 @@ def generate_layer_fn(op_type):
"operator {0} must input same dtype. {1} vs {2}".format(
op_type, dtype, each.dtype))
if dtype is None:
arg_dtype = kwargs.get("dtype")
if arg_dtype:
if not isinstance(arg_dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(arg_dtype)
else:
dtype = arg_dtype
else:
dtype = core.VarDesc.VarType.FP32
return dtype
def func(*args, **kwargs):
......
......@@ -622,7 +622,7 @@ def reverse(x, axis):
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='reverse',
inputs={'Input': x},
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis})
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册