提交 01ac9740 编写于 作者: S SunAhong1993

add tf dygraph

上级 0a209d49
......@@ -103,6 +103,7 @@ def tf2paddle(model_path,
save_dir,
without_data_format_optimization=False,
define_input_shape=False,
paddle_type="dygraph",
params_merge=False):
# check tensorflow installation and version
try:
......@@ -120,25 +121,28 @@ def tf2paddle(model_path,
"[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
)
return
from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
if paddle_type == "dygraph":
from x2paddle.op_mapper.dygraph.tf2paddle.tf_op_mapper import TFOpMapper
else:
from x2paddle.op_mapper.static.tf2paddle.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model)
program.build()
mapper.paddle_graph.build()
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
bias_opt.run(program)
batch_norm_opt.run(program)
transpose_opt.run(program)
program.gen_model(save_dir)
mapper.paddle_graph.gen_model(save_dir)
def caffe2paddle(proto, weight, save_dir, caffe_proto,
......@@ -293,7 +297,7 @@ def main():
if args.params_merge:
params_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape, params_merge)
define_input_shape, args.paddle_type, params_merge)
elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
......
......@@ -76,6 +76,7 @@ class PaddleGraph(object):
self.custom_code = None
self.inputs_info = None
def set_name(self, name):
self.name = name.replace("-", "_").replace("/", "_")
......@@ -285,8 +286,8 @@ class PaddleGraph(object):
for input_name in self.inputs:
input_shapes.append(self.inputs_info[input_name][0])
input_types.append(self.inputs_info[input_name][1])
# 如果input_files非空,则导出推理模型;其值类似[[None, 3, 224, 224]]
self.dygraph2static(save_dir, input_shapes, input_types)
# 如果input_files非空,则导出推理模型;其值类似[[None, 3, 224, 224]]
self.dygraph2static(save_dir, input_shapes, input_types)
def gen_static_code(self, code_dir):
def write_code(f, code_list, indent=0):
......@@ -446,6 +447,8 @@ class PaddleGraph(object):
if self.source_type == "caffe":
custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \
"import caffe_custom_layer as x2paddle_nn"
else:
custom_import = ""
self.head = gen_codes(
[
"from paddle.fluid.initializer import Constant",
......@@ -618,7 +621,10 @@ class PaddleGraph(object):
paddle.disable_static()
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)()
model.set_dict(restore)
if self.source_type == "tf":
model.set_dict(restore, use_structured_name=False)
else:
model.set_dict(restore)
model.eval()
static_model = paddle.jit.to_static(model, input_spec=sepc_list)
paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model"))
\ No newline at end of file
......@@ -132,6 +132,7 @@ class TFGraph(Graph):
self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
self.tf_data_format = data_format
self.graph_name = "TFModel"
def build(self):
for layer in self.model.node:
......
......@@ -277,7 +277,11 @@ def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent))
def prim_list_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = [a*b for a,b in zip({}, {})]".format(layer.outputs[0],
get_value(layer, "list0", different_attrs),
get_value(layer, "list1", different_attrs))
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册