提交 01ac9740 编写于 作者: S SunAhong1993

add tf dygraph

上级 0a209d49
...@@ -103,6 +103,7 @@ def tf2paddle(model_path, ...@@ -103,6 +103,7 @@ def tf2paddle(model_path,
save_dir, save_dir,
without_data_format_optimization=False, without_data_format_optimization=False,
define_input_shape=False, define_input_shape=False,
paddle_type="dygraph",
params_merge=False): params_merge=False):
# check tensorflow installation and version # check tensorflow installation and version
try: try:
...@@ -122,23 +123,26 @@ def tf2paddle(model_path, ...@@ -122,23 +123,26 @@ def tf2paddle(model_path,
return return
from x2paddle.decoder.tf_decoder import TFDecoder from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper if paddle_type == "dygraph":
from x2paddle.op_mapper.dygraph.tf2paddle.tf_op_mapper import TFOpMapper
else:
from x2paddle.op_mapper.static.tf2paddle.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tensorflow.bias import BiasOpt from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape) model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model) mapper = TFOpMapper(model)
program.build() mapper.paddle_graph.build()
bias_opt = BiasOpt() bias_opt = BiasOpt()
transpose_opt = TransposeOpt() transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt() batch_norm_opt = BatchNormOpt()
bias_opt.run(program) bias_opt.run(program)
batch_norm_opt.run(program) batch_norm_opt.run(program)
transpose_opt.run(program) transpose_opt.run(program)
program.gen_model(save_dir) mapper.paddle_graph.gen_model(save_dir)
def caffe2paddle(proto, weight, save_dir, caffe_proto, def caffe2paddle(proto, weight, save_dir, caffe_proto,
...@@ -293,7 +297,7 @@ def main(): ...@@ -293,7 +297,7 @@ def main():
if args.params_merge: if args.params_merge:
params_merge = True params_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization, tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape, params_merge) define_input_shape, args.paddle_type, params_merge)
elif args.framework == "caffe": elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
......
...@@ -76,6 +76,7 @@ class PaddleGraph(object): ...@@ -76,6 +76,7 @@ class PaddleGraph(object):
self.custom_code = None self.custom_code = None
self.inputs_info = None self.inputs_info = None
def set_name(self, name): def set_name(self, name):
self.name = name.replace("-", "_").replace("/", "_") self.name = name.replace("-", "_").replace("/", "_")
...@@ -446,6 +447,8 @@ class PaddleGraph(object): ...@@ -446,6 +447,8 @@ class PaddleGraph(object):
if self.source_type == "caffe": if self.source_type == "caffe":
custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \ custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \
"import caffe_custom_layer as x2paddle_nn" "import caffe_custom_layer as x2paddle_nn"
else:
custom_import = ""
self.head = gen_codes( self.head = gen_codes(
[ [
"from paddle.fluid.initializer import Constant", "from paddle.fluid.initializer import Constant",
...@@ -618,6 +621,9 @@ class PaddleGraph(object): ...@@ -618,6 +621,9 @@ class PaddleGraph(object):
paddle.disable_static() paddle.disable_static()
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model")) restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)() model = getattr(x2paddle_code, self.name)()
if self.source_type == "tf":
model.set_dict(restore, use_structured_name=False)
else:
model.set_dict(restore) model.set_dict(restore)
model.eval() model.eval()
static_model = paddle.jit.to_static(model, input_spec=sepc_list) static_model = paddle.jit.to_static(model, input_spec=sepc_list)
......
...@@ -132,6 +132,7 @@ class TFGraph(Graph): ...@@ -132,6 +132,7 @@ class TFGraph(Graph):
self.identity_map = dict() self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2'] self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
self.tf_data_format = data_format self.tf_data_format = data_format
self.graph_name = "TFModel"
def build(self): def build(self):
for layer in self.model.node: for layer in self.model.node:
......
...@@ -278,6 +278,10 @@ def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif ...@@ -278,6 +278,10 @@ def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
line = "{} = [{}]".format(layer.outputs[0], inputs_str) line = "{} = [{}]".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_list_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = [a*b for a,b in zip({}, {})]".format(layer.outputs[0],
get_value(layer, "list0", different_attrs),
get_value(layer, "list1", different_attrs))
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs)) line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册