提交 1a5e4a8a 编写于 作者: C channingss

fix bug for (import torch)

上级 88858561
...@@ -139,17 +139,6 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto): ...@@ -139,17 +139,6 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto):
def onnx2paddle(model_path, save_dir): def onnx2paddle(model_path, save_dir):
# check onnx installation and version # check onnx installation and version
try:
import torch
version = torch.__version__
if '1.2.0' not in version:
print("torch==1.2.0 is required")
return
except:
print(
"we use caffe2 to inference graph, please use \"pip install torch==1.2.0\"."
)
return
try: try:
import onnx import onnx
version = onnx.version.version version = onnx.version.version
...@@ -193,6 +182,17 @@ def main(): ...@@ -193,6 +182,17 @@ def main():
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)" assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined" assert args.save_dir is not None, "--save_dir is not defined"
try:
import paddle
v0, v1, v2 = paddle.__version__.split('.')
if int(v0) != 1 or int(v1) < 5:
print("paddlepaddle>=1.5.0 is required")
return
except:
print("paddlepaddle not installed, use \"pip install paddlepaddle\"")
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined"
if args.framework == "tensorflow": if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model" assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False without_data_format_optimization = False
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.proto import framework_pb2 from paddle.fluid.proto import framework_pb2
from x2paddle.core.util import * from x2paddle.core.util import *
import inspect import inspect
...@@ -46,6 +47,28 @@ def export_paddle_param(param, param_name, dir): ...@@ -46,6 +47,28 @@ def export_paddle_param(param, param_name, dir):
fp.close() fp.close()
# This func will copy to generate code file
def run_net(param_dir="./"):
import os
inputs, outputs = x2paddle_net()
for i, out in enumerate(outputs):
if isinstance(out, list):
for out_part in out:
outputs.append(out_part)
del outputs[i]
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
def if_exist(var):
b = os.path.exists(os.path.join(param_dir, var.name))
return b
fluid.io.load_vars(exe,
param_dir,
fluid.default_main_program(),
predicate=if_exist)
class OpMapper(object): class OpMapper(object):
def __init__(self): def __init__(self):
self.paddle_codes = "" self.paddle_codes = ""
......
...@@ -11,8 +11,6 @@ ...@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid
import numpy import numpy
import math import math
import os import os
...@@ -20,25 +18,3 @@ import os ...@@ -20,25 +18,3 @@ import os
def string(param): def string(param):
return "\'{}\'".format(param) return "\'{}\'".format(param)
# This func will copy to generate code file
def run_net(param_dir="./"):
import os
inputs, outputs = x2paddle_net()
for i, out in enumerate(outputs):
if isinstance(out, list):
for out_part in out:
outputs.append(out_part)
del outputs[i]
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
def if_exist(var):
b = os.path.exists(os.path.join(param_dir, var.name))
return b
fluid.io.load_vars(exe,
param_dir,
fluid.default_main_program(),
predicate=if_exist)
...@@ -271,6 +271,17 @@ class ONNXGraph(Graph): ...@@ -271,6 +271,17 @@ class ONNXGraph(Graph):
return value_info return value_info
def get_results_of_inference(self, model, shape): def get_results_of_inference(self, model, shape):
try:
import torch
version = torch.__version__
if '1.1.0' not in version:
print("your model have dynamic graph, torch==1.1.0 is required")
return
except:
print(
"your model have dynamic graph, we use caff2 to inference graph, please use \"pip install torch==1.1.0\"."
)
return
from x2paddle.decoder.onnx_backend import prepare from x2paddle.decoder.onnx_backend import prepare
np_images = np.random.rand(shape[0], shape[1], shape[2], np_images = np.random.rand(shape[0], shape[1], shape[2],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册