提交 3b6de893 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1187 Checkpoint and restore parameter's shape

Merge pull request !1187 from yangzhenzhang/ckpt-and-restore-parameter-shape
......@@ -22,12 +22,15 @@
#include <memory>
#include <numeric>
#include <utility>
#include <map>
#include "common/utils.h"
#include "parallel/device_manager.h"
namespace mindspore {
namespace parallel {
static std::map<std::string, std::vector<int>> param_shapes;
std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL,
AUTO_PARALLEL};
std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING};
......@@ -136,5 +139,56 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
}
return {};
}
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextInit(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) {
return;
}
param_shapes.clear();
}
// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
AbstractBasePtr ptr) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->flags().count(TRAINING) == 0) ||
func_graph->flags()[TRAINING]) {
return;
}
auto iter = param_shapes.find(param_node->name());
if (iter == param_shapes.end()) {
MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name();
return;
}
std::vector<int> shape = iter->second;
std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
ptr->set_shape(base_shape);
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}
// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) {
return;
}
std::vector<int> shape = dyn_cast<abstract::Shape>(ptr->GetShapeTrack())->shape();
auto ret = param_shapes.try_emplace(param_node->name(), shape);
if (!ret.second) {
MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed";
return;
}
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}
} // namespace parallel
} // namespace mindspore
......@@ -26,6 +26,9 @@
#include "parallel/ops_info/ops_utils.h"
#include "parallel/status.h"
#include "utils/convert_utils.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "debug/info.h"
namespace mindspore {
namespace parallel {
......@@ -38,6 +41,8 @@ constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel";
constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming";
constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming";
constexpr char TRAINING[] = "training";
class ParallelContext {
public:
~ParallelContext() = default;
......@@ -114,6 +119,12 @@ class ParallelContext {
std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_;
};
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
AbstractBasePtr ptr);
void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr &param_node,
const AbstractBasePtr &ptr);
} // namespace parallel
} // namespace mindspore
......
......@@ -25,6 +25,7 @@
#include "ir/func_graph_cloner.h"
#include "parallel/costmodel_context.h"
#include "parallel/context.h"
#include "pipeline/pass.h"
#include "pipeline/parse/parse_base.h"
#include "pipeline/parse/data_converter.h"
......@@ -217,6 +218,8 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec = res->args_spec();
parallel::ParallelParameterContextInit(func_graph);
// suppose that there is not KeywordArgument for the top graph
// get the hyper parameter
for (const auto &param : func_graph->parameters()) {
......@@ -224,7 +227,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
if (param_node->has_default()) {
AbstractBasePtr ptr =
abstract::FromValue(parse::data_converter::PyDataToValue(param_node->default_param()), true);
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr);
parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr);
}
}
// Analyze
......
......@@ -379,7 +379,7 @@ class _Executor:
self._params_init_data(obj, params)
if not enable_debug_runtime or enable_ge:
if auto_parallel_mode:
if auto_parallel_mode and "train" in phase:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
obj.load_parameter_slice(params)
......
......@@ -47,7 +47,7 @@ def test_get_parameter_layout():
net = Net(strategy1, strategy2, weight)
net.set_auto_parallel()
exe = me._executor
exe.compile(net, x, auto_parallel_mode=True)
exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout}
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
from mindspore.common.api import _executor
class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None):
super().__init__()
self.mul = P.Mul().set_strategy(strategy1)
self.neg = P.Neg().set_strategy(strategy2)
self.mul_weight = Parameter(mul_weight, "w1")
def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out = self.neg(out)
return out
class EvalNet(Cell):
def __init__(self, network, strategy2=None):
super().__init__()
self.network = network
self.relu = P.ReLU().set_strategy(strategy2)
def construct(self, x, b):
out = self.network(x, b)
out = self.relu(out)
return out
_x = Tensor(np.ones([8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 8]), dtype=ms.float32)
_b = Tensor(np.ones([8, 8]), dtype=ms.float32)
def test_train_and_eval():
context.set_context(save_graphs=True, mode=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16)
strategy1 = ((4, 4), (4, 4))
strategy2 = ((4, 4), )
net = Net(_w1, strategy1, strategy2)
eval_net = EvalNet(net, strategy2=strategy2)
net.set_train()
net.set_auto_parallel()
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True)
eval_net.set_train(mode=False)
eval_net.set_auto_parallel()
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
context.reset_auto_parallel_context()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册