未验证 提交 1accd53d 编写于 作者: J Jason 提交者: GitHub

Merge pull request #426 from SunAhong1993/paddle-2.0

add tf static
......@@ -103,6 +103,7 @@ def tf2paddle(model_path,
save_dir,
without_data_format_optimization=False,
define_input_shape=False,
paddle_type="dygraph",
params_merge=False):
# check tensorflow installation and version
try:
......@@ -122,23 +123,26 @@ def tf2paddle(model_path,
return
from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
if paddle_type == "dygraph":
from x2paddle.op_mapper.dygraph.tf2paddle.tf_op_mapper import TFOpMapper
else:
from x2paddle.op_mapper.static.tf2paddle.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model)
program.build()
mapper.paddle_graph.build()
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
bias_opt.run(program)
batch_norm_opt.run(program)
transpose_opt.run(program)
program.gen_model(save_dir)
mapper.paddle_graph.gen_model(save_dir)
def caffe2paddle(proto, weight, save_dir, caffe_proto,
......@@ -293,7 +297,7 @@ def main():
if args.params_merge:
params_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape, params_merge)
define_input_shape, args.paddle_type, params_merge)
elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
......
......@@ -76,6 +76,7 @@ class PaddleGraph(object):
self.custom_code = None
self.inputs_info = None
def set_name(self, name):
self.name = name.replace("-", "_").replace("/", "_")
......@@ -446,6 +447,8 @@ class PaddleGraph(object):
if self.source_type == "caffe":
custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \
"import caffe_custom_layer as x2paddle_nn"
else:
custom_import = ""
self.head = gen_codes(
[
"from paddle.fluid.initializer import Constant",
......@@ -581,6 +584,9 @@ class PaddleGraph(object):
line = ','.join(layer.outputs)
line += " = {}(".format(layer.kernel)
for k, v in layer.inputs.items():
if isinstance(v, list):
line += "{}=[{}], ".format(k, ", ".join(v))
else:
line += "{}={}, ".format(k, v)
for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v)
......@@ -618,6 +624,9 @@ class PaddleGraph(object):
paddle.disable_static()
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)()
if self.source_type == "tf":
model.set_dict(restore, use_structured_name=False)
else:
model.set_dict(restore)
model.eval()
static_model = paddle.jit.to_static(model, input_spec=sepc_list)
......
......@@ -132,6 +132,7 @@ class TFGraph(Graph):
self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
self.tf_data_format = data_format
self.graph_name = "TFModel"
def build(self):
for layer in self.model.node:
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册