From 8090c270a306c959ccd4bd81159f43da38974820 Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 4 Dec 2017 15:42:55 +0800 Subject: [PATCH] Small fix in tf2paddle --- image_classification/tf2paddle/README.md | 2 +- image_classification/tf2paddle/tf2paddle.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/image_classification/tf2paddle/README.md b/image_classification/tf2paddle/README.md index e61a781c..df56dcda 100644 --- a/image_classification/tf2paddle/README.md +++ b/image_classification/tf2paddle/README.md @@ -16,7 +16,7 @@ 1. TensorFlow网络配置中同一Operator内的`Variable`属于相同的scope,以此为依据将`Variable`划分到不同的`paddle.layer`。 1. `conv2d`、`batchnorm`、`fc`的scope需分别包含`conv`、`bn`、`fc`,以此获取对应`paddle.layer`的类型。也可以通过为`TFModelConverter`传入`layer_type_map`的`dict`,将scope映射到对应的`paddle.layer`的type来规避此项约束。 1. `conv2d`、`fc`中`Variable`的顺序为:先可学习`Weight`后`Bias`;`batchnorm`中`Variable`的顺序为:`scale`、`shift`、`mean`、`var`,请注意参数存储的顺序将`Variable`对应到`paddle.layer.batch_norm`相应位置的参数。 -1. TensorFlow网络拓扑顺序需和PaddlePaddle网络拓扑顺序一致,尤其注意具有分支时左右分支的顺序。这是针对模型转换和PaddlePaddle网络配置均使用PaddlePaddle默认参数命名的情况,此时将根据拓扑顺序进行参数命名。 +1. TensorFlow网络拓扑顺序需和PaddlePaddle网络拓扑顺序一致,尤其注意具有分支时分支的先后顺序,如ResNet的bottleneck模块中两分支定义的先后顺序。这是针对模型转换和PaddlePaddle网络配置均使用PaddlePaddle默认参数命名的情况,此时将根据拓扑顺序进行参数命名。 1. 若PaddlePaddle网络配置中需要通过调用`param_attr=paddle.attr.Param(name="XX"))`显示地设置可学习参数名字,这时可通过为`TFModelConverter`传入`layer_name_map`或`param_name_map`字典(类型为Python `dict`),在模型转换时将`Variable`的名字映射为所对应的`paddle.layer.XX`中可学习参数的名字。 1. 要求提供`build_model`接口以从此构建TensorFlow网络,加载模型并返回session。可参照如下示例进行编写: diff --git a/image_classification/tf2paddle/tf2paddle.py b/image_classification/tf2paddle/tf2paddle.py index c4c18fad..20b6cade 100644 --- a/image_classification/tf2paddle/tf2paddle.py +++ b/image_classification/tf2paddle/tf2paddle.py @@ -108,7 +108,8 @@ class TFModelConverter(ModelConverter): @wrap_name_default("conv") def convert_conv_layer(self, params, params_names=None, name=None): for i in range(len(params)): - data = np.transpose(params[i], (3, 2, 0, 1)) + data = np.transpose(params[i], ( + 3, 2, 0, 1)) if len(params[i].shape) == 4 else params[i] if len(params) == 2: suffix = "0" if i == 0 else "bias" file_name = "_%s.w%s" % (name, suffix) if not ( -- GitLab