提交 0925e352 编写于 作者: Z Ziyan

enable optimizer parallel with broadcast

上级 f067c209
......@@ -62,6 +62,7 @@ void ParallelContext::Reset() {
enable_all_reduce_fusion_ = false;
strategy_ckpt_load_file_ = "";
strategy_ckpt_save_file_ = "";
enable_parallel_optimizer_ = false;
}
void ParallelContext::set_device_num(int32_t device_num) {
......
......@@ -100,6 +100,11 @@ class ParallelContext {
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
enable_parallel_optimizer_ = enable_parallel_optimizer;
}
bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; }
void Reset();
private:
......@@ -123,6 +128,7 @@ class ParallelContext {
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_;
bool enable_parallel_optimizer_;
};
void ParallelParameterContextInit(const FuncGraphPtr &func_graph);
......
......@@ -205,6 +205,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
.def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
.def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
"Set enable/disable parallel optimizer.")
.def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
"Get enable/disable parallel optimizer.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
......
......@@ -29,8 +29,9 @@ from .optimizer import Optimizer
_adam_opt = C.MultitypeFuncGraph("adam_opt")
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
@_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag, optim_filter):
"""
Update parameters.
......@@ -44,38 +45,44 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
m (Tensor): m value of parameters.
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Applies weight decay or not.
optim_filter (bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
if optim_filter:
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
update = next_m / (eps + op_sqrt(next_v))
if decay_flag:
update = op_mul(weight_decay_tensor, param_fp32) + update
next_v = F.depend(next_v, F.assign(param, op_cast(next_param, F.dtype(param))))
next_v = F.depend(next_v, F.assign(m, op_cast(next_m, F.dtype(m))))
next_v = F.depend(next_v, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_v
update_with_lr = op_mul(lr, update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
return next_param
return gradient
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
......@@ -300,7 +307,7 @@ class AdamWeightDecay(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.
Examples:
>>> net = Net()
......@@ -328,11 +335,13 @@ class AdamWeightDecay(Optimizer):
def construct(self, gradients):
lr = self.get_lr()
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
return updated_velocity
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
return optim_result
class AdamWeightDecayDynamicLR(Optimizer):
......@@ -363,7 +372,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.
Examples:
>>> net = Net()
......@@ -424,12 +433,14 @@ class AdamWeightDecayDynamicLR(Optimizer):
warmup_lr = self.start_learning_rate * warmup_percent
is_warmup = self.cast(self.greater(self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
updated_velocity = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step)
self.global_step = added_global_step
return updated_velocity
return optim_result
......@@ -32,11 +32,10 @@ num_one = Tensor(np.ones([1]), mstype.float32)
_lamb_opt = C.MultitypeFuncGraph("lamb_opt")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Bool")
@_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Bool", "Bool")
def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, param, m, v,
gradient, decay_flag):
gradient, decay_flag, optim_filter):
"""
Update parameters.
......@@ -52,66 +51,66 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
v (Tensor): v value of parameters.
gradient (Tensor): Gradient of parameters.
decay_flag (bool): Specifies whether param update with weight decay.
optim_filter(bool): Applies parameter update or not.
Returns:
Tensor, the new value of v after updating.
"""
op_mul = P.Mul()
op_sqrt = P.Sqrt()
op_rsqrt = P.Rsqrt()
op_square = P.Square()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_pow = P.Pow()
op_norm = layer.Norm()
op_select = P.Select()
op_greater = P.Greater()
op_fill = P.Fill()
op_dtype = P.DType()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one,
mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one,
mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(
next_vv + eps)) + weight_decay_tensor * param_fp32)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m))
next_v = F.depend(next_v, F.assign(v, next_v))
return next_v
if optim_filter:
op_mul = P.Mul()
op_sqrt = P.Sqrt()
op_rsqrt = P.Rsqrt()
op_square = P.Square()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()
op_pow = P.Pow()
op_norm = layer.Norm()
op_select = P.Select()
op_greater = P.Greater()
op_fill = P.Fill()
op_dtype = P.DType()
param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),
op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
ones)
tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag:
update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_param = F.depend(next_param, F.assign(param, next_param))
next_param = F.depend(next_param, F.assign(m, next_m))
next_param = F.depend(next_param, F.assign(v, next_v))
return next_param
return gradient
lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel")
......@@ -238,7 +237,7 @@ class Lamb(Optimizer):
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
tuple[bool], all elements are True.
Examples:
>>> net = Net()
......@@ -311,18 +310,21 @@ class Lamb(Optimizer):
self.warmup_steps, self.global_step), mstype.float32)
lr = (self.one - is_warmup) * lr + is_warmup * warmup_lr
if self.enable_graph_kernel:
updated_velocity = self.hyper_map(F.partial(lamb_opt_graph_kernel,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
else:
updated_velocity = self.hyper_map(F.partial(_lamb_opt,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients, self.decay_flag)
optim_result = self.hyper_map(F.partial(_lamb_opt,
self.beta1, self.beta2, self.eps, lr,
self.weight_decay_tensor, self.global_step),
self.params, self.moments1, self.moments2, gradients,
self.decay_flag, self.optim_filter)
if self.use_parallel:
optim_result = self.broadcast_params(optim_result)
added_global_step = self.global_step + self.one
F.control_depend(lr, added_global_step)
self.global_step = added_global_step
return updated_velocity
return optim_result
......@@ -22,11 +22,14 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.parallel_utils import ParallelMode
__all__ = ['Optimizer']
......@@ -155,6 +158,27 @@ class Optimizer(Cell):
self.param_length = len(self.parameters)
self.map_ = C.Map()
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
self.use_parallel = use_parallel
if use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
ParallelMode.AUTO_PARALLEL]:
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
(_get_parallel_mode()))
self.dev_num = _get_device_num()
if self.dev_num > self.param_length:
raise RuntimeError("Optimizer segmentation can not be applied when the number of parameters {} is"
" less than the number of devices {}".format(self.param_length, self.dev_num))
self.param_rank = self._get_parameter_group_id()
self.optim_filter = tuple(map(lambda x: x == _get_global_rank(), self.param_rank))
self.param_names = []
for param in self.parameters:
self.param_names.append(param.name)
else:
self.optim_filter = (True,) * self.param_length
def decay_weight(self, gradients):
"""
Weight decay.
......@@ -384,6 +408,51 @@ class Optimizer(Cell):
lr = self.learning_rate
return lr
def _get_parameter_group_id(self):
"""
Get the parameter partition group id, which is less than the number of devices.
Returns:
tuple, the group id tuple of parameters.
"""
rank_list = ()
count = 0
for _ in range(self.param_length):
rank_list = rank_list + (count,)
count = count + 1
if count == self.dev_num:
count = 0
return rank_list
def broadcast_params(self, optim_result):
"""
Apply Broadcast operations in the sequential order of parameter groups.
Returns:
bool, the status flag.
"""
param_group = []
key_group = []
for _ in range(self.dev_num):
param_group.append(F.make_tuple())
key_group.append(F.make_tuple())
for i in range(self.param_length):
param_group[self.param_rank[i]] = param_group[self.param_rank[i]] + (optim_result[i],)
key = P.MakeRefKey(self.param_names[i])()
key_group[self.param_rank[i]] = key_group[self.param_rank[i]] + (key,)
new_param_group = []
for root in range(self.dev_num):
ops = P.Broadcast(root)
next_params = ops(param_group[root])
new_param_group.append(next_params)
for i in range(F.tuple_len(next_params)):
F.assign(key_group[root][i], next_params[i])
status = True
for i in range(self.dev_num - 1):
status = F.control_depend(new_param_group[i][0], new_param_group[i+1])
return status
def construct(self, *hyper_params):
raise NotImplementedError
......
......@@ -85,6 +85,22 @@ def _list_setitem_with_List(data, number_index, value):
return F.list_setitem(data, number_index, value)
@setitem.register("List", "Number", "Tuple")
def _list_setitem_with_Tuple(data, number_index, value):
"""
Assigns value to list.
Inputs:
data (list): Data of type lis.
number_index (Number): Index of data.
value (list): Value given.
Outputs:
list, type is same as the element type of data.
"""
return F.list_setitem(data, number_index, value)
@setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value):
"""
......
......@@ -98,6 +98,7 @@ class AllReduce(PrimitiveWithInfer):
self.op = op
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
self.add_prim_attr('index', 0)
def vm_impl(self, x):
"""Implement by vm mode."""
......
......@@ -59,8 +59,7 @@ class Assign(PrimitiveWithInfer):
return variable
def infer_dtype(self, variable, value):
args = {"variable": variable, "value": value}
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
# Add a type validation later when we don't have to assign a value to RefKey.
return variable
......
......@@ -400,6 +400,23 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_global_rank_is_set()
def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
"""
Set enable/disable parallel optimizer.
Args:
set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
"""
self.check_context_handle()
if not isinstance(enable_parallel_optimizer, bool):
raise TypeError('enable_parallel_optimizer is invalid type')
self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
def get_enable_parallel_optimizer(self):
"""Get parallel optimizer flag."""
self.check_context_handle()
return self._context_handle.get_enable_parallel_optimizer()
def reset(self):
"""Reset all settings."""
self.check_context_handle()
......@@ -433,7 +450,8 @@ _set_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch}
"full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer}
_get_auto_parallel_context_func_map = {
......@@ -447,13 +465,15 @@ _get_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch}
"full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool)
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def _set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
......@@ -493,6 +513,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
Raises:
ValueError: If input key is not attribute in auto parallel context.
......@@ -535,5 +556,6 @@ def _reset_auto_parallel_context():
- parameter_broadcast: False.
- strategy_ckpt_load_file: ""
- strategy_ckpt_save_file: ""
- enable_parallel_optimizer: False
"""
auto_parallel_context().reset()
......@@ -327,6 +327,9 @@ TEST_F(TestStepParallel, CreatOpInstance) {
} else if (name == "instance_name") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "test");
} else if (name == "index") {
parse::ConvertData(py::cast<py::object>(item.second), &converted_ret);
ASSERT_EQ(converted_ret->ToString(), "0");
} else {
MS_LOG(EXCEPTION) << "Test failed";
}
......
......@@ -16,7 +16,6 @@
test assign sub
"""
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
......@@ -36,27 +35,6 @@ class AssignW(nn.Cell):
return x
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.b = Parameter(initializer('ones', [5]), name='b')
self.assign = AssignW()
def construct(self, value):
return self.assign(self.b, value)
def test_assign_through_cell():
context.set_context(mode=context.GRAPH_MODE)
net = Net()
net.to_float(ms.float16)
net.add_flags_recursive(fp16=False)
input_data = Tensor(np.ones([5]).astype(np.float32))
net(input_data)
with pytest.raises(TypeError):
net(None)
class AssignOp(nn.Cell):
def __init__(self):
super(AssignOp, self).__init__()
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test adam """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore import context
class Net(nn.Cell):
"""Net definition"""
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Dense(128, 768, activation='relu')
self.fc2 = nn.Dense(128, 768, activation='relu')
self.fc3 = nn.Dense(128, 768, activation='relu')
self.fc4 = nn.Dense(768, 768, activation='relu')
self.relu4 = nn.ReLU()
self.relu5 = nn.ReLU()
self.transpose = P.Transpose()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
def construct(self, x):
q = self.fc1(x)
k = self.fc2(x)
v = self.fc3(x)
k = self.transpose(k, (1, 0))
c = self.relu4(self.matmul1(q, k))
s = self.relu5(self.matmul2(c, v))
s = self.fc4(s)
return s
def test_AdamWeightDecayDynamicLR():
""" test_AdamWeightDecayDynamicLR """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = AdamWeightDecayDynamicLR(net.trainable_params(), decay_steps=20, learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_AdamWeightDecay():
""" test_AdamWeightDecayDynamicLR """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = AdamWeightDecay(net.trainable_params(), learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_lamb_compile():
""" test_Lamb_compile """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_edge_case():
""" test_edge_case """
auto_parallel_context().set_enable_parallel_optimizer(True)
net = Net()
with pytest.raises(RuntimeError):
context.set_auto_parallel_context(parallel_mode="stand_alone")
Lamb(net.trainable_params(), decay_steps=10)
with pytest.raises(RuntimeError):
Adam(net.trainable_params(), learning_rate=0.1)
with pytest.raises(RuntimeError):
context.set_auto_parallel_context(device_num=16)
Lamb(net.trainable_params(), decay_steps=10)
......@@ -81,6 +81,10 @@ def test_set_auto_parallel_context():
with pytest.raises(ValueError):
set_algo_parameters(tensor_slice_align_size=1025)
auto_parallel_context().set_enable_parallel_optimizer(True)
assert auto_parallel_context().get_enable_parallel_optimizer() is True
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
def test_reset_auto_parallel_context():
context.reset_auto_parallel_context()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册