提交 a4cf9028 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!275 fix bert example script bugs

Merge pull request !275 from yoonlee666/master
...@@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de ...@@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
from mindspore import context from mindspore import context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell
...@@ -49,9 +50,9 @@ def create_train_dataset(batch_size): ...@@ -49,9 +50,9 @@ def create_train_dataset(batch_size):
"""create train dataset""" """create train dataset"""
# apply repeat operations # apply repeat operations
repeat_count = bert_train_cfg.epoch_size repeat_count = bert_train_cfg.epoch_size
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR,
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels",
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"])
type_cast_op = C.TypeCast(mstype.int32) type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册