“b742f824486459defbc8154c9051d1db3bf4e6f8”上不存在“tools/python/git@gitcode.net:xiaomi/mace.git”
Paddle的单机训练的python api
Created by: reyoung
实现使用python来驱动训练流程,达到类似下面的效果
import paddle
# Context是使用设备的上下文。Paddle究竟使用多少设备,在这里指定
context = paddle.Context(devices=[paddle.cpu_all, paddle.gpu_all]) # use all device in one node.
# 定义一个网络。前面的注解说明这个函数是一个网络定义。
@context.network()
def simple_network(network):
# network参数是一个网络定义的函数集合,包括了我们支持的layers
ipt = network.data_layer(name="input", size=784)
hidden = network.fc_layer(input=ipt, size=200)
predict = network.fc_layer(input=hidden, size=10, act=SoftmaxActivation())
cost = network.classification_cost(input=predict, label=network.data_layer(name="input", size=10))
return cost # 返回优化的目标。相当于现在paddle的outputs
# define a data provider, same as current Paddle process.
@paddle.provider()
def process_data(settings, filename):
for sample in read_from_file(filename):
yield sample
# train all networks in current context.
context.with_train_data(train_files=["a.txt", "b.txt"], method=process_data) # set train data, and data provider
.with_test_data(test_files=["c.txt"], test_period=Batch(1000), method=process_data) # set test data
.use_optimizer(SgdOptimizer()) # set optimizer.
.standard_sgd_train(num_passes=100) # set use standard sgd strategy to train 100 pass.
context.exit(0)
这个事情需要通过以下几个步骤完成:
- 清除掉Paddle运行时依赖的全局变量。否则,我们只能使用多进程来加载多个网络了。
- 添加需要的python api,来完成这项功能。