diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index f23fc32b2ed812e4a3adf9a5c506a8137f7f4cd0..e55f5c05b4a36d54d26cac3aadc034ec5b347b3b 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -72,6 +72,15 @@ class TFGraphNode(GraphNode): dtype, self.layer.name)) return self.dtype_map[dtype] + def set_dtype(self, dtype): + dtype_idx = 0 + for k, v in self.dtype_map.items(): + if v == dtype: + dtype_idx = k + if dtype_idx == 0: + raise Exception("Cannot set dtype of node to '{}'".format(dtype)) + self.layer.attr['dtype'].type = dtype_idx + @property def raw_dtype(self): keys = ['dtype', 'Tidx', 'T', 'DstT'] diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index 541ee3c845d78df5e0ec20763b6e9b7d92b6791e..e94428386a83199462682b8c7943669b9837742f 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -1133,18 +1133,26 @@ class TFOpMapper(OpMapper): inputs = dict() attr = dict() + dtype = 'int32' + if start.dtype.startswith('float'): + dtype = start.dtype if start.layer_type == "Const": attr["start"] = start.value else: inputs["start"] = start.name + if limit.dtype.startswith('float'): + dtype = limit.dtype if limit.layer_type == "Const": attr["end"] = limit.value else: inputs["end"] = limit.name + if delta.dtype.startswith('float'): + dtype = delta.dtype if delta.layer_type == "Const": attr["step"] = delta.value else: inputs["step"] = delta.name + node.set_dtype(dtype) attr["dtype"] = string(node.dtype) program.add_layer(