From d7f2b1161843d437fff102419710cff30d9c79b5 Mon Sep 17 00:00:00 2001 From: channingss Date: Tue, 29 Sep 2020 16:24:40 +0800 Subject: [PATCH] fix dtype for tensorflow --- x2paddle/decoder/tf_decoder.py | 9 +++++++++ x2paddle/op_mapper/tf_op_mapper.py | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index f23fc32..e55f5c0 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -72,6 +72,15 @@ class TFGraphNode(GraphNode): dtype, self.layer.name)) return self.dtype_map[dtype] + def set_dtype(self, dtype): + dtype_idx = 0 + for k, v in self.dtype_map.items(): + if v == dtype: + dtype_idx = k + if dtype_idx == 0: + raise Exception("Cannot set dtype of node to '{}'".format(dtype)) + self.layer.attr['dtype'].type = dtype_idx + @property def raw_dtype(self): keys = ['dtype', 'Tidx', 'T', 'DstT'] diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index 541ee3c..e944283 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -1133,18 +1133,26 @@ class TFOpMapper(OpMapper): inputs = dict() attr = dict() + dtype = 'int32' + if start.dtype.startswith('float'): + dtype = start.dtype if start.layer_type == "Const": attr["start"] = start.value else: inputs["start"] = start.name + if limit.dtype.startswith('float'): + dtype = limit.dtype if limit.layer_type == "Const": attr["end"] = limit.value else: inputs["end"] = limit.name + if delta.dtype.startswith('float'): + dtype = delta.dtype if delta.layer_type == "Const": attr["step"] = delta.value else: inputs["step"] = delta.name + node.set_dtype(dtype) attr["dtype"] = string(node.dtype) program.add_layer( -- GitLab