提交 d7f2b116 编写于 作者: C channingss

fix dtype for tensorflow

上级 97df7f87
...@@ -72,6 +72,15 @@ class TFGraphNode(GraphNode): ...@@ -72,6 +72,15 @@ class TFGraphNode(GraphNode):
dtype, self.layer.name)) dtype, self.layer.name))
return self.dtype_map[dtype] return self.dtype_map[dtype]
def set_dtype(self, dtype):
dtype_idx = 0
for k, v in self.dtype_map.items():
if v == dtype:
dtype_idx = k
if dtype_idx == 0:
raise Exception("Cannot set dtype of node to '{}'".format(dtype))
self.layer.attr['dtype'].type = dtype_idx
@property @property
def raw_dtype(self): def raw_dtype(self):
keys = ['dtype', 'Tidx', 'T', 'DstT'] keys = ['dtype', 'Tidx', 'T', 'DstT']
......
...@@ -1133,18 +1133,26 @@ class TFOpMapper(OpMapper): ...@@ -1133,18 +1133,26 @@ class TFOpMapper(OpMapper):
inputs = dict() inputs = dict()
attr = dict() attr = dict()
dtype = 'int32'
if start.dtype.startswith('float'):
dtype = start.dtype
if start.layer_type == "Const": if start.layer_type == "Const":
attr["start"] = start.value attr["start"] = start.value
else: else:
inputs["start"] = start.name inputs["start"] = start.name
if limit.dtype.startswith('float'):
dtype = limit.dtype
if limit.layer_type == "Const": if limit.layer_type == "Const":
attr["end"] = limit.value attr["end"] = limit.value
else: else:
inputs["end"] = limit.name inputs["end"] = limit.name
if delta.dtype.startswith('float'):
dtype = delta.dtype
if delta.layer_type == "Const": if delta.layer_type == "Const":
attr["step"] = delta.value attr["step"] = delta.value
else: else:
inputs["step"] = delta.name inputs["step"] = delta.name
node.set_dtype(dtype)
attr["dtype"] = string(node.dtype) attr["dtype"] = string(node.dtype)
program.add_layer( program.add_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册