提交 8b70f0f3 编写于 作者: Y Yi Wang
上级 5a1f0617
...@@ -166,14 +166,19 @@ There are some open questions here: ...@@ -166,14 +166,19 @@ There are some open questions here:
### Training ### Training
The recommended way to training a model is to call `paddle.train`, The recommended way to training a model is to call `paddle.train`,
which simply calls `paddle.optimizer.Default`, a global variable of which simply calls `paddle.trainer.Default`, a global variable of
type `paddle.optimizer.SGD`. Equivalently, we can do type `paddle.trainer.SGD`. Equivalently, we can do
```python ```python
opt = paddle.optimizer.SGD(...) opt = paddle.trainer.SGD(..., paddle.updater.Adam(...))
opt.train(model, reader=read, ...) opt.train(model, reader=read, ...)
``` ```
Please be aware that a trainer requires an updater as its data
member. This is to make it easier to customize trainers, as
discussed [here](https://github.com/PaddlePaddle/Paddle/issues/1319).
#### Distributed Training #### Distributed Training
If users want to do distributed training on a cluster, s/he should If users want to do distributed training on a cluster, s/he should
...@@ -185,8 +190,9 @@ access a Kubernetes cluster, s/he should be able to call ...@@ -185,8 +190,9 @@ access a Kubernetes cluster, s/he should be able to call
```python ```python
paddle.dist_train(model, paddle.dist_train(model,
trainer=paddle.trainer.SGD(...,
paddle.updater.Adam(...)),
reader=read, reader=read,
optimizer=paddle.optimizer.SGDOptimizer(...),
k8s_user="yi", k8s_user="yi",
k8s_token="kube_cluster_tls.pem", k8s_token="kube_cluster_tls.pem",
k8s_job="hello", k8s_job="hello",
...@@ -196,7 +202,7 @@ paddle.dist_train(model, ...@@ -196,7 +202,7 @@ paddle.dist_train(model,
The pseudo code if `paddle.dist_train` is as follows: The pseudo code if `paddle.dist_train` is as follows:
```python ```python
def dist_train(): def dist_train(topology, parameters, trainer, reader, ...):
if os.getenv("KUBERNETES_SERVICE_HOST") == None: if os.getenv("KUBERNETES_SERVICE_HOST") == None:
image_name = k8s_user + '/' + k8s_job image_name = k8s_user + '/' + k8s_job
docker_build(image_name) docker_build(image_name)
...@@ -209,7 +215,7 @@ def dist_train(): ...@@ -209,7 +215,7 @@ def dist_train():
elif rank < 15: elif rank < 15:
parameter_server() parameter_server()
else: else:
optimizer.train(model, reader=read) trainer.train(model, reader=read)
``` ```
Please be aware that if a process is running on the Kubernetes Please be aware that if a process is running on the Kubernetes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册