提交 d7f2b116 编写于 作者: C channingss

fix dtype for tensorflow

上级 97df7f87
......@@ -72,6 +72,15 @@ class TFGraphNode(GraphNode):
dtype, self.layer.name))
return self.dtype_map[dtype]
def set_dtype(self, dtype):
dtype_idx = 0
for k, v in self.dtype_map.items():
if v == dtype:
dtype_idx = k
if dtype_idx == 0:
raise Exception("Cannot set dtype of node to '{}'".format(dtype))
self.layer.attr['dtype'].type = dtype_idx
@property
def raw_dtype(self):
keys = ['dtype', 'Tidx', 'T', 'DstT']
......
......@@ -1133,18 +1133,26 @@ class TFOpMapper(OpMapper):
inputs = dict()
attr = dict()
dtype = 'int32'
if start.dtype.startswith('float'):
dtype = start.dtype
if start.layer_type == "Const":
attr["start"] = start.value
else:
inputs["start"] = start.name
if limit.dtype.startswith('float'):
dtype = limit.dtype
if limit.layer_type == "Const":
attr["end"] = limit.value
else:
inputs["end"] = limit.name
if delta.dtype.startswith('float'):
dtype = delta.dtype
if delta.layer_type == "Const":
attr["step"] = delta.value
else:
inputs["step"] = delta.name
node.set_dtype(dtype)
attr["dtype"] = string(node.dtype)
program.add_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册