提交 07d4531a 编写于 作者: W wanghaoshuang

Refine document and scripts.

1. Add document.
2. Add arguments for saving model and init model.
3. Refine inference.py and eval.py.
4. Make ctc_reader.py support for custom data.
上级 97cfb9de
# OCR Model

[toc]
This model built with paddle fluid is still under active development and is not
the final version. We welcome feedbacks.
运行本目录下的程序示例需要使用PaddlePaddle v0.11.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照安装文档中的说明更新PaddlePaddle安装版本。
# Optical Character Recognition
这里将介绍如何在PaddlePaddle fluid下使用CRNN-CTC 和 CRNN-Attention模型对图片中的文字内容进行识别。
## 1. CRNN-CTC
本章的任务是识别含有单行汉语字符图片,首先采用卷积将图片转为`features map`, 然后使用`im2sequence op``features map`转为`sequence`,经过`双向GRU RNN`得到每个step的汉语字符的概率分布。训练过程选用的损失函数为CTC loss,最终的评估指标为`instance error rate`
本路径下各个文件的作用如下:
- **ctc_reader.py : ** 下载、读取、处理数据。提供方法`train()``test()` 分别产生训练集和测试集的数据迭代器。
- **crnn_ctc_model.py : ** 在该脚本中定义了训练网络、预测网络和evaluate网络。
- **ctc_train.py : ** 用于模型的训练,可通过命令`python train.py --help` 获得使用方法。
- **inference.py : ** 加载训练好的模型文件,对新数据进行预测。可通过命令`python inference.py --help` 获得使用方法。
- **eval.py : ** 评估模型在指定数据集上的效果。可通过命令`python inference.py --help` 获得使用方法。
- **utility.py : ** 实现的一些通用方法,包括参数配置、tensor的构造等。
### 1.1 数据
数据的下载和简单预处理都在`ctc_reader.py`中实现。
#### 1.1.1 数据格式
我们使用的训练和测试数据如`图1`所示,每张图片包含单行不定长的中文字符串,这些图片都是经过检测算法进行预框选处理的。
<p align="center">
<img src="images/demo.jpg" width="620" hspace='10'/> <br/>
<strong>图 1</strong>
</p>
在训练集中,每张图片对应的label是由若干数字组成的sequence。 Sequence中的每个数字表示一个字符在字典中的index。 `图1` 对应的label如下所示:
```
3835,8371,7191,2369,6876,4162,1938,168,1517,4590,3793
```
在上边这个label中,`3835` 表示字符‘两’的index,`4590` 表示中文字符逗号的index。
#### 1.1.2 数据准备
**A. 训练集**
我们需要把所有参与训练的图片放入同一个文件夹,暂且记为`train_images`。然后用一个list文件存放每张图片的信息,包括图片大小、图片名称和对应的label,这里暂记该list文件为`train_list`,其格式如下所示:
```
185 48 00508_0215.jpg 7740,5332,2369,3201,4162
48 48 00197_1893.jpg 6569
338 48 00007_0219.jpg 4590,4788,3015,1994,3402,999,4553
150 48 00107_4517.jpg 5936,3382,1437,3382
...
157 48 00387_0622.jpg 2397,1707,5919,1278
```
<center>文件train_list</center>
上述文件中的每一行表示一张图片,每行被空格分为四列,前两列分别表示图片的宽和高,第三列表示图片的名称,第四列表示该图片对应的sequence label。
最终我们应有以下类似文件结构:
```
|-train_data
|- train_list
|- train_imags
|- 00508_0215.jpg
|- 00197_1893.jpg
|- 00007_0219.jpg
| ...
```
在训练时,我们通过选项`--train_images``--train_list` 分别设置准备好的`train_images``train_list`
>**注:** 如果`--train_images` 和 `--train_list`都未设置或设置为None, ctc_reader.py会自动下载使用[示例数据](http://cloud.dlnel.org/filepub/?uuid=df937251-3c0b-480d-9a7b-0080dfeee65c),并将其缓存到`$HOME/.cache/paddle/dataset/ctc_data/data/` 路径下。
**B. 测试集和评估集**
测试集、评估集的准备方式与训练集相同。
在训练阶段,测试集的路径通过train.py的选项`--test_images``--test_list` 来设置。
在评估时,评估集的路径通过eval.py的选项`--input_images_dir``--input_images_list` 来设置。
**C. 待预测数据集**
待预测数据集的格式与训练集也类似,只不过list文件中的最后一列可以放任意占位字符或字符串,如下所示:
```
185 48 00508_0215.jpg s
48 48 00197_1893.jpg s
338 48 00007_0219.jpg s
...
```
在做inference时,通过inference.py的选项`--input_images_dir``--input_images_list` 来设置输入数据的路径。
#### 1.2 训练
通过以下命令调用训练脚本进行训练:
```
python ctc_train.py [options]
```
其中,options支持配置以下训练相关的参数:
**- -batch_size : ** Minibatch 大小,默认为32.
**- -pass_num : ** 训练多少个pass。默认为100。
**- -log_period : ** 每隔多少个minibatch打印一次训练日志, 默认为1000.
**- -save_model_period : ** 每隔多少个minibatch保存一次模型。默认为15000。
如果设置为-1,则永不保存模型。
**- -eval_period : ** 每隔多少个minibatch用测试集测试一次模型。默认为15000。如果设置为-1,则永不进行测试。
**- -save_model_dir : ** 保存模型文件的路径,默认为“./models”,如果指定路径不存在,则会自动创建路径。
**- -init_model : ** 初始化模型的路径。如果模型是以单个文件存储的,这里需要指定具体文件的路径;如果模型是以多个文件存储的,这里只需指定多个文件所在文件夹的路径。该选项默认为 None,意思是不用预训练模型做初始化。
**- -learning_rate : ** 全局learning rate. 默认为 0.001.
**- -l2 : ** L2 regularizer. 默认为0.0004.
**- -max_clip : ** Max gradient clipping threshold. 默认为10.0.
**- -momentum : ** Momentum. 默认为0.9.
**- -rnn_hidden_size: ** RNN 隐藏层大小。 默认为200。
**- -device DEVICE : ** 设备ID。设置为-1,训练在CPU执行;设置为0,训练在GPU上执行。默认为0。
**- -min_average_window : ** Min average window. 默认为10000.
**- -max_average_window : ** Max average window. 建议大小设置为一个pass内minibatch的数量。默认为15625.
**- -average_window : ** Average window. 默认为0.15.
**- -parallel : ** 是否使用多卡进行训练。默认为True.
**- -train_images : ** 存放训练集图片的路径,如果设置为None,ctc_reader会自动下载使用默认数据集。如果使用自己的数据进行训练,需要修改该选项。默认为None。
**- -train_list : ** 存放训练集图片信息的list文件,如果设置为None,ctc_reader会自动下载使用默认数据集。如果使用自己的数据进行训练,需要修改该选项。默认为None。
**- -test_images : ** 存放测试集图片的路径,如果设置为None,ctc_reader会自动下载使用默认数据集。如果使用自己的数据进行测试,需要修改该选项。默认为None。
**- -test_list : ** 存放测试集图片信息的list文件,如果设置为None,ctc_reader会自动下载使用默认数据集。如果使用自己的数据进行测试,需要修改该选项。默认为None。
**- -num_classes : ** 字符集的大小。如果设置为None, 则使用ctc_reader提供的字符集大小。如果使用自己的数据进行训练,需要修改该选项。默认为None.
### 1.3 Inference
通过以下命令调用预测脚本进行预测:
```
python inference.py [options]
```
其中,options支持配置以下预测相关的参数:
**--model_path : ** 用来做预测的模型文件。如果模型是以单个文件存储的,这里需要指定具体文件的路径;如果模型是以多个文件存储的,这里只需指定多个文件所在文件夹的路径。为必设置选项。
**--input_images_dir : ** 存放待预测图片的文件夹路径。如果设置为None, 则使用ctc_reader提供的默认数据。默认为None.
**--input_images_list : ** 存放待预测图片信息的list文件的路径。如果设置为None, 则使用ctc_reader提供的默认数据。默认为None.
**--device DEVICE :** 设备ID。设置为-1,运行在CPU上;设置为0,运行在GPU上。默认为0。
预测结果会print到标准输出。
### 1.4 Evaluate
通过以下命令调用评估脚本用指定数据集对模型进行评估:
```
python eval.py [options]
```
其中,options支持配置以下评估相关的参数:
**--model_path : ** 待评估模型的文件路径。如果模型是以单个文件存储的,这里需要指定具体文件的路径;如果模型是以多个文件存储的,这里只需指定多个文件所在文件夹的路径。为必设置选项。
**--input_images_dir : ** 存放待评估图片的文件夹路径。如果设置为None, 则使用ctc_reader提供的默认数据。默认为None.
**--input_images_list : ** 存放待评估图片信息的list文件的路径。如果设置为None, 则使用ctc_reader提供的默认数据。默认为None.
**--device DEVICE :** 设备ID。设置为-1,运行在CPU上;设置为0,运行在GPU上。默认为0。
......@@ -143,7 +143,7 @@ def ctc_train_net(images, label, args, num_classes):
gradient_clip = None
if args.parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
pd = fluid.layers.ParallelDo(places, use_nccl=True)
with pd.do():
images_ = pd.read_input(images)
label_ = pd.read_input(label)
......
......@@ -124,21 +124,25 @@ def data_shape():
return DATA_SHAPE
def train(batch_size):
def train(batch_size, train_images_dir=None, train_list_file=None):
generator = DataGenerator()
data_dir = download_data()
return generator.train_reader(
path.join(data_dir, TRAIN_DATA_DIR_NAME),
path.join(data_dir, TRAIN_LIST_FILE_NAME), batch_size)
if train_images_dir is None:
data_dir = download_data()
train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
if train_list_file is None:
train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
return generator.train_reader(train_images_dir, train_list_file, batch_size)
def test(batch_size=1):
def test(batch_size=1, test_images_dir=None, test_list_file=None):
generator = DataGenerator()
data_dir = download_data()
if test_images_dir is None:
data_dir = download_data()
test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
if test_list_file is None:
test_list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
return paddle.batch(
generator.test_reader(
path.join(data_dir, TRAIN_DATA_DIR_NAME),
path.join(data_dir, TRAIN_LIST_FILE_NAME)), batch_size)
generator.test_reader(test_images_dir, test_list_file), batch_size)
def download_data():
......
"""Trainer for OCR CTC model."""
import paddle.fluid as fluid
import dummy_reader
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
import ctc_reader
import argparse
from load_model import load_param
import functools
import sys
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
import time
import os
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('pass_num', int, 100, "# of training epochs.")
add_arg('log_period', int, 1000, "Log period.")
add_arg('learning_rate', float, 1.0e-3, "Learning rate.")
add_arg('l2', float, 0.0004, "L2 regularizer.")
add_arg('max_clip', float, 10.0, "Max clip threshold.")
add_arg('min_clip', float, -10.0, "Min clip threshold.")
add_arg('momentum', float, 0.9, "Momentum.")
add_arg('rnn_hidden_size',int, 200, "Hidden size of rnn layers.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.")
add_arg('min_average_window', int, 10000, "Min average window.")
add_arg('max_average_window', int, 15625, "Max average window.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('pass_num', int, 100, "# of training epochs.")
add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period.")
add_arg('eval_period', int, 15000, "Evaluate period.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('learning_rate', float, 1.0e-3, "Learning rate.")
add_arg('l2', float, 0.0004, "L2 regularizer.")
add_arg('max_clip', float, 10.0, "Max clip threshold.")
add_arg('min_clip', float, -10.0, "Min clip threshold.")
add_arg('momentum', float, 0.9, "Momentum.")
add_arg('rnn_hidden_size', int, 200, "Hidden size of rnn layers.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU"
"while '0' means GPU-0.")
add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 15625, "Max average window.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('train_images', str, None, "The directory of training images."
"None means using the default training images of reader.")
add_arg('train_list', str, None, "The list file of training images."
"None means using the default train_list file of reader.")
add_arg('test_images', str, None, "The directory of training images."
"None means using the default test images of reader.")
add_arg('test_list', str, None, "The list file of training images."
"None means using the default test_list file of reader.")
add_arg('num_classes', int, None, "The number of classes."
"None means using the default num_classes from reader.")
# yapf: disable
def load_parameter(place):
params = load_param('./name.map', './data/model/results_without_avg_window/pass-00000/')
for name in params:
t = fluid.global_scope().find_var(name).get_tensor()
t.set(params[name], place)
def train(args, data_reader=dummy_reader):
def train(args, data_reader=ctc_reader):
"""OCR CTC training"""
num_classes = data_reader.num_classes()
num_classes = data_reader.num_classes() if args.num_classes is None else args.num_classes
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
......@@ -47,15 +54,30 @@ def train(args, data_reader=dummy_reader):
sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(images, label, args, num_classes)
# data reader
train_reader = data_reader.train(args.batch_size)
test_reader = data_reader.test()
train_reader = data_reader.train(
args.batch_size,
train_images_dir=args.train_images,
train_list_file=args.train_list)
test_reader = data_reader.test(
test_images_dir=args.test_images,
test_list_file=args.test_list)
# prepare environment
place = fluid.CPUPlace()
if args.device >= 0:
place = fluid.CUDAPlace(args.device)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#load_parameter(place)
# load init model
if args.init_model is not None:
model_dir = args.init_model
model_file_name = None
if not os.path.isdir(args.init_model):
model_dir=os.path.dirname(args.init_model)
model_file_name=os.path.basename(args.init_model)
fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
print "Init model from: %s." % args.init_model
for pass_id in range(args.pass_num):
error_evaluator.reset(exe)
......@@ -70,24 +92,31 @@ def train(args, data_reader=dummy_reader):
fetch_list=[sum_cost] + error_evaluator.metrics)
total_loss += batch_loss[0]
total_seq_error += batch_seq_error[0]
if batch_id % 100 == 1:
print '.',
sys.stdout.flush()
if batch_id % args.log_period == 1:
# training log
if batch_id % args.log_period == 0:
print "\nTime: %s; Pass[%d]-batch[%d]; Avg Warp-CTC loss: %s; Avg seq error: %s." % (
time.time(),
pass_id, batch_id, total_loss / (batch_id * args.batch_size), total_seq_error / (batch_id * args.batch_size))
sys.stdout.flush()
batch_id += 1
# evaluate
if batch_id % args.eval_period == 0:
with model_average.apply(exe):
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
with model_average.apply(exe):
error_evaluator.reset(exe)
for data in test_reader():
exe.run(inference_program, feed=get_feeder_data(data, place))
_, test_seq_error = error_evaluator.eval(exe)
print "\nTime: %s; Pass[%d]-batch[%d]; Test seq error: %s.\n" % (
time.time(), pass_id, batch_id, str(test_seq_error[0]))
# save model
if batch_id % args.save_model_period == 0:
with model_average.apply(exe):
filename = "model_%05d_%d" % (pass_id, batch_id)
fluid.io.save_params(exe, dirname=args.save_model_dir, filename=filename)
print "Saved model to: %s/%s." %(args.save_model_dir, filename)
batch_id += 1
print "\nEnd pass[%d]; Test seq error: %s.\n" % (
pass_id, str(test_seq_error[0]))
def main():
args = parser.parse_args()
......
"""A dummy reader for test."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import numpy as np
import paddle.v2 as paddle
DATA_SHAPE = [1, 512, 512]
NUM_CLASSES = 20
def _read_creater(num_sample=1024, min_seq_len=1, max_seq_len=10):
def reader():
for i in range(num_sample):
sequence_len = np.random.randint(min_seq_len, max_seq_len)
x = np.random.uniform(0.1, 1, DATA_SHAPE).astype("float32")
y = np.random.randint(0, NUM_CLASSES + 1,
[sequence_len]).astype("int32")
yield x, y
return reader
def train(batch_size, num_sample=128):
"""Get train dataset reader."""
return paddle.batch(_read_creater(num_sample=num_sample), batch_size)
def test(batch_size=1, num_sample=16):
"""Get test dataset reader."""
return paddle.batch(_read_creater(num_sample=num_sample), batch_size)
def data_shape():
"""Get image shape in CHW order."""
return DATA_SHAPE
def num_classes():
"""Get number of total classes."""
return NUM_CLASSES
import paddle.v2 as paddle
import paddle.fluid as fluid
from load_model import load_param
from utility import get_feeder_data
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_infer
from crnn_ctc_model import ctc_eval
import ctc_reader
import dummy_reader
import argparse
import functools
import os
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_path', str, None, "The model path to be used for inference.")
add_arg('input_images_dir', str, None, "The directory of images.")
add_arg('input_images_list', str, None, "The list file of images.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU")
# yapf: disable
def load_parameter(place):
params = load_param('./name.map', './data/model/results/pass-00062/')
for name in params:
print "param: %s" % name
t = fluid.global_scope().find_var(name).get_tensor()
t.set(params[name], place)
def evaluate(eval=ctc_eval, data_reader=dummy_reader):
def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
"""OCR inference"""
num_classes = data_reader.num_classes()
data_shape = data_reader.data_shape()
......@@ -26,29 +29,39 @@ def evaluate(eval=ctc_eval, data_reader=dummy_reader):
evaluator, cost = eval(images, label, num_classes)
# data reader
test_reader = data_reader.test()
test_reader = data_reader.test(test_images_dir=args.input_images_dir, test_list_file=args.input_images_list)
# prepare environment
place = fluid.CUDAPlace(0)
#place = fluid.CPUPlace()
place = fluid.CPUPlace()
if args.device >= 0:
place = fluid.CUDAPlace(args.device)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
print fluid.default_main_program()
load_parameter(place)
# load init model
model_dir = args.model_path
model_file_name = None
if not os.path.isdir(args.model_path):
model_dir=os.path.dirname(args.model_path)
model_file_name=os.path.basename(args.model_path)
fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
print "Init model from: %s." % args.model_path
evaluator.reset(exe)
count = 0
for data in test_reader():
count += 1
print 'Process samples: %d\r' % (count, ),
result, avg_distance, avg_seq_error = exe.run(
exe.run(
fluid.default_main_program(),
feed=get_feeder_data(data, place),
fetch_list=[cost] + evaluator.metrics)
feed=get_feeder_data(data, place))
avg_distance, avg_seq_error = evaluator.eval(exe)
print "avg_distance: %s; avg_seq_error: %s" % (avg_distance, avg_seq_error)
print "Read %d samples; avg_distance: %s; avg_seq_error: %s" % (count, avg_distance, avg_seq_error)
def main():
evaluate(data_reader=ctc_reader)
args = parser.parse_args()
print_arguments(args)
evaluate(args, data_reader=ctc_reader)
if __name__ == "__main__":
......
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
from load_model import load_param
from utility import get_feeder_data
import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_infer
import numpy as np
import ctc_reader
import dummy_reader
def load_parameter(place):
params = load_param('./name.map', './data/model/results/pass-00062/')
for name in params:
print "param: %s" % name
t = fluid.global_scope().find_var(name).get_tensor()
t.set(params[name], place)
def inference(infer=ctc_infer, data_reader=dummy_reader):
import argparse
import functools
import os
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_path', str, None, "The model path to be used for inference.")
add_arg('input_images_dir', str, None, "The directory of images.")
add_arg('input_images_list', str, None, "The list file of images.")
add_arg('device', int, 0, "Device id.'-1' means running on CPU")
# yapf: disable
def inference(args, infer=ctc_infer, data_reader=ctc_reader):
"""OCR inference"""
num_classes = data_reader.num_classes()
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
sequence, tmp = infer(images, num_classes)
fluid.layers.Print(tmp)
sequence = infer(images, num_classes)
# data reader
test_reader = data_reader.test()
test_reader = data_reader.test(test_images_dir=args.input_images_dir, test_list_file=args.input_images_list)
# prepare environment
place = fluid.CUDAPlace(0)
place = fluid.CPUPlace()
if args.device >= 0:
place = fluid.CUDAPlace(args.device)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
load_parameter(place)
# load init model
model_dir = args.model_path
model_file_name = None
if not os.path.isdir(args.model_path):
model_dir=os.path.dirname(args.model_path)
model_file_name=os.path.basename(args.model_path)
fluid.io.load_params(exe, dirname=model_dir, filename=model_file_name)
print "Init model from: %s." % args.model_path
for data in test_reader():
result = exe.run(fluid.default_main_program(),
feed=get_feeder_data(
data, place, need_label=False),
fetch_list=[tmp])
print "result: %s" % (list(result[0].flatten()), )
fetch_list=[sequence],
return_numpy=False)
print "result: %s" % (np.array(result[0]).flatten(), )
def main():
inference(data_reader=ctc_reader)
args = parser.parse_args()
print_arguments(args)
inference(args, data_reader=ctc_reader)
if __name__ == "__main__":
main()
import sys
import numpy as np
import ast
def load_parameter(file_name):
with open(file_name, 'rb') as f:
f.read(16) # skip header.
return np.fromfile(f, dtype=np.float32)
def load_param(name_map_file, old_param_dir):
result = {}
name_map = {}
shape_map = {}
with open(name_map_file, 'r') as map_file:
for param in map_file:
old_name, new_name, shape = param.strip().split('=')
name_map[new_name] = old_name
shape_map[new_name] = ast.literal_eval(shape)
for new_name in name_map:
result[new_name] = load_parameter("/".join(
[old_param_dir, name_map[new_name]])).reshape(shape_map[new_name])
return result
if __name__ == "__main__":
name_map_file = "./name.map"
old_param_dir = "./data/model/results/pass-00062/"
result = load_param(name_map_file, old_param_dir)
for p in result:
print "name: %s; param.shape: %s" % (p, result[p].shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册