提交 8556eee3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1578 recitify pretrained path and revert AdjustAllReduceMulAdduse

Merge pull request !1578 from gengdongjie/master
......@@ -16,7 +16,7 @@
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
......@@ -32,7 +32,7 @@ PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ $# == 3 ]
then
PATH3=$(get_real_path $3)
PATH3=$(get_real_path $3)
fi
if [ ! -f "$PATH1" ]
......@@ -47,11 +47,11 @@ then
exit 1
fi
if [ ! -f "$PATH3" ]
then
if [ $# == 3 ] && [ ! -f "$PATH3" ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
exit 1
fi
fi
ulimit -u unlimited
export DEVICE_NUM=8
......
......@@ -34,13 +34,13 @@ PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$1 is not a directory"
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$2 is not a file"
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
......
......@@ -31,17 +31,17 @@ get_real_path(){
PATH1=$(get_real_path $1)
if [ $# == 2 ]
then
PATH2=$(get_real_path $2)
PATH2=$(get_real_path $2)
fi
if [ ! -d "$PATH1" ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
fi
if [ ! -f "$PATH2" ]
then
if [ $# == 2 ] && [ ! -f "$PATH2" ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
......@@ -62,7 +62,7 @@ cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 1 ]
then
then
python train.py --do_train=True --dataset_path=$PATH1 &> log &
else
python train.py --do_train=True --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
......
......@@ -246,7 +246,6 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
// Debug ops
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
......
......@@ -252,7 +252,6 @@ extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict;
// Comm ops
extern const PrimitivePtr kPrimAllReduce;
extern const PrimitivePtr kPrimMirror;
extern const PrimitivePtr kPrimVirtualDiv;
extern const PrimitivePtr kPrimVirtualDataset;
......
......@@ -54,7 +54,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
// ops eliminate
item_tuple_eliminate_ =
......
......@@ -35,7 +35,6 @@ class OptimizeIRPassLib {
SubstitutionPtr arithmetic_simplify_;
SubstitutionPtr special_op_eliminate_;
SubstitutionPtr zero_like_fill_zero_;
SubstitutionPtr adjust_all_reduce_mul_add_;
// ops eliminate
SubstitutionPtr item_tuple_eliminate_;
......
......@@ -228,115 +228,6 @@ class ConstantDuplicateMul : public AnfVisitor {
CNodePtr cnode_;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
// grad = grad + weight * decy
// grad = AllReduce(grad) / worker_number
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
class AdjustAllReduceMulAdd : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
// {prim::kPrimAddN, Zs}
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
return nullptr;
}
auto addn = node->cast<CNodePtr>();
if (addn->size() != 2) {
return nullptr;
}
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
return nullptr;
}
auto addn_maketuple = addn->input(1);
auto fg = all_reduce_fg_;
// addn inputs cross the graph, make the inputs same as allreduce node.
if (z_->isa<CNode>() && fg != z_->func_graph()) {
auto cnode_z = z_->cast<CNodePtr>();
z_ = NewCNode(cnode_z->inputs(), fg);
}
auto addn_op_node = addn->input(0);
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
ProcessDependEdge(fg, addn_maketuple, all_reduce);
return mul;
}
void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) {
// If has dynamic loss scale.
auto &users_map = fg->manager()->node_users();
auto it = users_map.find(mul_cnode_);
if (it != users_map.end()) {
auto users = it->second;
for (auto &user_pair : users) {
auto node = user_pair.first;
if (node != addn_maketuple) {
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
fg->manager()->SetEdge(node, user_pair.second, new_node);
}
}
}
}
}
void Visit(const AnfNodePtr &node) override {
if (level_ == 0) {
level_ = 1;
is_reduce_match_ = false;
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
AnfVisitor::Match(prim::kPrimMul)(node);
level_ = 0;
if (is_reduce_match_) {
mul_ = node->cast<CNodePtr>()->input(0);
mul_cnode_ = node->cast<CNodePtr>();
y_ = tmp_;
} else {
z_ = node;
}
}
if (level_ == 1) {
// {prim::kPrimAllReduce, X}
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
auto cnode = node->cast<CNodePtr>();
if (cnode->size() > 1) {
all_reduce_ = cnode->input(0);
x_ = cnode->input(1);
is_reduce_match_ = true;
all_reduce_fg_ = cnode->func_graph();
}
} else {
tmp_ = node;
}
}
}
void Reset() {
level_ = 0;
is_reduce_match_ = false;
x_ = nullptr;
y_ = nullptr;
z_ = nullptr;
tmp_ = nullptr;
all_reduce_fg_ = nullptr;
}
private:
int level_{0};
bool is_reduce_match_{false};
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr};
FuncGraphPtr all_reduce_fg_{nullptr};
};
class ArithmeticSimplify {
public:
ArithmeticSimplify()
......
......@@ -28,7 +28,6 @@
#include <utility>
#include "pipeline/parse/parse_base.h"
#include "utils/log_adapter.h"
#include "utils/ordered_map.h"
namespace mindspore {
namespace parse {
......@@ -100,7 +99,7 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis_;
// set state nodes need to insert before function return nodes.
OrderedMap<AnfNodePtr, std::string> state_assign_;
std::unordered_map<AnfNodePtr, std::string> state_assign_;
// hold declared global variables in function
std::set<std::string> global_vars_;
......
......@@ -82,7 +82,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Arithmetic simplifications
irpass.arithmetic_simplify_,
irpass.addn_zero_filter_,
irpass.adjust_all_reduce_mul_add_,
// Miscellaneous
irpass.item_tuple_eliminate_,
......
......@@ -1275,7 +1275,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Examples:
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
......
......@@ -1855,7 +1855,7 @@ class LayerNorm(Primitive):
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
.. math::
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
......
......@@ -284,8 +284,7 @@ def prim_attr_register(fn):
def constexpr(fn=None, get_instance=True, name=None):
"""
Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function
to compute between constant variable and used in constructß.
Makes a PrimitiveWithInfer operator, which infer the value while compiling.
Args:
fn (function): A `fn` use as the infer_value of the output operator.
......
......@@ -556,24 +556,5 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
}
TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
auto patterns = std::vector<SubstitutionPtr>({irpass.adjust_all_reduce_mul_add_});
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
}
} // namespace opt
} // namespace mindspore
......@@ -1045,8 +1045,8 @@ def test_print_tuple_wrapper(tag):
def test_constant_duplicate_mul(tag):
fns = FnDict()
Mul = Primitive('Mul')
Sqrt = Primitive('Sqrt')
Mul = Primitive('Mul');
Sqrt = Primitive('Sqrt');
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
......@@ -1073,44 +1073,3 @@ def test_constant_duplicate_mul(tag):
return Mul(Sqrt(x), Mul(tensor1, tensor2))
return fns[tag]
def test_adjust_allreduce_mul_add(tag):
fns = FnDict()
Mul = Primitive('Mul')
AddN = Primitive('AddN')
AllReduce = Primitive('AllReduce')
@fns
def beforell(x, y, z):
return AddN((z, Mul(y, AllReduce(x))))
@fns
def beforelr(x, y, z):
return AddN((z, Mul(AllReduce(x), y)))
@fns
def beforerl(x, y, z):
return AddN((Mul(y, AllReduce(x)), z))
@fns
def beforerr(x, y, z):
return AddN((Mul(AllReduce(x), y), z))
@fns
def after1(x, y, z):
return Mul(AllReduce(AddN((z, x))), y)
@fns
def before2r(x, y, z):
return AddN((Mul(AllReduce(x), y), Mul(z, z)))
@fns
def before2l(x, y, z):
return AddN((Mul(z, z), Mul(AllReduce(x), y)))
@fns
def after2(x, y, z):
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
return fns[tag]
......@@ -20,14 +20,9 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore import amp
from mindspore import nn
from mindspore.train import Model, ParallelMode
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.context as context
from mindspore.model_zoo.resnet import resnet50
from mindspore.train import Model
from ....dataset_mock import MindData
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import init
def setup_module(module):
context.set_context(mode=context.GRAPH_MODE)
......@@ -143,22 +138,3 @@ def test_compile_model_train_O2():
with pytest.raises(ValueError):
# not actual run, the metrics step will fail, check if compile ok.
model.eval(dataset)
def test_compile_model_train_O2_parallel():
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))
dataset = MindDataSet(dataset_types, dataset_shapes)
net = NetNoLoss(16, 16)
loss = nn.MSELoss()
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
context.set_auto_parallel_context(
global_rank=0, device_num=8,
mirror_mean=True, parameter_broadcast=True,
parallel_mode=ParallelMode.DATA_PARALLEL)
init()
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
model.train(2, dataset, dataset_sink_mode=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册