Created by: qingqing01
- 1 去掉 shape_hints,并增加 Input
- 2 去掉Loss种的 infer_shape and infer_dtype
- 3 去掉train/test/eval接口的device和 device_ids,放到prepare接口中
- 4 device和device_ids支持设置为None, 依据安装的Paddle版本自动选择
- 5 静态图在prepare中构建program
- 6 去掉lazy加载。
- 7 prepare接口,optimizer和loss_function支持设置为None,例如eval和test时可以没有opt和loss
Now, there is no unit tests for any module, will add tests in the PR of Fit API.
简单例子:
# 对于静态图,必需声明inputs和labels。动态图也能使用
inputs=[Input([None, 1, 28, 28], 'float32', name='image')] # 单个输入,可不用[]
labels=[Input([None, 1], 'int64', name='label')] # 单个label,可不用[]
model = MNIST()
optim = Momentum(learning_rate=FLAGS.lr, momentum=.9,
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy(),inputs, labels)
原先实现方式:
-
优点:
- 一些case自动推断inputs和label的shape、dtype:
- 简单的情况,如图像分类,固定长度训练,通过输入numpy.array的数据自动输入shape和dtype。
- 当label和组网outputs (通常是forward的返回)shape和dtype相同,自动推断。
- label和Model可以完全解耦。
- 一些case自动推断inputs和label的shape、dtype:
-
缺点:
- 无法自动推断的情况:
- forward函数需设置装饰器shape_hints进行推断。
- Loss需实现infer_shape和infer_dtype进行推断。
- 两处设置地方也不统一。
- shape_hints无法配置化设置shape参数。
- 使得使用DataLoader变得困难: 当DataLoader非iterable模式,无法获取到dtype。
- 在没有拿到数据,网络运行前,静态图中,对graph的变换导致变得困难。
- 无法自动推断的情况:
-
train和eval复用Model导致的缺点: train和eval复用同一个Model,但输入不同,并且输入shape无法自动推断时
- shape_hints必需指定所有forward输入(包含train和eval)
- train和eval的reader必需返回所有输入的数据,如果没有数据,也必须为None
当前的实现方式:
-
优点:
- 可解决原先实现方式的缺点
-
缺点:
- 对于静态图,简单的case,也需要声明inputs和labels。
- 对于静态图,label作为了Model的输入,没有完全解耦。
-
train和eval复用Model导致的缺点仍然存在