Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8556eee3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8556eee3
编写于
5月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1578 recitify pretrained path and revert AdjustAllReduceMulAdduse
Merge pull request !1578 from gengdongjie/master
上级
ec5363ad
6930c737
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
20 addition
and
220 deletion
+20
-220
example/resnet50_imagenet2012/run_distribute_train.sh
example/resnet50_imagenet2012/run_distribute_train.sh
+5
-5
example/resnet50_imagenet2012/run_infer.sh
example/resnet50_imagenet2012/run_infer.sh
+2
-2
example/resnet50_imagenet2012/run_standalone_train.sh
example/resnet50_imagenet2012/run_standalone_train.sh
+5
-5
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+0
-1
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+0
-1
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+0
-1
mindspore/ccsrc/optimizer/irpass.h
mindspore/ccsrc/optimizer/irpass.h
+0
-1
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
+0
-109
mindspore/ccsrc/pipeline/parse/function_block.h
mindspore/ccsrc/pipeline/parse/function_block.h
+1
-2
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+0
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+1
-1
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+1
-2
tests/ut/cpp/optimizer/lib_test.cc
tests/ut/cpp/optimizer/lib_test.cc
+0
-19
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
+2
-43
tests/ut/python/train/test_amp.py
tests/ut/python/train/test_amp.py
+2
-26
未找到文件。
example/resnet50_imagenet2012/run_distribute_train.sh
浏览文件 @
8556eee3
...
...
@@ -16,7 +16,7 @@
if
[
$#
!=
2
]
&&
[
$#
!=
3
]
then
echo
"Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo
"Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit
1
fi
...
...
@@ -32,7 +32,7 @@ PATH1=$(get_real_path $1)
PATH2
=
$(
get_real_path
$2
)
if
[
$#
==
3
]
then
PATH3
=
$(
get_real_path
$3
)
PATH3
=
$(
get_real_path
$3
)
fi
if
[
!
-f
"
$PATH1
"
]
...
...
@@ -47,11 +47,11 @@ then
exit
1
fi
if
[
!
-f
"
$PATH3
"
]
then
if
[
$#
==
3
]
&&
[
!
-f
"
$PATH3
"
]
then
echo
"error: PRETRAINED_CKPT_PATH=
$PATH3
is not a file"
exit
1
fi
fi
ulimit
-u
unlimited
export
DEVICE_NUM
=
8
...
...
example/resnet50_imagenet2012/run_infer.sh
浏览文件 @
8556eee3
...
...
@@ -34,13 +34,13 @@ PATH2=$(get_real_path $2)
if
[
!
-d
$PATH1
]
then
echo
"error: DATASET_PATH=
$1
is not a directory"
echo
"error: DATASET_PATH=
$
PATH
1
is not a directory"
exit
1
fi
if
[
!
-f
$PATH2
]
then
echo
"error: CHECKPOINT_PATH=
$2
is not a file"
echo
"error: CHECKPOINT_PATH=
$
PATH
2
is not a file"
exit
1
fi
...
...
example/resnet50_imagenet2012/run_standalone_train.sh
浏览文件 @
8556eee3
...
...
@@ -31,17 +31,17 @@ get_real_path(){
PATH1
=
$(
get_real_path
$1
)
if
[
$#
==
2
]
then
PATH2
=
$(
get_real_path
$2
)
PATH2
=
$(
get_real_path
$2
)
fi
if
[
!
-d
"
$PATH1
"
]
then
echo
"error: DATASET_PATH=
$PATH1
is not a directory"
exit
1
fi
fi
if
[
!
-f
"
$PATH2
"
]
then
if
[
$#
==
2
]
&&
[
!
-f
"
$PATH2
"
]
then
echo
"error: PRETRAINED_CKPT_PATH=
$PATH2
is not a file"
exit
1
fi
...
...
@@ -62,7 +62,7 @@ cd ./train || exit
echo
"start training for device
$DEVICE_ID
"
env
>
env.log
if
[
$#
==
1
]
then
then
python train.py
--do_train
=
True
--dataset_path
=
$PATH1
&> log &
else
python train.py
--do_train
=
True
--dataset_path
=
$PATH1
--pre_trained
=
$PATH2
&> log &
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
8556eee3
...
...
@@ -246,7 +246,6 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const
PrimitivePtr
kPrimMirror
=
std
::
make_shared
<
Primitive
>
(
"_MirrorOperator"
);
const
PrimitivePtr
kPrimVirtualDiv
=
std
::
make_shared
<
Primitive
>
(
"_VirtualDiv"
);
const
PrimitivePtr
kPrimVirtualDataset
=
std
::
make_shared
<
Primitive
>
(
"_VirtualDataset"
);
const
PrimitivePtr
kPrimAllReduce
=
std
::
make_shared
<
Primitive
>
(
"AllReduce"
);
// Debug ops
const
PrimitivePtr
kPrimScalarSummary
=
std
::
make_shared
<
Primitive
>
(
"ScalarSummary"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
8556eee3
...
...
@@ -252,7 +252,6 @@ extern const PrimitivePtr kPrimInDict;
extern
const
PrimitivePtr
kPrimNotInDict
;
// Comm ops
extern
const
PrimitivePtr
kPrimAllReduce
;
extern
const
PrimitivePtr
kPrimMirror
;
extern
const
PrimitivePtr
kPrimVirtualDiv
;
extern
const
PrimitivePtr
kPrimVirtualDataset
;
...
...
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
8556eee3
...
...
@@ -54,7 +54,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{
prim
::
kPrimInsertGradientOf
,
prim
::
kPrimHookBackward
,
prim
::
kPrimPrintShapeType
,
prim
::
kPrimGetRefKey
,
prim
::
kPrimMirror
,
prim
::
kPrimVirtualDiv
});
zero_like_fill_zero_
=
MakeSubstitution
(
ZeroLikeFillZero
(),
"zero_like_fill_zero"
,
prim
::
kPrimZerosLikeTensor
);
adjust_all_reduce_mul_add_
=
MakeSubstitution
(
AdjustAllReduceMulAdd
(),
"adjust_all_reduce_mul_add"
,
prim
::
kPrimAddN
);
// ops eliminate
item_tuple_eliminate_
=
...
...
mindspore/ccsrc/optimizer/irpass.h
浏览文件 @
8556eee3
...
...
@@ -35,7 +35,6 @@ class OptimizeIRPassLib {
SubstitutionPtr
arithmetic_simplify_
;
SubstitutionPtr
special_op_eliminate_
;
SubstitutionPtr
zero_like_fill_zero_
;
SubstitutionPtr
adjust_all_reduce_mul_add_
;
// ops eliminate
SubstitutionPtr
item_tuple_eliminate_
;
...
...
mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h
浏览文件 @
8556eee3
...
...
@@ -228,115 +228,6 @@ class ConstantDuplicateMul : public AnfVisitor {
CNodePtr
cnode_
;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
// grad = grad + weight * decy
// grad = AllReduce(grad) / worker_number
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
class
AdjustAllReduceMulAdd
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
Reset
();
// {prim::kPrimAddN, Zs}
if
(
!
IsPrimitiveCNode
(
node
,
prim
::
kPrimAddN
))
{
return
nullptr
;
}
auto
addn
=
node
->
cast
<
CNodePtr
>
();
if
(
addn
->
size
()
!=
2
)
{
return
nullptr
;
}
AnfVisitor
::
Match
(
prim
::
kPrimMakeTuple
,
{
IsNode
,
IsNode
})(
addn
->
input
(
1
));
if
(
x_
==
nullptr
||
y_
==
nullptr
||
z_
==
nullptr
||
all_reduce_fg_
==
nullptr
)
{
return
nullptr
;
}
auto
addn_maketuple
=
addn
->
input
(
1
);
auto
fg
=
all_reduce_fg_
;
// addn inputs cross the graph, make the inputs same as allreduce node.
if
(
z_
->
isa
<
CNode
>
()
&&
fg
!=
z_
->
func_graph
())
{
auto
cnode_z
=
z_
->
cast
<
CNodePtr
>
();
z_
=
NewCNode
(
cnode_z
->
inputs
(),
fg
);
}
auto
addn_op_node
=
addn
->
input
(
0
);
auto
make_tuple_op_node
=
addn
->
input
(
1
)
->
cast
<
CNodePtr
>
()
->
input
(
0
);
AnfNodePtr
tuple
=
NewCNode
({
make_tuple_op_node
,
z_
,
x_
},
fg
);
AnfNodePtr
add
=
NewCNode
({
addn_op_node
,
tuple
},
fg
);
AnfNodePtr
all_reduce
=
NewCNode
({
all_reduce_
,
add
},
fg
);
AnfNodePtr
mul
=
NewCNode
({
mul_
,
all_reduce
,
y_
},
fg
);
ProcessDependEdge
(
fg
,
addn_maketuple
,
all_reduce
);
return
mul
;
}
void
ProcessDependEdge
(
const
FuncGraphPtr
&
fg
,
const
AnfNodePtr
&
addn_maketuple
,
const
AnfNodePtr
&
new_node
)
{
// If has dynamic loss scale.
auto
&
users_map
=
fg
->
manager
()
->
node_users
();
auto
it
=
users_map
.
find
(
mul_cnode_
);
if
(
it
!=
users_map
.
end
())
{
auto
users
=
it
->
second
;
for
(
auto
&
user_pair
:
users
)
{
auto
node
=
user_pair
.
first
;
if
(
node
!=
addn_maketuple
)
{
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimMakeTuple
))
{
fg
->
manager
()
->
SetEdge
(
node
,
user_pair
.
second
,
new_node
);
}
}
}
}
}
void
Visit
(
const
AnfNodePtr
&
node
)
override
{
if
(
level_
==
0
)
{
level_
=
1
;
is_reduce_match_
=
false
;
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
AnfVisitor
::
Match
(
prim
::
kPrimMul
)(
node
);
level_
=
0
;
if
(
is_reduce_match_
)
{
mul_
=
node
->
cast
<
CNodePtr
>
()
->
input
(
0
);
mul_cnode_
=
node
->
cast
<
CNodePtr
>
();
y_
=
tmp_
;
}
else
{
z_
=
node
;
}
}
if
(
level_
==
1
)
{
// {prim::kPrimAllReduce, X}
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimAllReduce
))
{
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
->
size
()
>
1
)
{
all_reduce_
=
cnode
->
input
(
0
);
x_
=
cnode
->
input
(
1
);
is_reduce_match_
=
true
;
all_reduce_fg_
=
cnode
->
func_graph
();
}
}
else
{
tmp_
=
node
;
}
}
}
void
Reset
()
{
level_
=
0
;
is_reduce_match_
=
false
;
x_
=
nullptr
;
y_
=
nullptr
;
z_
=
nullptr
;
tmp_
=
nullptr
;
all_reduce_fg_
=
nullptr
;
}
private:
int
level_
{
0
};
bool
is_reduce_match_
{
false
};
AnfNodePtr
x_
{
nullptr
},
y_
{
nullptr
},
z_
{
nullptr
},
tmp_
{
nullptr
};
AnfNodePtr
all_reduce_
{
nullptr
},
mul_
{
nullptr
},
mul_cnode_
{
nullptr
};
FuncGraphPtr
all_reduce_fg_
{
nullptr
};
};
class
ArithmeticSimplify
{
public:
ArithmeticSimplify
()
...
...
mindspore/ccsrc/pipeline/parse/function_block.h
浏览文件 @
8556eee3
...
...
@@ -28,7 +28,6 @@
#include <utility>
#include "pipeline/parse/parse_base.h"
#include "utils/log_adapter.h"
#include "utils/ordered_map.h"
namespace
mindspore
{
namespace
parse
{
...
...
@@ -100,7 +99,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
std
::
unordered_map
<
ParameterPtr
,
AnfNodePtr
>
removable_phis_
;
// set state nodes need to insert before function return nodes.
OrderedM
ap
<
AnfNodePtr
,
std
::
string
>
state_assign_
;
std
::
unordered_m
ap
<
AnfNodePtr
,
std
::
string
>
state_assign_
;
// hold declared global variables in function
std
::
set
<
std
::
string
>
global_vars_
;
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
8556eee3
...
...
@@ -82,7 +82,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Arithmetic simplifications
irpass
.
arithmetic_simplify_
,
irpass
.
addn_zero_filter_
,
irpass
.
adjust_all_reduce_mul_add_
,
// Miscellaneous
irpass
.
item_tuple_eliminate_
,
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
8556eee3
...
...
@@ -1275,7 +1275,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Examples:
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float
32
)
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
8556eee3
...
...
@@ -1855,7 +1855,7 @@ class LayerNorm(Primitive):
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
.. math::
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
y = \frac{x - mean
]
}{\sqrt{variance + \epsilon}} * \gamma + \beta
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
...
...
mindspore/ops/primitive.py
浏览文件 @
8556eee3
...
...
@@ -284,8 +284,7 @@ def prim_attr_register(fn):
def
constexpr
(
fn
=
None
,
get_instance
=
True
,
name
=
None
):
"""
Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function
to compute between constant variable and used in constructß.
Makes a PrimitiveWithInfer operator, which infer the value while compiling.
Args:
fn (function): A `fn` use as the infer_value of the output operator.
...
...
tests/ut/cpp/optimizer/lib_test.cc
浏览文件 @
8556eee3
...
...
@@ -556,24 +556,5 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
ASSERT_TRUE
(
CheckOpt
(
beforerl
,
after
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
beforerr
,
after
,
patterns
));
}
TEST_F
(
TestOptLib
,
test_adjust_allreduce_mul_add
)
{
FuncGraphPtr
beforell
=
getPyFun
.
CallAndParseRet
(
"test_adjust_allreduce_mul_add"
,
"beforell"
);
FuncGraphPtr
beforelr
=
getPyFun
.
CallAndParseRet
(
"test_adjust_allreduce_mul_add"
,
"beforelr"
);
FuncGraphPtr
beforerl
=
getPyFun
.
CallAndParseRet
(
"test_adjust_allreduce_mul_add"
,
"beforerl"
);
FuncGraphPtr
beforerr
=
getPyFun
.
CallAndParseRet
(
"test_adjust_allreduce_mul_add"
,
"beforerr"
);
FuncGraphPtr
after1
=
getPyFun
.
CallAndParseRet
(
"test_adjust_allreduce_mul_add"
,
"after1"
);
FuncGraphPtr
before2r
=
getPyFun
.
CallAndParseRet
(
"test_adjust_allreduce_mul_add"
,
"before2r"
);
FuncGraphPtr
before2l
=
getPyFun
.
CallAndParseRet
(
"test_adjust_allreduce_mul_add"
,
"before2l"
);
FuncGraphPtr
after2
=
getPyFun
.
CallAndParseRet
(
"test_adjust_allreduce_mul_add"
,
"after2"
);
auto
patterns
=
std
::
vector
<
SubstitutionPtr
>
({
irpass
.
adjust_all_reduce_mul_add_
});
ASSERT_TRUE
(
CheckOpt
(
beforell
,
after1
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
beforelr
,
after1
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
beforerl
,
after1
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
beforerr
,
after1
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
before2l
,
after2
,
patterns
));
ASSERT_TRUE
(
CheckOpt
(
before2r
,
after2
,
patterns
));
}
}
// namespace opt
}
// namespace mindspore
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
浏览文件 @
8556eee3
...
...
@@ -1045,8 +1045,8 @@ def test_print_tuple_wrapper(tag):
def
test_constant_duplicate_mul
(
tag
):
fns
=
FnDict
()
Mul
=
Primitive
(
'Mul'
)
Sqrt
=
Primitive
(
'Sqrt'
)
Mul
=
Primitive
(
'Mul'
)
;
Sqrt
=
Primitive
(
'Sqrt'
)
;
x
=
Tensor
(
np
.
array
([[
2
,
2
],
[
2
,
3
]]).
astype
(
'float32'
))
tensor1
=
Tensor
(
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]]).
astype
(
'float32'
))
...
...
@@ -1073,44 +1073,3 @@ def test_constant_duplicate_mul(tag):
return
Mul
(
Sqrt
(
x
),
Mul
(
tensor1
,
tensor2
))
return
fns
[
tag
]
def
test_adjust_allreduce_mul_add
(
tag
):
fns
=
FnDict
()
Mul
=
Primitive
(
'Mul'
)
AddN
=
Primitive
(
'AddN'
)
AllReduce
=
Primitive
(
'AllReduce'
)
@
fns
def
beforell
(
x
,
y
,
z
):
return
AddN
((
z
,
Mul
(
y
,
AllReduce
(
x
))))
@
fns
def
beforelr
(
x
,
y
,
z
):
return
AddN
((
z
,
Mul
(
AllReduce
(
x
),
y
)))
@
fns
def
beforerl
(
x
,
y
,
z
):
return
AddN
((
Mul
(
y
,
AllReduce
(
x
)),
z
))
@
fns
def
beforerr
(
x
,
y
,
z
):
return
AddN
((
Mul
(
AllReduce
(
x
),
y
),
z
))
@
fns
def
after1
(
x
,
y
,
z
):
return
Mul
(
AllReduce
(
AddN
((
z
,
x
))),
y
)
@
fns
def
before2r
(
x
,
y
,
z
):
return
AddN
((
Mul
(
AllReduce
(
x
),
y
),
Mul
(
z
,
z
)))
@
fns
def
before2l
(
x
,
y
,
z
):
return
AddN
((
Mul
(
z
,
z
),
Mul
(
AllReduce
(
x
),
y
)))
@
fns
def
after2
(
x
,
y
,
z
):
return
Mul
(
AllReduce
(
AddN
((
Mul
(
z
,
z
),
x
))),
y
)
return
fns
[
tag
]
tests/ut/python/train/test_amp.py
浏览文件 @
8556eee3
...
...
@@ -20,14 +20,9 @@ import mindspore.context as context
from
mindspore
import
Tensor
from
mindspore
import
amp
from
mindspore
import
nn
from
mindspore.train
import
Model
,
ParallelMode
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
import
mindspore.context
as
context
from
mindspore.model_zoo.resnet
import
resnet50
from
mindspore.train
import
Model
from
....dataset_mock
import
MindData
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
mindspore.communication.management
import
init
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
@@ -143,22 +138,3 @@ def test_compile_model_train_O2():
with
pytest
.
raises
(
ValueError
):
# not actual run, the metrics step will fail, check if compile ok.
model
.
eval
(
dataset
)
def
test_compile_model_train_O2_parallel
():
dataset_types
=
(
np
.
float32
,
np
.
float32
)
dataset_shapes
=
((
16
,
16
),
(
16
,
16
))
dataset
=
MindDataSet
(
dataset_types
,
dataset_shapes
)
net
=
NetNoLoss
(
16
,
16
)
loss
=
nn
.
MSELoss
()
optimizer
=
nn
.
Momentum
(
net
.
trainable_params
(),
0.1
,
0.9
,
0.00004
,
1024.0
)
context
.
set_auto_parallel_context
(
global_rank
=
0
,
device_num
=
8
,
mirror_mean
=
True
,
parameter_broadcast
=
True
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
)
init
()
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
optimizer
,
metrics
=
{
"acc"
},
amp_level
=
"O2"
)
model
.
train
(
2
,
dataset
,
dataset_sink_mode
=
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录