未验证 提交 48bc7e4e 编写于 作者: J Jason 提交者: GitHub

Merge pull request #44 from jiangjiajun/master

try to fix bug in windows
...@@ -22,6 +22,7 @@ import logging ...@@ -22,6 +22,7 @@ import logging
import math import math
import struct import struct
import numpy import numpy
import os
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -166,10 +167,10 @@ class PaddleEmitter(object): ...@@ -166,10 +167,10 @@ class PaddleEmitter(object):
"float64": "d" "float64": "d"
} }
shape = weight.shape shape = weight.shape
filew = open(dir + "/" + paddle_var_name, "wb") filew = open(os.path.join(dir, paddle_var_name), "wb")
filew.write(struct.pack('i', 0)) numpy.array([0], dtype=numpy.int32).tofile(filew)
filew.write(struct.pack('L', 0)) numpy.array([0], dtype=numpy.int64).tofile(filew)
filew.write(struct.pack('i', 0)) numpy.array([0], dtype=numpy.int32).tofile(filew)
tensor_desc = framework.VarType.TensorDesc() tensor_desc = framework.VarType.TensorDesc()
if str(weight.dtype) in numpy_dtype_map: if str(weight.dtype) in numpy_dtype_map:
tensor_desc.data_type = numpy_dtype_map[str(weight.dtype)] tensor_desc.data_type = numpy_dtype_map[str(weight.dtype)]
...@@ -177,7 +178,7 @@ class PaddleEmitter(object): ...@@ -177,7 +178,7 @@ class PaddleEmitter(object):
raise Exception("Unexpected array dtype [{}]".format(weight.dtype)) raise Exception("Unexpected array dtype [{}]".format(weight.dtype))
tensor_desc.dims.extend(shape) tensor_desc.dims.extend(shape)
desc_size = tensor_desc.ByteSize() desc_size = tensor_desc.ByteSize()
filew.write(struct.pack('i', desc_size)) numpy.array([desc_size], dtype=numpy.int32).tofile(filew)
filew.write(tensor_desc.SerializeToString()) filew.write(tensor_desc.SerializeToString())
weight.tofile(filew) weight.tofile(filew)
filew.close() filew.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册