未验证 提交 4e087fd6 编写于 作者: J Jason 提交者: GitHub

Merge pull request #122 from mamingjie-China/develop

support for python2
......@@ -13,8 +13,9 @@
# limitations under the License.
from x2paddle.core.graph import GraphNode
import collections
from x2paddle.core.util import *
import collections
import six
class Layer(object):
......@@ -28,7 +29,7 @@ class Layer(object):
def get_code(self):
layer_code = ""
if self.output is not None:
if isinstance(self.output, str):
if isinstance(self.output, six.string_types):
layer_code = self.output + " = "
else:
layer_code = self.output.layer_name + " = "
......@@ -47,7 +48,7 @@ class Layer(object):
"[{}]".format(input.index) + ", ")
else:
in_list += (input.layer_name + ", ")
elif isinstance(input, str):
elif isinstance(input, six.string_types):
in_list += (input + ", ")
else:
raise Exception(
......@@ -72,7 +73,7 @@ class Layer(object):
"[{}]".format(self.inputs.index) + ", ")
else:
layer_code += (self.inputs.layer_name + ", ")
elif isinstance(self.inputs, str):
elif isinstance(self.inputs, six.string_types):
layer_code += (self.inputs + ", ")
else:
raise Exception("Unknown type of inputs.")
......@@ -119,6 +120,6 @@ class FluidCode(object):
for layer in self.layers:
if isinstance(layer, Layer):
codes.append(layer.get_code())
elif isinstance(layer, str):
elif isinstance(layer, six.string_types):
codes.append(layer)
return codes
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import collections
import copy as cp
......
......@@ -236,11 +236,7 @@ class CaffeDecoder(object):
data.MergeFromString(open(self.model_path, 'rb').read())
pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
layers = data.layers or data.layer
import time
start = time.time()
self.params = [pair(layer) for layer in layers if layer.blobs]
end = time.time()
print('cost:', str(end - start))
def normalize_pb_data(self, layer):
transformed = []
......
......@@ -94,7 +94,7 @@ class ONNXOpMapper(OpMapper):
print(op)
return False
def directly_map(self, node, *args, name='', **kwargs):
def directly_map(self, node, name='', *args, **kwargs):
inputs = node.layer.input
outputs = node.layer.output
op_type = node.layer_type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册