diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index 69383eaa86d63800f662d9992b9df0f2c937a2ee..de92bba507ced1819531df49c21b6f6bcdd98237 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -22,12 +22,15 @@ #include #include #include +#include #include "common/utils.h" #include "parallel/device_manager.h" namespace mindspore { namespace parallel { +static std::map> param_shapes; + std::vector PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL}; std::vector STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING}; @@ -136,5 +139,56 @@ const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const } return {}; } + +// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { + return; + } + param_shapes.clear(); +} + +// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + AbstractBasePtr ptr) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(param_node); + MS_EXCEPTION_IF_NULL(ptr); + if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) || + func_graph->flags()[TRAINING]) { + return; + } + + auto iter = param_shapes.find(param_node->name()); + if (iter == param_shapes.end()) { + MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); + return; + } + std::vector shape = iter->second; + std::shared_ptr base_shape = std::make_shared(shape); + ptr->set_shape(base_shape); + MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; +} + +// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + const AbstractBasePtr &ptr) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(param_node); + MS_EXCEPTION_IF_NULL(ptr); + if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { + return; + } + + std::vector shape = dyn_cast(ptr->GetShapeTrack())->shape(); + auto ret = param_shapes.try_emplace(param_node->name(), shape); + if (!ret.second) { + MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed"; + return; + } + + MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index 3e750c176063b442e72fd8dd46d99aae5439dbb7..cc17f536159be69d50dc515f2b014101918ef06a 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -26,6 +26,9 @@ #include "parallel/ops_info/ops_utils.h" #include "parallel/status.h" #include "utils/convert_utils.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "debug/info.h" namespace mindspore { namespace parallel { @@ -38,6 +41,8 @@ constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel"; constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; +constexpr char TRAINING[] = "training"; + class ParallelContext { public: ~ParallelContext() = default; @@ -114,6 +119,12 @@ class ParallelContext { std::string strategy_ckpt_load_file_; std::string strategy_ckpt_save_file_; }; + +void ParallelParameterContextInit(const FuncGraphPtr &func_graph); +void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + AbstractBasePtr ptr); +void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + const AbstractBasePtr &ptr); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 5e43293983c69b0b30465b5a2c6305a2593fb02b..5c8edd7c8692527c68565ca1fb87fca97cf6718c 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -25,6 +25,7 @@ #include "ir/func_graph_cloner.h" #include "parallel/costmodel_context.h" +#include "parallel/context.h" #include "pipeline/pass.h" #include "pipeline/parse/parse_base.h" #include "pipeline/parse/data_converter.h" @@ -217,6 +218,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); abstract::AbstractBasePtrList args_spec = res->args_spec(); + parallel::ParallelParameterContextInit(func_graph); + // suppose that there is not KeywordArgument for the top graph // get the hyper parameter for (const auto ¶m : func_graph->parameters()) { @@ -224,7 +227,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { if (param_node->has_default()) { AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true); + + parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); args_spec.push_back(ptr); + parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); } } // Analyze diff --git a/mindspore/common/api.py b/mindspore/common/api.py index eb740374f56d4848fff812a270c7de80274b01e1..16df9a00ee97335bbc22c25213f031f9670b838e 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -379,7 +379,7 @@ class _Executor: self._params_init_data(obj, params) if not enable_debug_runtime or enable_ge: - if auto_parallel_mode: + if auto_parallel_mode and "train" in phase: obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) obj.load_parameter_slice(params) diff --git a/tests/ut/python/parallel/test_get_parameter_layout.py b/tests/ut/python/parallel/test_get_parameter_layout.py index 6a3b729c1b68ea7d93e31d9483f01940ced70444..292aba1c6bcd562a48ed60564da8693620527946 100644 --- a/tests/ut/python/parallel/test_get_parameter_layout.py +++ b/tests/ut/python/parallel/test_get_parameter_layout.py @@ -47,7 +47,7 @@ def test_get_parameter_layout(): net = Net(strategy1, strategy2, weight) net.set_auto_parallel() exe = me._executor - exe.compile(net, x, auto_parallel_mode=True) + exe.compile(net, x, phase='train', auto_parallel_mode=True) x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] expect_dict = {'x': x_layout, 'w1': weight_layout} diff --git a/tests/ut/python/parallel/test_train_and_eval.py b/tests/ut/python/parallel/test_train_and_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..3d760a27eef91693f694006f4985aceb359e7f99 --- /dev/null +++ b/tests/ut/python/parallel/test_train_and_eval.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P +from mindspore.common.api import _executor + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.neg = P.Neg().set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.neg(out) + return out + + +class EvalNet(Cell): + def __init__(self, network, strategy2=None): + super().__init__() + self.network = network + self.relu = P.ReLU().set_strategy(strategy2) + + def construct(self, x, b): + out = self.network(x, b) + out = self.relu(out) + return out + + +_x = Tensor(np.ones([8, 8]), dtype=ms.float32) +_w1 = Tensor(np.ones([8, 8]), dtype=ms.float32) +_b = Tensor(np.ones([8, 8]), dtype=ms.float32) + + +def test_train_and_eval(): + context.set_context(save_graphs=True, mode=0) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16) + strategy1 = ((4, 4), (4, 4)) + strategy2 = ((4, 4), ) + net = Net(_w1, strategy1, strategy2) + eval_net = EvalNet(net, strategy2=strategy2) + net.set_train() + net.set_auto_parallel() + _executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True) + + eval_net.set_train(mode=False) + eval_net.set_auto_parallel() + _executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True) + + context.reset_auto_parallel_context() \ No newline at end of file