Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
2301_77200941
mindspore
提交
5ec91683
M
mindspore
项目概览
2301_77200941
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5ec91683
编写于
7月 09, 2020
作者:
L
linqingke
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
mass eval mertric update.
上级
1f4944fa
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
270 addition
and
59 deletion
+270
-59
model_zoo/faster_rcnn/eval.py
model_zoo/faster_rcnn/eval.py
+1
-1
model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py
model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py
+1
-1
model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py
model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py
+1
-1
model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py
model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py
+1
-1
model_zoo/faster_rcnn/train.py
model_zoo/faster_rcnn/train.py
+1
-1
model_zoo/mass/eval.py
model_zoo/mass/eval.py
+13
-29
model_zoo/mass/scripts/run.sh
model_zoo/mass/scripts/run.sh
+10
-4
model_zoo/mass/src/transformer/__init__.py
model_zoo/mass/src/transformer/__init__.py
+2
-1
model_zoo/mass/src/transformer/embedding.py
model_zoo/mass/src/transformer/embedding.py
+1
-1
model_zoo/mass/src/transformer/infer_mass.py
model_zoo/mass/src/transformer/infer_mass.py
+129
-0
model_zoo/mass/src/utils/__init__.py
model_zoo/mass/src/utils/__init__.py
+3
-1
model_zoo/mass/src/utils/eval_score.py
model_zoo/mass/src/utils/eval_score.py
+92
-0
model_zoo/mass/src/utils/ppl_score.py
model_zoo/mass/src/utils/ppl_score.py
+15
-18
未找到文件。
model_zoo/faster_rcnn/eval.py
浏览文件 @
5ec91683
...
...
@@ -40,7 +40,7 @@ parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoi
parser
.
add_argument
(
"--device_id"
,
type
=
int
,
default
=
0
,
help
=
"Device id, default is 0."
)
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
device_id
=
args_opt
.
device_id
)
def
FasterRcnn_eval
(
dataset_path
,
ckpt_path
,
ann_file
):
"""FasterRcnn evaluation."""
...
...
model_zoo/faster_rcnn/src/FasterRcnn/fpn_neck.py
浏览文件 @
5ec91683
...
...
@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common.initializer
import
initializer
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
def
bias_init_zeros
(
shape
):
"""Bias init method."""
...
...
model_zoo/faster_rcnn/src/FasterRcnn/proposal_generator.py
浏览文件 @
5ec91683
...
...
@@ -22,7 +22,7 @@ from mindspore import Tensor
from
mindspore
import
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
Proposal
(
nn
.
Cell
):
...
...
model_zoo/faster_rcnn/src/FasterRcnn/resnet50.py
浏览文件 @
5ec91683
...
...
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F
from
mindspore
import
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
def
weight_init_ones
(
shape
):
...
...
model_zoo/faster_rcnn/train.py
浏览文件 @
5ec91683
...
...
@@ -52,7 +52,7 @@ parser.add_argument("--device_num", type=int, default=1, help="Use device nums,
parser
.
add_argument
(
"--rank_id"
,
type
=
int
,
default
=
0
,
help
=
"Rank id, default is 0."
)
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
True
,
device_id
=
args_opt
.
device_id
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
device_id
=
args_opt
.
device_id
)
if
__name__
==
'__main__'
:
if
not
args_opt
.
do_eval
and
args_opt
.
run_distribute
:
...
...
model_zoo/mass/eval.py
浏览文件 @
5ec91683
...
...
@@ -15,15 +15,13 @@
"""Evaluation api."""
import
argparse
import
pickle
import
numpy
as
np
from
mindspore.common
import
dtype
as
mstype
from
config
import
TransformerConfig
from
src.transformer
import
infer
from
src.utils
import
ngram_ppl
from
src.transformer
import
infer
,
infer_ppl
from
src.utils
import
Dictionary
from
src.utils
import
roug
e
from
src.utils
import
get_scor
e
parser
=
argparse
.
ArgumentParser
(
description
=
'Evaluation MASS.'
)
parser
.
add_argument
(
"--config"
,
type
=
str
,
required
=
True
,
...
...
@@ -32,6 +30,8 @@ parser.add_argument("--vocab", type=str, required=True,
help
=
"Vocabulary to use."
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
required
=
True
,
help
=
"Result file path."
)
parser
.
add_argument
(
"--metric"
,
type
=
str
,
default
=
'rouge'
,
help
=
'Set eval method.'
)
def
get_config
(
config
):
...
...
@@ -45,31 +45,15 @@ if __name__ == '__main__':
args
,
_
=
parser
.
parse_known_args
()
vocab
=
Dictionary
.
load_from_persisted_dict
(
args
.
vocab
)
_config
=
get_config
(
args
.
config
)
result
=
infer
(
_config
)
if
args
.
metric
==
'rouge'
:
result
=
infer
(
_config
)
else
:
result
=
infer_ppl
(
_config
)
with
open
(
args
.
output
,
"wb"
)
as
f
:
pickle
.
dump
(
result
,
f
,
1
)
ppl_score
=
0.
preds
=
[]
tgts
=
[]
_count
=
0
for
sample
in
result
:
sentence_prob
=
np
.
array
(
sample
[
'prediction_prob'
],
dtype
=
np
.
float32
)
sentence_prob
=
sentence_prob
[:,
1
:]
_ppl
=
[]
for
path
in
sentence_prob
:
_ppl
.
append
(
ngram_ppl
(
path
,
log_softmax
=
True
))
ppl
=
np
.
min
(
_ppl
)
preds
.
append
(
' '
.
join
([
vocab
[
t
]
for
t
in
sample
[
'prediction'
]]))
tgts
.
append
(
' '
.
join
([
vocab
[
t
]
for
t
in
sample
[
'target'
]]))
print
(
f
" | source:
{
' '
.
join
([
vocab
[
t
]
for
t
in
sample
[
'source'
]])
}
"
)
print
(
f
" | target:
{
tgts
[
-
1
]
}
"
)
print
(
f
" | prediction:
{
preds
[
-
1
]
}
"
)
print
(
f
" | ppl:
{
ppl
}
."
)
if
np
.
isinf
(
ppl
):
continue
ppl_score
+=
ppl
_count
+=
1
print
(
f
" | PPL=
{
ppl_score
/
_count
}
."
)
rouge
(
preds
,
tgts
)
# get score by given metric
score
=
get_score
(
result
,
vocab
,
metric
=
args
.
metric
)
print
(
score
)
model_zoo/mass/scripts/run.sh
浏览文件 @
5ec91683
...
...
@@ -18,7 +18,7 @@ export DEVICE_ID=0
export
RANK_ID
=
0
export
RANK_SIZE
=
1
options
=
`
getopt
-u
-o
ht:n:i:j:c:o:v:
-l
help
,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab
:
--
"
$@
"
`
options
=
`
getopt
-u
-o
ht:n:i:j:c:o:v:
m:
-l
help
,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric
:
--
"
$@
"
`
eval set
--
"
$options
"
echo
$options
...
...
@@ -35,6 +35,7 @@ echo_help()
echo
" -c --config set the configuration file"
echo
" -o --output set the output file of inference"
echo
" -v --vocab set the vocabulary"
echo
" -m --metric set the metric"
}
set_hccl_json
()
...
...
@@ -43,8 +44,8 @@ set_hccl_json()
do
if
[[
"
$1
"
==
"-j"
||
"
$1
"
==
"--hccl_json"
]]
then
export
MINDSPORE_HCCL_CONFIG_PATH
=
$2
#/data/wsc/hccl_2p_01.json
export
RANK_TABLE_FILE
=
$2
#/data/wsc/hccl_2p_01.json
export
MINDSPORE_HCCL_CONFIG_PATH
=
$2
export
RANK_TABLE_FILE
=
$2
break
fi
shift
...
...
@@ -119,6 +120,11 @@ do
vocab
=
$2
shift
2
;;
-m
|
--metric
)
echo
"metric"
;
metric
=
$2
shift
2
;;
--
)
shift
break
...
...
@@ -163,7 +169,7 @@ do
python train.py
--config
${
configurations
##*/
}
>>
log.log 2>&1 &
elif
[
"
$task
"
==
"infer"
]
then
python eval.py
--config
${
configurations
##*/
}
--output
${
output
}
--vocab
${
vocab
##*/
}
>>
log_infer.log 2>&1 &
python eval.py
--config
${
configurations
##*/
}
--output
${
output
}
--vocab
${
vocab
##*/
}
--metric
${
metric
}
>>
log_infer.log 2>&1 &
fi
cd
../
done
model_zoo/mass/src/transformer/__init__.py
浏览文件 @
5ec91683
...
...
@@ -19,10 +19,11 @@ from .decoder import TransformerDecoder
from
.beam_search
import
BeamSearchDecoder
from
.transformer_for_train
import
TransformerTraining
,
LabelSmoothedCrossEntropyCriterion
,
\
TransformerNetworkWithLoss
,
TransformerTrainOneStepWithLossScaleCell
from
.infer_mass
import
infer
from
.infer_mass
import
infer
,
infer_ppl
__all__
=
[
"infer"
,
"infer_ppl"
,
"TransformerTraining"
,
"LabelSmoothedCrossEntropyCriterion"
,
"TransformerTrainOneStepWithLossScaleCell"
,
...
...
model_zoo/mass/src/transformer/embedding.py
浏览文件 @
5ec91683
...
...
@@ -41,7 +41,7 @@ class EmbeddingLookup(nn.Cell):
self
.
vocab_size
=
vocab_size
self
.
use_one_hot_embeddings
=
use_one_hot_embeddings
init_weight
=
np
.
random
.
normal
(
0
,
embed_dim
**
-
0.5
,
size
=
[
vocab_size
,
embed_dim
])
init_weight
=
np
.
random
.
normal
(
0
,
embed_dim
**
-
0.5
,
size
=
[
vocab_size
,
embed_dim
])
.
astype
(
np
.
float32
)
# 0 is Padding index, thus init it as 0.
init_weight
[
0
,
:]
=
0
self
.
embedding_table
=
Parameter
(
Tensor
(
init_weight
),
...
...
model_zoo/mass/src/transformer/infer_mass.py
浏览文件 @
5ec91683
...
...
@@ -17,13 +17,16 @@ import time
import
mindspore.nn
as
nn
import
mindspore.common.dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.common.tensor
import
Tensor
from
mindspore.train.model
import
Model
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore
import
context
from
src.dataset
import
load_dataset
from
.transformer_for_infer
import
TransformerInferModel
from
.transformer_for_train
import
TransformerTraining
from
..utils.load_weights
import
load_infer_weights
context
.
set_context
(
...
...
@@ -156,3 +159,129 @@ def infer(config):
shuffle
=
False
)
if
config
.
test_dataset
else
None
prediction
=
transformer_infer
(
config
,
eval_dataset
)
return
prediction
class
TransformerInferPPLCell
(
nn
.
Cell
):
"""
Encapsulation class of transformer network infer for PPL.
Args:
config(TransformerConfig): Config.
Returns:
Tuple[Tensor, Tensor], predicted log prob and label lengths.
"""
def
__init__
(
self
,
config
):
super
(
TransformerInferPPLCell
,
self
).
__init__
()
self
.
transformer
=
TransformerTraining
(
config
,
is_training
=
False
,
use_one_hot_embeddings
=
False
)
self
.
batch_size
=
config
.
batch_size
self
.
vocab_size
=
config
.
vocab_size
self
.
one_hot
=
P
.
OneHot
()
self
.
on_value
=
Tensor
(
float
(
1
),
mstype
.
float32
)
self
.
off_value
=
Tensor
(
float
(
0
),
mstype
.
float32
)
self
.
reduce_sum
=
P
.
ReduceSum
()
self
.
reshape
=
P
.
Reshape
()
self
.
cast
=
P
.
Cast
()
self
.
flat_shape
=
(
config
.
batch_size
*
config
.
seq_length
,)
self
.
batch_shape
=
(
config
.
batch_size
,
config
.
seq_length
)
self
.
last_idx
=
(
-
1
,)
def
construct
(
self
,
source_ids
,
source_mask
,
target_ids
,
target_mask
,
label_ids
,
label_mask
):
"""Defines the computation performed."""
predicted_log_probs
=
self
.
transformer
(
source_ids
,
source_mask
,
target_ids
,
target_mask
)
label_ids
=
self
.
reshape
(
label_ids
,
self
.
flat_shape
)
label_mask
=
self
.
cast
(
label_mask
,
mstype
.
float32
)
one_hot_labels
=
self
.
one_hot
(
label_ids
,
self
.
vocab_size
,
self
.
on_value
,
self
.
off_value
)
label_log_probs
=
self
.
reduce_sum
(
predicted_log_probs
*
one_hot_labels
,
self
.
last_idx
)
label_log_probs
=
self
.
reshape
(
label_log_probs
,
self
.
batch_shape
)
log_probs
=
label_log_probs
*
label_mask
lengths
=
self
.
reduce_sum
(
label_mask
,
self
.
last_idx
)
return
log_probs
,
lengths
def
transformer_infer_ppl
(
config
,
dataset
):
"""
Run infer with Transformer for PPL.
Args:
config (TransformerConfig): Config.
dataset (Dataset): Dataset.
Returns:
List[Dict], prediction, each example has 4 keys, "source",
"target", "log_prob" and "length".
"""
tfm_infer
=
TransformerInferPPLCell
(
config
=
config
)
tfm_infer
.
init_parameters_data
()
parameter_dict
=
load_checkpoint
(
config
.
existed_ckpt
)
load_param_into_net
(
tfm_infer
,
parameter_dict
)
model
=
Model
(
tfm_infer
)
log_probs
=
[]
lengths
=
[]
source_sentences
=
[]
target_sentences
=
[]
for
batch
in
dataset
.
create_dict_iterator
():
source_sentences
.
append
(
batch
[
"source_eos_ids"
])
target_sentences
.
append
(
batch
[
"target_eos_ids"
])
source_ids
=
Tensor
(
batch
[
"source_eos_ids"
],
mstype
.
int32
)
source_mask
=
Tensor
(
batch
[
"source_eos_mask"
],
mstype
.
int32
)
target_ids
=
Tensor
(
batch
[
"target_sos_ids"
],
mstype
.
int32
)
target_mask
=
Tensor
(
batch
[
"target_sos_mask"
],
mstype
.
int32
)
label_ids
=
Tensor
(
batch
[
"target_eos_ids"
],
mstype
.
int32
)
label_mask
=
Tensor
(
batch
[
"target_eos_mask"
],
mstype
.
int32
)
start_time
=
time
.
time
()
log_prob
,
length
=
model
.
predict
(
source_ids
,
source_mask
,
target_ids
,
target_mask
,
label_ids
,
label_mask
)
print
(
f
" | Batch size:
{
config
.
batch_size
}
, "
f
"Time cost:
{
time
.
time
()
-
start_time
}
."
)
log_probs
.
append
(
log_prob
.
asnumpy
())
lengths
.
append
(
length
.
asnumpy
())
output
=
[]
for
inputs
,
ref
,
log_prob
,
length
in
zip
(
source_sentences
,
target_sentences
,
log_probs
,
lengths
):
for
i
in
range
(
config
.
batch_size
):
example
=
{
"source"
:
inputs
[
i
].
tolist
(),
"target"
:
ref
[
i
].
tolist
(),
"log_prob"
:
log_prob
[
i
].
tolist
(),
"length"
:
length
[
i
]
}
output
.
append
(
example
)
return
output
def
infer_ppl
(
config
):
"""
Transformer infer PPL api.
Args:
config (TransformerConfig): Config.
Returns:
list, result with
"""
eval_dataset
=
load_dataset
(
data_files
=
config
.
test_dataset
,
batch_size
=
config
.
batch_size
,
epoch_count
=
1
,
sink_mode
=
config
.
dataset_sink_mode
,
shuffle
=
False
)
if
config
.
test_dataset
else
None
prediction
=
transformer_infer_ppl
(
config
,
eval_dataset
)
return
prediction
model_zoo/mass/src/utils/__init__.py
浏览文件 @
5ec91683
...
...
@@ -20,6 +20,7 @@ from .loss_monitor import LossCallBack
from
.byte_pair_encoding
import
bpe_encode
from
.initializer
import
zero_weight
,
one_weight
,
normal_weight
,
weight_variable
from
.rouge_score
import
rouge
from
.eval_score
import
get_score
__all__
=
[
"Dictionary"
,
...
...
@@ -31,5 +32,6 @@ __all__ = [
"one_weight"
,
"zero_weight"
,
"normal_weight"
,
"weight_variable"
"weight_variable"
,
"get_score"
]
model_zoo/mass/src/utils/eval_score.py
0 → 100644
浏览文件 @
5ec91683
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Get score by given metric."""
from
.ppl_score
import
ngram_ppl
from
.rouge_score
import
rouge
def
get_ppl_score
(
result
):
"""
Calculate Perplexity(PPL) score.
Args:
List[Dict], prediction, each example has 4 keys, "source",
"target", "log_prob" and "length".
Returns:
Float, ppl score.
"""
log_probs
=
[]
total_length
=
0
for
sample
in
result
:
log_prob
=
sample
[
'log_prob'
]
length
=
sample
[
'length'
]
log_probs
.
extend
(
log_prob
)
total_length
+=
length
print
(
f
" | log_prob:
{
log_prob
}
"
)
print
(
f
" | length:
{
length
}
"
)
ppl
=
ngram_ppl
(
log_probs
,
total_length
,
log_softmax
=
True
)
print
(
f
" | final PPL=
{
ppl
}
."
)
return
ppl
def
get_rouge_score
(
result
,
vocab
):
"""
Calculate ROUGE score.
Args:
List[Dict], prediction, each example has 4 keys, "source",
"target", "prediction" and "prediction_prob".
Dictionary, dict instance.
retur:
Str, rouge score.
"""
predictions
=
[]
targets
=
[]
for
sample
in
result
:
predictions
.
append
(
' '
.
join
([
vocab
[
t
]
for
t
in
sample
[
'prediction'
]]))
targets
.
append
(
' '
.
join
([
vocab
[
t
]
for
t
in
sample
[
'target'
]]))
print
(
f
" | source:
{
' '
.
join
([
vocab
[
t
]
for
t
in
sample
[
'source'
]])
}
"
)
print
(
f
" | target:
{
targets
[
-
1
]
}
"
)
return
rouge
(
predictions
,
targets
)
def
get_score
(
result
,
vocab
=
None
,
metric
=
'rouge'
):
"""
Get eval score.
Args:
List[Dict], prediction.
Dictionary, dict instance.
Str, metric function, default is rouge.
Return:
Str, Score.
"""
score
=
None
if
metric
==
'rouge'
:
score
=
get_rouge_score
(
result
,
vocab
)
elif
metric
==
'ppl'
:
score
=
get_ppl_score
(
result
)
else
:
print
(
f
" |metric not in (rouge, ppl)"
)
return
score
model_zoo/mass/src/utils/ppl_score.py
浏览文件 @
5ec91683
...
...
@@ -17,10 +17,7 @@ from typing import Union
import
numpy
as
np
NINF
=
-
1.0
*
1e9
def
ngram_ppl
(
prob
:
Union
[
np
.
ndarray
,
list
],
log_softmax
=
False
,
index
:
float
=
np
.
e
):
def
ngram_ppl
(
prob
:
Union
[
np
.
ndarray
,
list
],
length
:
int
,
log_softmax
=
False
,
index
:
float
=
np
.
e
):
"""
Calculate Perplexity(PPL) score under N-gram language model.
...
...
@@ -39,7 +36,8 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n
Returns:
float, ppl score.
"""
eps
=
1e-8
if
not
length
:
return
np
.
inf
if
not
isinstance
(
prob
,
(
np
.
ndarray
,
list
)):
raise
TypeError
(
"`prob` must be type of list or np.ndarray."
)
if
not
isinstance
(
prob
,
np
.
ndarray
):
...
...
@@ -47,18 +45,17 @@ def ngram_ppl(prob: Union[np.ndarray, list], log_softmax=False, index: float = n
if
prob
.
shape
[
0
]
==
0
:
raise
ValueError
(
"`prob` length must greater than 0."
)
p
=
1.0
sen_len
=
0
for
t
in
range
(
prob
.
shape
[
0
]):
s
=
prob
[
t
]
if
s
<=
NINF
:
break
if
log_softmax
:
s
=
np
.
power
(
index
,
s
)
p
*=
(
1
/
(
s
+
eps
))
sen_len
+=
1
print
(
f
'length:
{
length
}
, log_prob:
{
prob
}
'
)
if
sen_len
==
0
:
return
np
.
inf
if
log_softmax
:
prob
=
np
.
sum
(
prob
)
/
length
ppl
=
1.
/
np
.
power
(
index
,
prob
)
print
(
f
'avg log prob:
{
prob
}
'
)
else
:
p
=
1.
for
i
in
range
(
prob
.
shape
[
0
]):
p
*=
(
1.
/
prob
[
i
])
ppl
=
pow
(
p
,
1
/
length
)
return
pow
(
p
,
1
/
sen_len
)
print
(
f
'ppl val:
{
ppl
}
'
)
return
ppl
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录