Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • Issue
  • #1069

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 1月 05, 2017 by saxon_zh@saxon_zhGuest

Paddle Python API 设计文档(初稿)

Created by: jacquesqiao

一个典型的训练过程

gradient_machine.startPass()
updater.startPass()
for each_batch in data:
    gradient_machine.startBatch()
    updater.startBatch()

    gradient_machine.train()

    updater.finishBatch()
    gradient_machine.finishBatch()
updater.finishPass()
gradient_machine.finishPass()

用一个类似调用链的东西,把操作分离开。比如上面的例子可以被拆成两个RunnerChainItems.

  • GradientMachineOperations
  • UpdaterOperations.

一些核心概念

  • Runner
  • RunnerItem
  • RunnerBuilder

Runner

Runner主要利用GradientMachine层面暴露出来的API,将原来Trainer.cpp的逻辑封装起来,Runner中包含很多个RunnerItem,每个RunnerItem完成Trainer中的部分逻辑,用户可以循环调用Runner的run_pass,Runner内部通过一个一个的RunnerItem完成之前各个组件的功能,比如updater,gradientmachine的forward/backward,parameter save/load等操作,用户无需关心。

Runner的实现

class Runner(object):
    def add_item(self, item):
        """
        Add a runner item to runner.
        """
    def run_one_pass(self):
        """
        Run one pass for runner. The parent argument will passed to context.
        """

构造一个runner的过程

    runner = Runner()

    runner.add_item(ItemA())
    runner.add_item(ItemB())

    with runner:
        runner.run_one_pass()

RunnerItem

RunnerItem is an item in Runner. Runner will composite the RunnerItems together and invoke the first RunnerChainItem's methods. And Runner will pass the next chain item's method as next_callback. If current chain item is the last item. A default next_callback will be passed.

Context is a global object shared by items.

class RunnerItem(object):
    """
    RunnerItem is an item in Runner. Runner will composite the
    RunnerItems together and invoke the first RunnerChainItem's methods.
    And Runner will pass the next chain item's method as `next_callback`.
    If current chain item is the last item. A default next_callback will be
    passed.

    Context is a global object shared by items.
    """

    def __init__(self):
        pass

    def initialize(self, context, next_callback):
        """
        initialize method. It will be invoked when Runner start to run.

        :param context: a global object shared by items.
        :type context: RunnerContext
        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: None
        :rtype: None
        """
        next_callback(context)

    def finalize(self, next_callback):
        """
        Finalize method. It will be invoked when Runner complete run, and clean
        some state in RunnerItem.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: None
        :rtype: None
        """
        next_callback()

    def on_pass_begin(self, next_callback):
        """
        Pass Begin Method. Invoked when a pass begins.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: None
        :rtype: None
        """

        next_callback()

    def on_pass_end(self, next_callback):
        """
        Pass End Method. Invoked when a pass ends.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: None
        :rtype: None
        """
        next_callback()

    def on_batch_begin(self, next_callback):
        """
        Batch Begin Method. Invoked when a batch begins. Return true if there is
        no more batch could be processed.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: True if no more batch could be processed.
        :rtype: bool
        """
        return next_callback()

    def on_batch_end(self, next_callback):
        """
        Batch End Method. Invoked when a batch ends. Return true if there is
        no more batch could be processed.

        :param next_callback: next item's initialize method.
        :type next_callback: callable
        :return: True if no more batch could be processed.
        :rtype: bool
        """
        return next_callback()

已经实现的RunnerItem

  • CreateGradientMachine
  • BasicLocalParameterUpdaterOps
  • BasicGradientMachineTrainOps
  • BasicGradientMachineTestOps
  • SaveParamsOnPassEnd
  • Counter

RunnerBuilder

将build Runner的过程封装起来,用with_std_local_trainer等辅助函数方式组装一个可以运行的Runner

import paddle.trainer.PyDataProvider2 as dp
from paddle.trainer_config_helpers import *

import mnist_provider
from py_paddle.trainer import *


@network(
    inputs={
        'pixel': dp.dense_vector(784),
        'label': dp.integer_value(10),
    },
    learning_rate=1e-4,
    learning_method=AdamOptimizer(),
    batch_size=1000,
    model_average=ModelAverage(average_window=0.5),
    regularization=L2Regularization(rate=0.5))
def mnist_network(pixel, label):
    hidden1 = fc_layer(input=pixel, size=200)
    hidden2 = fc_layer(input=hidden1, size=200)
    inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
    cost = classification_cost(input=inference, label=label)
    return cost


def main():
    mnist = mnist_network()
    runner = RunnerBuilder(
        network=mnist, device_count=2).with_std_local_trainer(
            method=mnist_provider.process,
            file_list=['./data/raw_data/train']).with_std_tester(
                method=mnist_provider.process,
                file_list=['./data/raw_data/t10k']).build()
    with runner:
        for _ in xrange(2):
            runner.run_one_pass()


if __name__ == '__main__':
    main()
指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/Paddle#1069
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7