提交 c9ba8499 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2676 Mass text summarization update.

Merge pull request !2676 from linqingke/mass
......@@ -18,7 +18,7 @@ export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab -- "$@"`
options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab: -- "$@"`
eval set -- "$options"
echo $options
......@@ -129,6 +129,7 @@ do
esac
done
file_path=$(cd "$(dirname $0)" || exit; pwd)
for((i=0; i < $RANK_SIZE; i++))
do
if [ $RANK_SIZE -gt 1 ]
......@@ -139,7 +140,6 @@ do
fi
echo "Working on device $i"
file_path=$(cd "$(dirname $0)" || exit; pwd)
cd $file_path || exit
cd ../ || exit
......
......@@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""Dataset loader to feed into model."""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC
......@@ -40,12 +39,6 @@ def _load_dataset(input_files, batch_size, epoch_count=1,
if not input_files:
raise FileNotFoundError("Require at least one dataset.")
if not (schema_file and
os.path.exists(schema_file)
and os.path.isfile(schema_file)
and os.path.basename(schema_file).endswith(".json")):
raise FileNotFoundError("`dataset_schema` must be a existed json file.")
if not isinstance(sink_mode, bool):
raise ValueError("`sink` must be type of bool.")
......
......@@ -47,7 +47,7 @@ def rouge(hypothesis: List[str], target: List[str]):
edited_ref.append(r + "\n")
_rouge = Rouge()
scores = _rouge.get_scores(edited_hyp, target, avg=True)
scores = _rouge.get_scores(edited_hyp, edited_ref, avg=True)
print(" | ROUGE Score:")
print(f" | RG-1(F): {scores['rouge-1']['f'] * 100:8.2f}")
print(f" | RG-2(F): {scores['rouge-2']['f'] * 100:8.2f}")
......
......@@ -120,6 +120,7 @@ def _build_training_pipeline(config: TransformerConfig,
test_dataset (Dataset): Test dataset.
"""
net_with_loss = TransformerNetworkWithLoss(config, is_training=True)
net_with_loss.init_parameters_data()
if config.existed_ckpt:
if config.existed_ckpt.endswith(".npz"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册