trainer_develop.md 6.4 KB
Newer Older
C
Chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
# 如何添加自定义流程

模型训练的流程也可以像`model``reader`一样,由用户自定义,并在`config.yaml`中指定路径,由PaddleRec调用。

PaddleRec可自定义的流程有如下5个:
1. **instance**: 执行训练前的所有操作
2. **network**:执行组网的前向/反向,训练策略的添加
3. **startup**:执行模型的初始化,加载
4. **runnner**: 执行模型的训练
5. **terminal**: 执行训练后的所有操作

## instance

instance由GeneralTrainer首先调用,执行模型组网前的所有操作。用户可以在这里进行下载数据,import不同的包,配置环境变量等操作。instance的官方实现位于[instance.py](../core/trainers/framework/instance.py),instance基类定义如下:

```python
class InstanceBase(object):
    def __init__(self, context):
        pass

    def instance(self, context):
        pass
```

您需要继承`InstanceBase`并命名为`Instance`,完成`instance`的实现,通过上下文信息字典`context`拿到模型所需信息,及保存相关配置。

## network

network将在instanc后调用,执行模型的组网。network的官方实现位于[network.py](../core/trainers/framework/network.py),network基类定义如下:

```python
class NetworkBase(object):
    def __init__(self, context):
        pass

    def build_network(self, context):
        pass
```

可参照其他模式的实现方式,自定其中的部分步骤。您需要您需要继承`NetworkBase`并命名为`Network`,完成`build_network`的实现,通过上下文信息字典`context`拿到模型所需信息,并在context中保存模型的program与scope信息,例如:

```python
context["model"][model_dict["name"]][
                "main_program"] = train_program
context["model"][model_dict["name"]][
    "startup_program"] = startup_program
context["model"][model_dict["name"]]["scope"] = scope
context["model"][model_dict["name"]]["model"] = model
context["model"][model_dict["name"]][
    "default_main_program"] = train_program.clone()
```

## startup

startup执行网络参数的初始化,抑或模型的热启动,主要功能是执行`exe.run(fluid.default_startup_program())`。 startup的官方实现在[startup](../core/trainers/framework/startup.py)

```python
class StartupBase(object):
    def __init__(self, context):
        pass

    def startup(self, context):
        pass

    def load(self, context, is_fleet=False, main_program=None):
        dirname = envs.get_global_env(
            "runner." + context["runner_name"] + ".init_model_path", None)
        if dirname is None or dirname == "":
            return
        print("going to load ", dirname)
        if is_fleet:
            context["fleet"].load_persistables(context["exe"], dirname)
        else:
            fluid.io.load_persistables(
                context["exe"], dirname, main_program=main_program)
```

自定义startup流程,您需要您需要继承`StartupBase`并命名为`Startup`,实现该类型中startup成员函数。

## runner

runner是运行的主要流程,主要功能是reader的运行,网络的运行,指标的打印以及模型的保存。以参数服务器Runner为示例,如下:

```python
class PSRunner(RunnerBase):
    def __init__(self, context):
        print("Running PSRunner.")
        pass

    def run(self, context):
        # 通过超参拿到迭代次数
        epochs = int(
            envs.get_global_env("runner." + context["runner_name"] +
                                ".epochs"))
        # 取第一个phase的模型与reader
        model_dict = context["env"]["phase"][0]
        for epoch in range(epochs):
            begin_time = time.time()
            # 调用run进行训练
            self._run(context, model_dict)
            end_time = time.time()
            seconds = end_time - begin_time
            print("epoch {} done, use time: {}".format(epoch, seconds))
            with fluid.scope_guard(context["model"][model_dict["name"]][
                    "scope"]):
                train_prog = context["model"][model_dict["name"]][
                    "main_program"]
                startup_prog = context["model"][model_dict["name"]][
                    "startup_program"]
                with fluid.program_guard(train_prog, startup_prog):
                    # 保存模型
                    self.save(epoch, context, True)
        context["status"] = "terminal_pass"
```

自定义runner需要参照官方实现[runner.py](../core/trainers/framework/startup.py),继承基类`RunnerBase`,命名为`Runner`,并实现`run`成员函数。

## terminal

terminal主要进行分布式训练结束后的`stop worker`,以及其他需要在模型训练完成后进行的工作,比如数据整理,模型上传等等。

```python
class TerminalBase(object):
    def __init__(self, context):
        pass

    def terminal(self, context):
        print("PaddleRec Finish")
```

自定义terminal需要继承`TerminalBase`命名为`Terminal`,并实现成员函数`terminal`

## 自定义流程参与训练

假如我们自定义了某个流程,将其与model/reader一样,放在workspace下,并同时更改yaml配置中的runner相关选项,PaddleRec会自动用指定的流程替换原始的类别。

```yaml
runner:
  - name: train_runner
    class: single_train
    epochs: 2
    device: cpu
    instance_class_path: "{workspace}/your_instance.py"
    network_class_path: "{workspace}/your_network.py"
    startup_class_path: "{workspace}/your_startup.py"
    runner_class_path: "{workspace}/your_runner.py"
    terminal_class_path: "{workspace}/your_terminal.py"
    print_interval: 1
```

## 示例

官方模型中的TDM是一个很好的示例,该模型自定义了`startup`的实现,可以参考[tdm_startup.py](../models/treebased/tdm/tdm_startup.py)

```python
class Startup(StartupBase):
    def startup(self, context):
        logger.info("Run TDM Trainer Startup Pass")
        if context["engine"] == EngineMode.SINGLE:
            self._single_startup(context)
        else:
            self._cluster_startup(context)

        context['status'] = 'train_pass'

    def _single_startup(self, context):
        # single process

    def _cluster_startup(self, context):
        # cluster process
```

于此同时,在yaml中更改了默认的startup执行类:
```yaml
runner:
- name: runner1
  class: single_train
  startup_class_path: "{workspace}/tdm_startup.py"
  epochs: 10
  device: cpu
```