Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
68731921
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68731921
编写于
7月 02, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 02, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2565 Add case for precision of bert network
Merge pull request !2565 from ddwolf/add_case_for_precisoin_of_bert
上级
fff45a11
1e43c609
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
342 addition
and
22 deletion
+342
-22
mindspore/ccsrc/device/ascend/ascend_stream_assign.cc
mindspore/ccsrc/device/ascend/ascend_stream_assign.cc
+8
-17
mindspore/ccsrc/kernel/akg/akg_kernel_build.cc
mindspore/ccsrc/kernel/akg/akg_kernel_build.cc
+2
-1
mindspore/ccsrc/session/anf_runtime_algorithm.cc
mindspore/ccsrc/session/anf_runtime_algorithm.cc
+1
-1
mindspore/ccsrc/session/anf_runtime_algorithm.h
mindspore/ccsrc/session/anf_runtime_algorithm.h
+8
-3
tests/st/networks/models/bert/test_bert_graph_kernel.py
tests/st/networks/models/bert/test_bert_graph_kernel.py
+193
-0
tests/st/ops/graph_kernel/test_lamb.py
tests/st/ops/graph_kernel/test_lamb.py
+130
-0
未找到文件。
mindspore/ccsrc/device/ascend/ascend_stream_assign.cc
浏览文件 @
68731921
...
...
@@ -348,16 +348,13 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
uint32_t
cur_stream_id
=
AnfAlgo
::
GetStreamId
(
cur_cnode_ptr
);
if
(
AnfAlgo
::
GetCNodeName
(
cur_cnode_ptr
)
==
kStreamSwitchOpName
)
{
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
cur_cnode_ptr
);
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
true_stream_id
=
GetValue
<
uint32_t
>
(
primitive
->
GetAttr
(
kAttrTrueBranchStream
));
auto
true_stream_id
=
AnfAlgo
::
GetNodeAttr
<
uint32_t
>
(
cur_cnode_ptr
,
kAttrTrueBranchStream
);
processed_streams_
.
emplace
(
true_stream_id
);
auto
value_ptr
=
primitive
->
GetAttr
(
kStreamNeedActivedFirst
);
if
(
value_ptr
==
nullptr
)
{
if
(
!
AnfAlgo
::
HasNodeAttr
(
kStreamNeedActivedFirst
,
cur_cnode_ptr
))
{
continue
;
}
auto
need_active
=
GetValue
<
bool
>
(
value_ptr
);
auto
need_active
=
AnfAlgo
::
GetNodeAttr
<
bool
>
(
cur_cnode_ptr
,
kStreamNeedActivedFirst
);
if
(
need_active
)
{
processed_streams_
.
emplace
(
cur_stream_id
);
}
...
...
@@ -371,20 +368,17 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
void
AscendStreamAssign
::
UpdateStreamSwitch
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
,
const
CNodePtr
&
switch_ptr
,
vector
<
CNodePtr
>
*
orders
)
{
orders
->
emplace_back
(
switch_ptr
);
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
switch_ptr
);
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
value_ptr
=
primitive
->
GetAttr
(
kStreamNeedActivedFirst
);
if
(
value_ptr
==
nullptr
)
{
if
(
!
AnfAlgo
::
HasNodeAttr
(
kStreamNeedActivedFirst
,
switch_ptr
))
{
return
;
}
auto
need_active
=
GetValue
<
bool
>
(
value_ptr
);
auto
need_active
=
AnfAlgo
::
GetNodeAttr
<
bool
>
(
switch_ptr
,
kStreamNeedActivedFirst
);
if
(
!
need_active
)
{
return
;
}
MS_EXCEPTION_IF_NULL
(
switch_ptr
);
auto
true_stream_id
=
GetValue
<
uint32_t
>
(
primitive
->
GetAttr
(
kAttrTrueBranchStream
)
);
auto
true_stream_id
=
AnfAlgo
::
GetNodeAttr
<
uint32_t
>
(
switch_ptr
,
kAttrTrueBranchStream
);
MS_LOG
(
INFO
)
<<
"Streamswtich stream id:"
<<
AnfAlgo
::
GetStreamId
(
switch_ptr
)
<<
"; active stream id:"
<<
true_stream_id
;
...
...
@@ -677,14 +671,11 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
for
(
size_t
i
=
0
;
i
<
cnode_ptr_list
.
size
();
++
i
)
{
cur_cnode_ptr
=
cnode_ptr_list
[
i
];
MS_EXCEPTION_IF_NULL
(
cur_cnode_ptr
);
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
cur_cnode_ptr
);
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
value_ptr
=
primitive
->
GetAttr
(
kStreamNeedActivedFirst
);
if
(
value_ptr
==
nullptr
)
{
if
(
!
AnfAlgo
::
HasNodeAttr
(
kStreamNeedActivedFirst
,
cur_cnode_ptr
))
{
continue
;
}
auto
need_active
=
GetValue
<
bool
>
(
value_ptr
);
auto
need_active
=
AnfAlgo
::
GetNodeAttr
<
bool
>
(
cur_cnode_ptr
,
kStreamNeedActivedFirst
);
if
(
need_active
)
{
auto
stream_id
=
AnfAlgo
::
GetStreamId
(
cur_cnode_ptr
);
MS_LOG
(
INFO
)
<<
"Stream id:"
<<
stream_id
<<
" is need actived at first"
;
...
...
mindspore/ccsrc/kernel/akg/akg_kernel_build.cc
浏览文件 @
68731921
...
...
@@ -276,7 +276,8 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j
input_desc_json
[
kName
]
=
op_input_name
;
input_desc_json
[
kTensorName
]
=
"input_"
+
std
::
to_string
(
GetInputTensorIdxInc
(
anf_node
,
real_input_index
));
auto
input_shape
=
AnfAlgo
::
GetInputDeviceShape
(
anf_node
,
real_input_index
);
if
(
GetInputTensorValue
(
anf_node
,
real_input_index
,
&
input_desc_json
))
{
if
(
anf_node
->
func_graph
()
!=
nullptr
&&
anf_node
->
func_graph
()
->
has_attr
(
FUNC_GRAPH_ATTR_GRAPH_KERNEL
)
&&
GetInputTensorValue
(
anf_node
,
real_input_index
,
&
input_desc_json
))
{
MS_LOG
(
WARNING
)
<<
"we take input["
<<
real_input_index
<<
"] of ["
<<
anf_node
->
DebugString
(
2
)
<<
"] as const tensor, shape: ["
<<
Vector2Str
(
input_shape
)
<<
"], value: "
<<
input_desc_json
[
kValue
];
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.cc
浏览文件 @
68731921
...
...
@@ -291,7 +291,7 @@ bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &no
// graph kernel cnode.
auto
fg
=
AnfAlgo
::
GetCNodeFuncGraphPtr
(
node
);
MS_EXCEPTION_IF_NULL
(
fg
);
return
fg
->
has_
flag
(
key
);
return
fg
->
has_
attr
(
key
);
}
size_t
AnfRuntimeAlgorithm
::
GetInputTensorNum
(
const
AnfNodePtr
&
node
)
{
...
...
mindspore/ccsrc/session/anf_runtime_algorithm.h
浏览文件 @
68731921
...
...
@@ -68,9 +68,14 @@ class AnfRuntimeAlgorithm {
std
::
string
node_debug_log
=
node
->
DebugString
();
MS_LOG
(
EXCEPTION
)
<<
"Only cnode has attr, but this anf is "
<<
node_debug_log
.
c_str
();
}
auto
primitive
=
GetCNodePrimitive
(
node
);
MS_EXCEPTION_IF_NULL
(
primitive
);
return
GetValue
<
T
>
(
primitive
->
GetAttr
(
key
));
// single op cnode.
if
(
auto
primitive
=
GetCNodePrimitive
(
node
);
primitive
!=
nullptr
)
{
return
GetValue
<
T
>
(
primitive
->
GetAttr
(
key
));
}
// graph kernel cnode.
auto
fg
=
GetCNodeFuncGraphPtr
(
node
);
MS_EXCEPTION_IF_NULL
(
fg
);
return
GetValue
<
T
>
(
fg
->
get_attr
(
key
));
}
static
bool
IsTupleOutput
(
const
AnfNodePtr
&
anf
);
// set attr of anf node
...
...
tests/st/networks/models/bert/test_bert_graph_kernel.py
0 → 100644
浏览文件 @
68731921
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train bert network without lossscale"""
import
os
import
pytest
import
numpy
as
np
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset.engine.datasets
as
de
import
mindspore.dataset.transforms.c_transforms
as
C
from
mindspore
import
context
from
mindspore
import
log
as
logger
from
mindspore.common.tensor
import
Tensor
from
mindspore.nn.optim
import
Lamb
from
mindspore.train.callback
import
Callback
from
mindspore.train.loss_scale_manager
import
DynamicLossScaleManager
from
mindspore.train.model
import
Model
from
src.bert_for_pre_training
import
BertNetworkWithLoss
,
BertTrainOneStepWithLossScaleCell
from
src.bert_model
import
BertConfig
DATA_DIR
=
[
"/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"
]
SCHEMA_DIR
=
"/home/workspace/mindspore_dataset/bert/example/datasetSchema.json"
def
get_config
(
version
=
'base'
,
batch_size
=
1
):
"""get config"""
if
version
==
'base'
:
bert_config
=
BertConfig
(
batch_size
=
batch_size
,
seq_length
=
128
,
vocab_size
=
21136
,
hidden_size
=
768
,
num_hidden_layers
=
2
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
use_relative_positions
=
True
,
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float32
)
elif
version
==
'large'
:
bert_config
=
BertConfig
(
batch_size
=
batch_size
,
seq_length
=
128
,
vocab_size
=
30522
,
hidden_size
=
1024
,
num_hidden_layers
=
2
,
num_attention_heads
=
16
,
intermediate_size
=
4096
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.0
,
attention_probs_dropout_prob
=
0.0
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
use_relative_positions
=
True
,
input_mask_from_dataset
=
True
,
token_type_ids_from_dataset
=
True
,
dtype
=
mstype
.
float32
,
compute_type
=
mstype
.
float16
,
enable_fused_layernorm
=
True
)
else
:
bert_config
=
BertConfig
(
batch_size
=
batch_size
)
return
bert_config
def
me_de_train_dataset
():
"""test me de train dataset"""
# apply repeat operations
repeat_count
=
1
ds
=
de
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"input_ids"
,
"input_mask"
,
"segment_ids"
,
"next_sentence_labels"
,
"masked_lm_positions"
,
"masked_lm_ids"
,
"masked_lm_weights"
],
shuffle
=
False
)
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"masked_lm_positions"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"next_sentence_labels"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"segment_ids"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"input_mask"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"input_ids"
,
operations
=
type_cast_op
)
# apply batch operations
batch_size
=
int
(
os
.
getenv
(
'BATCH_SIZE'
,
'16'
))
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
ds
=
ds
.
repeat
(
repeat_count
)
return
ds
def
weight_variable
(
shape
):
"""weight variable"""
np
.
random
.
seed
(
1
)
ones
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
size
=
shape
).
astype
(
np
.
float32
)
return
Tensor
(
ones
)
class
ModelCallback
(
Callback
):
def
__init__
(
self
):
super
(
ModelCallback
,
self
).
__init__
()
self
.
loss_list
=
[]
self
.
overflow_list
=
[]
self
.
lossscale_list
=
[]
def
step_end
(
self
,
run_context
):
cb_params
=
run_context
.
original_args
()
self
.
loss_list
.
append
(
cb_params
.
net_outputs
[
0
].
asnumpy
()[
0
])
self
.
overflow_list
.
append
(
cb_params
.
net_outputs
[
1
].
asnumpy
())
self
.
lossscale_list
.
append
(
cb_params
.
net_outputs
[
2
].
asnumpy
())
print
(
"epoch: {}, outputs are: {}"
.
format
(
cb_params
.
cur_epoch_num
,
str
(
cb_params
.
net_outputs
)))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_bert_tdt
():
"""test bert tdt"""
np
.
random
.
seed
(
0
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
reserve_class_name_in_scope
=
False
)
context
.
set_context
(
enable_graph_kernel
=
True
)
ds
=
me_de_train_dataset
()
config
=
get_config
(
version
=
'large'
,
batch_size
=
16
)
netwithloss
=
BertNetworkWithLoss
(
config
,
True
)
optimizer
=
Lamb
(
netwithloss
.
trainable_params
(),
decay_steps
=
ds
.
get_dataset_size
()
*
ds
.
get_repeat_count
(),
start_learning_rate
=
5e-5
,
end_learning_rate
=
1e-9
,
power
=
10.0
,
warmup_steps
=
0
,
weight_decay
=
0.01
)
scale_window
=
3
scale_manager
=
DynamicLossScaleManager
(
262144
,
2
,
scale_window
)
netwithgrads
=
BertTrainOneStepWithLossScaleCell
(
netwithloss
,
optimizer
=
optimizer
,
scale_update_cell
=
scale_manager
.
get_update_cell
())
netwithgrads
.
set_train
(
True
)
model
=
Model
(
netwithgrads
)
callback
=
ModelCallback
()
params
=
netwithloss
.
trainable_params
()
for
param
in
params
:
param
.
init_data
()
value
=
param
.
default_input
name
=
param
.
name
if
isinstance
(
value
,
Tensor
):
if
name
.
split
(
'.'
)[
-
1
]
in
[
'weight'
]:
if
name
.
split
(
'.'
)[
-
3
]
in
[
'cls2'
]:
logger
.
info
(
"***************** BERT param name is 1 {}"
.
format
(
name
))
param
.
default_input
=
weight_variable
(
value
.
asnumpy
().
shape
)
else
:
logger
.
info
(
"***************** BERT param name is 2 {}"
.
format
(
name
))
tempshape
=
value
.
asnumpy
().
shape
shape
=
(
tempshape
[
1
],
tempshape
[
0
])
weight_value
=
weight_variable
(
shape
).
asnumpy
()
param
.
default_input
=
Tensor
(
np
.
transpose
(
weight_value
,
[
1
,
0
]))
else
:
logger
.
info
(
"***************** BERT param name is 3 {}"
.
format
(
name
))
param
.
default_input
=
weight_variable
(
value
.
asnumpy
().
shape
)
model
.
train
(
1
,
ds
,
callbacks
=
callback
,
dataset_sink_mode
=
False
)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value
=
np
.
array
(
callback
.
loss_list
)
expect_loss_value
=
[
12.559319
,
12.333815
,
12.339806
,
12.350235
,
12.343947
,
12.830965
,
12.375336
,
12.973715
,
12.57929
,
12.7766905
]
error
=
loss_value
-
expect_loss_value
print
(
"loss value: {}"
.
format
(
loss_value
))
print
(
"error value: {}"
.
format
(
error
))
assert
np
.
allclose
(
loss_value
,
expect_loss_value
,
0
,
0.0005
)
overflow
=
np
.
array
(
callback
.
overflow_list
)
expect_overflow
=
[
True
,
True
,
True
,
True
,
False
,
False
,
False
,
True
,
False
,
False
]
print
(
"overflow: {}"
.
format
(
overflow
))
assert
(
overflow
==
expect_overflow
).
all
()
loss_scale
=
np
.
array
(
callback
.
lossscale_list
)
expect_loss_scale
=
[
131072.0
,
65536.0
,
32768.0
,
16384.0
,
16384.0
,
16384.0
,
32768.0
,
16384.0
,
16384.0
,
16384.0
]
print
(
"loss scale: {}"
.
format
(
loss_scale
))
assert
np
.
allclose
(
loss_scale
,
expect_loss_scale
,
0
,
0
)
if
__name__
==
'__main__'
:
test_bert_tdt
()
tests/st/ops/graph_kernel/test_lamb.py
0 → 100644
浏览文件 @
68731921
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
import
numpy
as
np
import
mindspore.context
as
context
from
mindspore
import
Tensor
,
Parameter
from
mindspore.nn
import
Cell
from
mindspore.nn.graph_kernels
import
LambUpdateWithLR
,
LambNextMV
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
)
class
LambNet
(
Cell
):
def
__init__
(
self
,
i2
,
i5
,
x6
):
super
(
LambNet
,
self
).
__init__
()
self
.
i2
=
Parameter
(
i2
,
name
=
'i2'
)
self
.
i5
=
Parameter
(
i5
,
name
=
'i5'
)
self
.
x6
=
Parameter
(
x6
,
name
=
'x6'
)
self
.
lamb_next
=
LambNextMV
()
self
.
lamb_update
=
LambUpdateWithLR
()
def
construct
(
self
,
i1
,
i3
,
i4
,
i6
,
i7
,
i8
,
i9
,
ix0
,
ix1
,
ix2
,
ix3
,
x1
,
x2
,
x3
,
x4
,
x5
,
gy
,
se
,
my
):
return
self
.
lamb_next
(
i1
,
self
.
i2
,
i3
,
i4
,
self
.
i5
,
i6
,
i7
,
i8
,
i9
,
ix0
,
ix1
,
ix2
,
ix3
),
\
self
.
lamb_update
(
x1
,
x2
,
x3
,
x4
,
x5
,
self
.
x6
,
gy
,
se
,
my
)
def
LambUpdateNumpy
(
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
gy
,
se
,
my
):
trust_ratio
=
np
.
where
(
np
.
greater
(
x2
,
gy
),
np
.
where
(
np
.
greater
(
x1
,
gy
),
np
.
divide
(
x2
,
x3
),
se
),
se
)
trust_ratio
=
np
.
maximum
(
np
.
minimum
(
trust_ratio
,
my
),
gy
)
update_with_lr
=
trust_ratio
*
x4
*
x5
next_param
=
x6
-
np
.
reshape
(
update_with_lr
,
x6
.
shape
)
return
next_param
def
LambNextMVNumpy
(
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
i7
,
i8
,
i9
,
x0
,
x1
,
x2
,
x3
):
m_fp32
=
i5
.
astype
(
np
.
float32
)
v_fp32
=
i2
.
astype
(
np
.
float32
)
next_m
=
i8
*
m_fp32
+
i9
*
i4
next_v
=
x0
*
v_fp32
+
x1
*
i1
next_mm
=
next_m
/
i6
next_vv
=
next_v
/
i3
update
=
next_mm
/
(
np
.
sqrt
(
next_vv
)
+
x3
)
add3
=
next_mm
/
np
.
sqrt
(
next_vv
+
x3
)
+
x2
*
i7
return
add3
,
next_m
,
next_v
,
update
def
tensor_all
(
*
args
):
res
=
[
Tensor
(
a
)
for
a
in
args
]
return
res
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
env_onecard
def
test_graph_kernel_lamb
():
shape
=
[
1
,
16
]
oshape
=
[
1
]
np
.
random
.
seed
(
0
)
x1
=
np
.
random
.
normal
(
0
,
1
,
oshape
).
astype
(
np
.
float32
)
x2
=
np
.
random
.
normal
(
0
,
1
,
oshape
).
astype
(
np
.
float32
)
x3
=
np
.
random
.
normal
(
0
,
1
,
oshape
).
astype
(
np
.
float32
)
x4
=
np
.
random
.
normal
(
0
,
1
,
oshape
).
astype
(
np
.
float32
)
x5
=
np
.
random
.
normal
(
0
,
1
,
shape
).
astype
(
np
.
float32
)
x6
=
np
.
random
.
normal
(
0
,
1
,
shape
).
astype
(
np
.
float32
)
gy
=
np
.
random
.
normal
(
0
,
1
,
oshape
).
astype
(
np
.
float32
)
se
=
np
.
random
.
normal
(
0
,
1
,
oshape
).
astype
(
np
.
float32
)
my
=
np
.
random
.
normal
(
0
,
1
,
oshape
).
astype
(
np
.
float32
)
tx1
,
tx2
,
tx3
,
tx4
,
tx5
,
tx6
,
tgy
,
tse
,
tmy
=
tensor_all
(
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
gy
,
se
,
my
)
np
.
random
.
seed
(
1
)
i1
=
np
.
abs
(
np
.
random
.
normal
(
0
,
1
,
shape
)).
astype
(
np
.
float32
)
i2
=
np
.
abs
(
np
.
random
.
normal
(
0
,
1
,
shape
)).
astype
(
np
.
float32
)
i3
=
np
.
abs
(
np
.
random
.
normal
(
0
,
1
,
shape
)).
astype
(
np
.
float32
)
i4
=
np
.
random
.
normal
(
0
,
1
,
shape
).
astype
(
np
.
float32
)
i5
=
np
.
random
.
normal
(
0
,
1
,
shape
).
astype
(
np
.
float32
)
i6
=
np
.
abs
(
np
.
random
.
normal
(
0
,
1
,
shape
)).
astype
(
np
.
float32
)
i7
=
np
.
random
.
normal
(
0
,
1
,
shape
).
astype
(
np
.
float32
)
i8
=
np
.
random
.
normal
(
0
,
1
,
shape
).
astype
(
np
.
float32
)
i9
=
np
.
random
.
normal
(
0
,
1
,
shape
).
astype
(
np
.
float32
)
ix0
=
np
.
abs
(
np
.
random
.
normal
(
0
,
1
,
shape
)).
astype
(
np
.
float32
)
ix1
=
np
.
abs
(
np
.
random
.
normal
(
0
,
1
,
shape
)).
astype
(
np
.
float32
)
ix2
=
np
.
random
.
normal
(
0
,
1
,
shape
).
astype
(
np
.
float32
)
ix3
=
np
.
ones
(
shape
).
astype
(
np
.
float32
)
*
1e-6
ti1
,
ti2
,
ti3
,
ti4
,
ti5
,
ti6
,
ti7
,
ti8
,
ti9
,
tix0
,
tix1
,
tix2
,
tix3
=
\
tensor_all
(
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
i7
,
i8
,
i9
,
ix0
,
ix1
,
ix2
,
ix3
)
context
.
set_context
(
enable_graph_kernel
=
True
)
net
=
LambNet
(
ti2
,
ti5
,
tx6
)
(
wa3
,
wup
),
_
=
net
(
ti1
,
ti3
,
ti4
,
ti6
,
ti7
,
ti8
,
ti9
,
tix0
,
tix1
,
tix2
,
tix3
,
tx1
,
tx2
,
tx3
,
tx4
,
tx5
,
tgy
,
tse
,
tmy
)
wi2
=
net
.
i2
.
data
.
asnumpy
().
copy
()
wi5
=
net
.
i5
.
data
.
asnumpy
().
copy
()
ares
=
net
.
x6
.
data
.
asnumpy
().
copy
()
context
.
set_context
(
enable_graph_kernel
=
False
)
a3
,
a0
,
a1
,
up
=
LambNextMVNumpy
(
i1
,
i2
,
i3
,
i4
,
i5
,
i6
,
i7
,
i8
,
i9
,
ix0
,
ix1
,
ix2
,
ix3
)
np_res
=
LambUpdateNumpy
(
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
gy
,
se
,
my
)
rtol
=
0.0001
atol
=
0.0001
wres
=
(
wa3
.
asnumpy
().
copy
(),
wi5
,
wi2
,
wup
.
asnumpy
().
copy
())
bres
=
(
a3
,
a0
,
a1
,
up
)
cmp_res
=
list
(
map
(
lambda
x
,
y
:
np
.
allclose
(
x
,
y
,
rtol
,
atol
),
wres
,
bres
))
assert
all
(
cmp_res
)
and
np
.
allclose
(
ares
,
np_res
,
rtol
,
atol
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录