提交 c254669f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!363 [Host+device training] Fix some tutorial errors

Merge pull request !363 from Xiaoda/master
......@@ -72,6 +72,12 @@ This tutorial introduces how to train [Wide&Deep](https://gitee.com/mindspore/mi
```python
self.embedding_table = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=[184968, 80]).astype(dtype=np_type)), name='V_l2', sparse_grad='V_l2')
```
In the same file, add the import information in the head:
```python
from mindspore import Tensor
```
In the `construct` function of `class WideDeepModel(nn.Cell)` of file `src/wide_and_deep.py`, to adapt for sparse parameters, replace the return value as:
......@@ -103,10 +109,9 @@ This tutorial introduces how to train [Wide&Deep](https://gitee.com/mindspore/mi
## Training the Model
Use the script `script/run_auto_parallel_train.sh`, and run the command `bash run_auto_parallel_train.sh 1 1 DATASET RANK_TABLE_FILE MINDSPORE_HCCL_CONFIG_PATH`,
Use the script `script/run_auto_parallel_train.sh`. Run the command `bash run_auto_parallel_train.sh 1 1 DATASET RANK_TABLE_FILE`,
where the first `1` is the number of accelerators, the second `1` is the number of epochs, `DATASET` is the path of dataset,
and `RANK_TABLE_FILE` and `MINDSPORE_HCCL_CONFIG_PATH` is the path of the above `rank_table_1p_0.json` file.
and `RANK_TABLE_FILE` is the path of the above `rank_table_1p_0.json` file.
The running log is in the directory of `device_0`, where `loss.log` contains every loss value of every step in the epoch:
......
......@@ -70,7 +70,11 @@
```python
self.embedding_table = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=[184968, 80]).astype(dtype=np_type)), name='V_l2', sparse_grad='V_l2')
```
此外,需要在`src/wide_and_deep.py`的头文件引用中添加:
```python
from mindspore import Tensor
```
`src/wide_and_deep.py`文件的`class WideDeepModel(nn.Cell)`类的`construct`函数中,将函数的返回值替换为如下值,以适配参数的稀疏性:
```
......@@ -101,9 +105,8 @@
## 训练模型
使用训练脚本`script/run_auto_parallel_train.sh`
执行命令:`bash run_auto_parallel_train.sh 1 1 DATASET RANK_TABLE_FILE MINDSPORE_HCCL_CONFIG_PATH`
其中第一个`1`表示用例使用的卡数,第二`1`表示训练的epoch数,`DATASET`是数据集所在路径,`RANK_TABLE_FILE``MINDSPORE_HCCL_CONFIG_PATH`为上述`rank_table_1p_0.json`文件所在路径。
使用训练脚本`script/run_auto_parallel_train.sh`。执行命令:`bash run_auto_parallel_train.sh 1 1 DATASET RANK_TABLE_FILE`
其中第一个`1`表示用例使用的卡数,第二`1`表示训练的epoch数,`DATASET`是数据集所在路径,`RANK_TABLE_FILE`为上述`rank_table_1p_0.json`文件所在路径。
运行日志保存在`device_0`目录下,其中`loss.log`保存一个epoch内中多个loss值,如下:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册