diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index b73c284aab7a7e4c74f923c8fbf0cbc079784a9e..d33adb04eefb3324bdaeda2cce303514938f58ef 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -388,7 +388,7 @@ class AdamWeightDecayDynamicLR(Optimizer): beta2=0.999, eps=1e-6, weight_decay=0.0, - decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): + decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower()): super(AdamWeightDecayDynamicLR, self).__init__(0.0, params) if self.is_group: raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") diff --git a/model_zoo/bert/src/dataset.py b/model_zoo/bert/src/dataset.py index 7985ca8559ffbdf1a2a46353250c0a063fe0213f..4e7d48605e9bf31a517b9b693bf8c92a17c9cf70 100644 --- a/model_zoo/bert/src/dataset.py +++ b/model_zoo/bert/src/dataset.py @@ -36,8 +36,8 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e ds = de.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"], - shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank, - shard_equal_rows=True) + shuffle=de.Shuffle.FILES if do_shuffle == "true" else False, + num_shards=device_num, shard_id=rank, shard_equal_rows=True) ori_dataset_size = ds.get_dataset_size() print('origin dataset size: ', ori_dataset_size) new_size = ori_dataset_size