From 1e43c609e0ef110242d59b573810ff0d42133d49 Mon Sep 17 00:00:00 2001 From: duxiutao Date: Wed, 24 Jun 2020 16:16:35 +0800 Subject: [PATCH] Add test case and fix two bugs 1. add case to guard precision 2. fix a shape bug 3. fix a funcGraph bug --- .../device/ascend/ascend_stream_assign.cc | 25 +-- .../ccsrc/kernel/akg/akg_kernel_build.cc | 3 +- .../ccsrc/session/anf_runtime_algorithm.cc | 2 +- .../ccsrc/session/anf_runtime_algorithm.h | 11 +- .../models/bert/test_bert_graph_kernel.py | 193 ++++++++++++++++++ tests/st/ops/graph_kernel/test_lamb.py | 130 ++++++++++++ 6 files changed, 342 insertions(+), 22 deletions(-) create mode 100644 tests/st/networks/models/bert/test_bert_graph_kernel.py create mode 100644 tests/st/ops/graph_kernel/test_lamb.py diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index e3491536e..736d6203e 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -348,16 +348,13 @@ void AscendStreamAssign::GetProcessedStream(const NotNull &graph uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { - auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(primitive); - auto true_stream_id = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); + auto true_stream_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrTrueBranchStream); processed_streams_.emplace(true_stream_id); - auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); - if (value_ptr == nullptr) { + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { continue; } - auto need_active = GetValue(value_ptr); + auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); if (need_active) { processed_streams_.emplace(cur_stream_id); } @@ -371,20 +368,17 @@ void AscendStreamAssign::GetProcessedStream(const NotNull &graph void AscendStreamAssign::UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, vector *orders) { orders->emplace_back(switch_ptr); - auto primitive = AnfAlgo::GetCNodePrimitive(switch_ptr); - MS_EXCEPTION_IF_NULL(primitive); - auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); - if (value_ptr == nullptr) { + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) { return; } - auto need_active = GetValue(value_ptr); + auto need_active = AnfAlgo::GetNodeAttr(switch_ptr, kStreamNeedActivedFirst); if (!need_active) { return; } MS_EXCEPTION_IF_NULL(switch_ptr); - auto true_stream_id = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); + auto true_stream_id = AnfAlgo::GetNodeAttr(switch_ptr, kAttrTrueBranchStream); MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) << "; active stream id:" << true_stream_id; @@ -677,14 +671,11 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull &gra for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { cur_cnode_ptr = cnode_ptr_list[i]; MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(primitive); - auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); - if (value_ptr == nullptr) { + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { continue; } - auto need_active = GetValue(value_ptr); + auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); if (need_active) { auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first"; diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc b/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc index 6bd1e7747..0e8d93d47 100644 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc +++ b/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc @@ -276,7 +276,8 @@ bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::j input_desc_json[kName] = op_input_name; input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); - if (GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { + if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && + GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2) << "] as const tensor, shape: [" << Vector2Str(input_shape) << "], value: " << input_desc_json[kValue]; diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 56983a4d2..5f896282d 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -291,7 +291,7 @@ bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &no // graph kernel cnode. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); MS_EXCEPTION_IF_NULL(fg); - return fg->has_flag(key); + return fg->has_attr(key); } size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index c46f0b595..ae8b450e2 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -68,9 +68,14 @@ class AnfRuntimeAlgorithm { std::string node_debug_log = node->DebugString(); MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str(); } - auto primitive = GetCNodePrimitive(node); - MS_EXCEPTION_IF_NULL(primitive); - return GetValue(primitive->GetAttr(key)); + // single op cnode. + if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) { + return GetValue(primitive->GetAttr(key)); + } + // graph kernel cnode. + auto fg = GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + return GetValue(fg->get_attr(key)); } static bool IsTupleOutput(const AnfNodePtr &anf); // set attr of anf node diff --git a/tests/st/networks/models/bert/test_bert_graph_kernel.py b/tests/st/networks/models/bert/test_bert_graph_kernel.py new file mode 100644 index 000000000..ec71cbaa4 --- /dev/null +++ b/tests/st/networks/models/bert/test_bert_graph_kernel.py @@ -0,0 +1,193 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""train bert network without lossscale""" + +import os +import pytest +import numpy as np + +import mindspore.common.dtype as mstype +import mindspore.dataset.engine.datasets as de +import mindspore.dataset.transforms.c_transforms as C +from mindspore import context +from mindspore import log as logger +from mindspore.common.tensor import Tensor +from mindspore.nn.optim import Lamb +from mindspore.train.callback import Callback +from mindspore.train.loss_scale_manager import DynamicLossScaleManager +from mindspore.train.model import Model +from src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepWithLossScaleCell +from src.bert_model import BertConfig + +DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] +SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" + +def get_config(version='base', batch_size=1): + """get config""" + if version == 'base': + bert_config = BertConfig( + batch_size=batch_size, + seq_length=128, + vocab_size=21136, + hidden_size=768, + num_hidden_layers=2, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float32) + elif version == 'large': + bert_config = BertConfig( + batch_size=batch_size, + seq_length=128, + vocab_size=30522, + hidden_size=1024, + num_hidden_layers=2, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + use_relative_positions=True, + input_mask_from_dataset=True, + token_type_ids_from_dataset=True, + dtype=mstype.float32, + compute_type=mstype.float16, + enable_fused_layernorm=True) + else: + bert_config = BertConfig(batch_size=batch_size) + return bert_config + + +def me_de_train_dataset(): + """test me de train dataset""" + # apply repeat operations + repeat_count = 1 + ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", + "next_sentence_labels", "masked_lm_positions", + "masked_lm_ids", "masked_lm_weights"], shuffle=False) + type_cast_op = C.TypeCast(mstype.int32) + ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) + ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) + ds = ds.map(input_columns="next_sentence_labels", operations=type_cast_op) + ds = ds.map(input_columns="segment_ids", operations=type_cast_op) + ds = ds.map(input_columns="input_mask", operations=type_cast_op) + ds = ds.map(input_columns="input_ids", operations=type_cast_op) + # apply batch operations + batch_size = int(os.getenv('BATCH_SIZE', '16')) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_count) + return ds + + +def weight_variable(shape): + """weight variable""" + np.random.seed(1) + ones = np.random.uniform(-0.1, 0.1, size=shape).astype(np.float32) + return Tensor(ones) + + +class ModelCallback(Callback): + def __init__(self): + super(ModelCallback, self).__init__() + self.loss_list = [] + self.overflow_list = [] + self.lossscale_list = [] + + def step_end(self, run_context): + cb_params = run_context.original_args() + self.loss_list.append(cb_params.net_outputs[0].asnumpy()[0]) + self.overflow_list.append(cb_params.net_outputs[1].asnumpy()) + self.lossscale_list.append(cb_params.net_outputs[2].asnumpy()) + print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_bert_tdt(): + """test bert tdt""" + np.random.seed(0) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) + context.set_context(enable_graph_kernel=True) + ds = me_de_train_dataset() + config = get_config(version='large', batch_size=16) + netwithloss = BertNetworkWithLoss(config, True) + optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size()*ds.get_repeat_count(), + start_learning_rate=5e-5, end_learning_rate=1e-9, + power=10.0, warmup_steps=0, weight_decay=0.01) + scale_window = 3 + scale_manager = DynamicLossScaleManager(262144, 2, scale_window) + netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, + scale_update_cell=scale_manager.get_update_cell()) + netwithgrads.set_train(True) + model = Model(netwithgrads) + callback = ModelCallback() + params = netwithloss.trainable_params() + for param in params: + param.init_data() + value = param.default_input + name = param.name + if isinstance(value, Tensor): + if name.split('.')[-1] in ['weight']: + if name.split('.')[-3] in ['cls2']: + logger.info("***************** BERT param name is 1 {}".format(name)) + param.default_input = weight_variable(value.asnumpy().shape) + else: + logger.info("***************** BERT param name is 2 {}".format(name)) + tempshape = value.asnumpy().shape + shape = (tempshape[1], tempshape[0]) + weight_value = weight_variable(shape).asnumpy() + param.default_input = Tensor(np.transpose(weight_value, [1, 0])) + else: + logger.info("***************** BERT param name is 3 {}".format(name)) + param.default_input = weight_variable(value.asnumpy().shape) + model.train(1, ds, callbacks=callback, dataset_sink_mode=False) + + # assertion occurs while the loss value, overflow state or loss_scale value is wrong + loss_value = np.array(callback.loss_list) + expect_loss_value = [12.559319, 12.333815, 12.339806, 12.350235, 12.343947, 12.830965, 12.375336, 12.973715, + 12.57929, 12.7766905] + error = loss_value - expect_loss_value + print("loss value: {}".format(loss_value)) + print("error value: {}".format(error)) + assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) + + overflow = np.array(callback.overflow_list) + expect_overflow = [True, True, True, True, False, False, False, True, False, False] + print("overflow: {}".format(overflow)) + assert (overflow == expect_overflow).all() + + loss_scale = np.array(callback.lossscale_list) + expect_loss_scale = [131072.0, 65536.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0] + print("loss scale: {}".format(loss_scale)) + assert np.allclose(loss_scale, expect_loss_scale, 0, 0) + + +if __name__ == '__main__': + test_bert_tdt() diff --git a/tests/st/ops/graph_kernel/test_lamb.py b/tests/st/ops/graph_kernel/test_lamb.py new file mode 100644 index 000000000..d34c0eea5 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_lamb.py @@ -0,0 +1,130 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +import mindspore.context as context +from mindspore import Tensor, Parameter +from mindspore.nn import Cell +from mindspore.nn.graph_kernels import LambUpdateWithLR, LambNextMV + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class LambNet(Cell): + def __init__(self, i2, i5, x6): + super(LambNet, self).__init__() + self.i2 = Parameter(i2, name='i2') + self.i5 = Parameter(i5, name='i5') + self.x6 = Parameter(x6, name='x6') + self.lamb_next = LambNextMV() + self.lamb_update = LambUpdateWithLR() + + def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3, + x1, x2, x3, x4, x5, gy, se, my): + return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0, + ix1, ix2, ix3), \ + self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my) + +def LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my): + trust_ratio = np.where(np.greater(x2, gy), + np.where(np.greater(x1, gy), np.divide(x2, x3), se), + se) + trust_ratio = np.maximum(np.minimum(trust_ratio, my), gy) + update_with_lr = trust_ratio * x4 * x5 + next_param = x6 - np.reshape(update_with_lr, x6.shape) + return next_param + +def LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3): + m_fp32 = i5.astype(np.float32) + v_fp32 = i2.astype(np.float32) + next_m = i8 * m_fp32 + i9 * i4 + next_v = x0 * v_fp32 + x1 * i1 + next_mm = next_m / i6 + next_vv = next_v / i3 + update = next_mm / (np.sqrt(next_vv) + x3) + add3 = next_mm / np.sqrt(next_vv + x3) + x2 * i7 + return add3, next_m, next_v, update + + +def tensor_all(*args): + res = [Tensor(a) for a in args] + return res + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_graph_kernel_lamb(): + shape = [1, 16] + oshape = [1] + np.random.seed(0) + x1 = np.random.normal(0, 1, oshape).astype(np.float32) + x2 = np.random.normal(0, 1, oshape).astype(np.float32) + x3 = np.random.normal(0, 1, oshape).astype(np.float32) + x4 = np.random.normal(0, 1, oshape).astype(np.float32) + x5 = np.random.normal(0, 1, shape).astype(np.float32) + x6 = np.random.normal(0, 1, shape).astype(np.float32) + gy = np.random.normal(0, 1, oshape).astype(np.float32) + se = np.random.normal(0, 1, oshape).astype(np.float32) + my = np.random.normal(0, 1, oshape).astype(np.float32) + + tx1, tx2, tx3, tx4, tx5, tx6, tgy, tse, tmy = tensor_all( + x1, x2, x3, x4, x5, x6, gy, se, my) + + np.random.seed(1) + i1 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) + i2 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) + i3 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) + i4 = np.random.normal(0, 1, shape).astype(np.float32) + i5 = np.random.normal(0, 1, shape).astype(np.float32) + i6 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) + i7 = np.random.normal(0, 1, shape).astype(np.float32) + i8 = np.random.normal(0, 1, shape).astype(np.float32) + i9 = np.random.normal(0, 1, shape).astype(np.float32) + ix0 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) + ix1 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) + ix2 = np.random.normal(0, 1, shape).astype(np.float32) + ix3 = np.ones(shape).astype(np.float32) * 1e-6 + + ti1, ti2, ti3, ti4, ti5, ti6, ti7, ti8, ti9, tix0, tix1, tix2, tix3 = \ + tensor_all(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0, ix1, ix2, ix3) + + context.set_context(enable_graph_kernel=True) + + net = LambNet(ti2, ti5, tx6) + (wa3, wup), _ = net(ti1, ti3, ti4, ti6, ti7, ti8, ti9, tix0, tix1, tix2, tix3, + tx1, tx2, tx3, tx4, tx5, tgy, tse, tmy) + + wi2 = net.i2.data.asnumpy().copy() + wi5 = net.i5.data.asnumpy().copy() + ares = net.x6.data.asnumpy().copy() + + context.set_context(enable_graph_kernel=False) + + a3, a0, a1, up = LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0, + ix1, ix2, ix3) + + np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my) + + rtol = 0.0001 + atol = 0.0001 + + wres = (wa3.asnumpy().copy(), wi5, wi2, wup.asnumpy().copy()) + bres = (a3, a0, a1, up) + + cmp_res = list(map(lambda x, y: np.allclose(x, y, rtol, atol), + wres, bres)) + + assert all(cmp_res) and np.allclose(ares, np_res, rtol, atol) -- GitLab