diff --git a/mindspore/ccsrc/kernel/mng/label_goto.cc b/mindspore/ccsrc/kernel/mng/label_goto.cc new file mode 100644 index 0000000000000000000000000000000000000000..674e48fb00e24ab13105de29a0692359d411950e --- /dev/null +++ b/mindspore/ccsrc/kernel/mng/label_goto.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/mng/label_goto.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelGotoTaskInfo; +using LabelGotoTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelGotoKernel::LabelGotoKernel() { label_ = 0; } + +LabelGotoKernel::~LabelGotoKernel() {} + +bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelGotoKernel init"; + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, anf_node)) { + MS_LOG(EXCEPTION) << "LabelGotoKernel has no attr label_index"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); + MS_LOG(INFO) << "LabelGotoKernel get attr label:" << label_; + return true; +} + +bool LabelGotoKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uintptr_t stream_ptr) { + MS_LOG(INFO) << "LabelGotoKernel launch"; + return true; +} + +std::vector LabelGotoKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "LabelGotoKernel GenTask label:" << label_ << ", stream id:" << stream_id; + std::vector task_info_list; + std::shared_ptr task_info_ptr = std::make_shared(stream_id, label_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/mng/label_goto.h b/mindspore/ccsrc/kernel/mng/label_goto.h new file mode 100644 index 0000000000000000000000000000000000000000..093c95f2f508a1e20f98a59d03a989e9ae4ce07b --- /dev/null +++ b/mindspore/ccsrc/kernel/mng/label_goto.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_MNG_LABEL_GOTO_H +#define MINDSPORE_CCSRC_KERNEL_MNG_LABEL_GOTO_H + +#include +#include +#include "kernel/mng/rt_kernel.h" +#include "kernel/mng/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelGotoKernel : public RtKernel { + public: + LabelGotoKernel(); + ~LabelGotoKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uintptr_t stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t label_; +}; + +MS_REG_RTKERNEL(labelgoto, LabelGotoKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_MNG_LABEL_GOTO_H diff --git a/mindspore/ccsrc/kernel/mng/label_set.cc b/mindspore/ccsrc/kernel/mng/label_set.cc new file mode 100644 index 0000000000000000000000000000000000000000..9041867e5f5a093ca9ceec24a2fa8e1bab39193f --- /dev/null +++ b/mindspore/ccsrc/kernel/mng/label_set.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/mng/label_set.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelSetTaskInfo; +using LabelSetTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelSetKernel::LabelSetKernel() { label_ = 0; } + +LabelSetKernel::~LabelSetKernel() {} + +bool LabelSetKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelSetKernel init"; + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, anf_node)) { + MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); + MS_LOG(INFO) << "LabelSetKernel get attr label:" << label_; + return true; +} + +bool LabelSetKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uintptr_t stream_ptr) { + MS_LOG(INFO) << "LabelSetKernel launch"; + return true; +} + +std::vector LabelSetKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "LabelSetKernel GenTask label:" << label_ << ", stream id:" << stream_id; + std::vector task_info_list; + std::shared_ptr task_info_ptr = std::make_shared(stream_id, label_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/mng/label_set.h b/mindspore/ccsrc/kernel/mng/label_set.h new file mode 100644 index 0000000000000000000000000000000000000000..1ee51e4a4e976998cb37b61dd585e47589d322c8 --- /dev/null +++ b/mindspore/ccsrc/kernel/mng/label_set.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_MNG_LABEL_SET_H +#define MINDSPORE_CCSRC_KERNEL_MNG_LABEL_SET_H + +#include +#include +#include "kernel/mng/rt_kernel.h" +#include "kernel/mng/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelSetKernel : public RtKernel { + public: + LabelSetKernel(); + ~LabelSetKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uintptr_t stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t label_; +}; + +MS_REG_RTKERNEL(labelset, LabelSetKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_MNG_LABEL_SET_H diff --git a/mindspore/ccsrc/kernel/mng/label_switch.cc b/mindspore/ccsrc/kernel/mng/label_switch.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac8dafa933be8fbf9c5120f01cf05d72a5ee09f8 --- /dev/null +++ b/mindspore/ccsrc/kernel/mng/label_switch.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/mng/label_switch.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelSwitchTaskInfo; +using LabelSwitchTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelSwitchKernel::LabelSwitchKernel() { + label_list_ = {}; + cond_ = nullptr; + label_size_ = 0; +} + +LabelSwitchKernel::~LabelSwitchKernel() {} + +bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelSwitchKernel init"; + if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, anf_node)) { + MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_list_ = GetValue>(primitive->GetAttr(kAttrLabelSwitchList)); + label_size_ = label_list_.size(); + MS_LOG(INFO) << "LabelSwitchKernel get attr label size:" << label_size_; + for (auto label : label_list_) { + MS_LOG(INFO) << "label: " << label; + } + return true; +} + +bool LabelSwitchKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uintptr_t stream_ptr) { + MS_LOG(INFO) << "LabelSwitchKernel launch"; + return true; +} + +std::vector LabelSwitchKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; + std::vector task_info_list; + cond_ = inputs[0]->addr; + // std::shared_ptr task_info_ptr = + // std::make_shared(stream_id, label_size_, &label_list_, cond_); + // need updata ge task info define + std::shared_ptr task_info_ptr = std::make_shared(stream_id, label_size_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/mng/label_switch.h b/mindspore/ccsrc/kernel/mng/label_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..57be7fc9477820e12224142ff5f893a93b953012 --- /dev/null +++ b/mindspore/ccsrc/kernel/mng/label_switch.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_MNG_LABEL_SWITCH_H +#define MINDSPORE_CCSRC_KERNEL_MNG_LABEL_SWITCH_H + +#include +#include +#include "kernel/mng/rt_kernel.h" +#include "kernel/mng/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelSwitchKernel : public RtKernel { + public: + LabelSwitchKernel(); + ~LabelSwitchKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uintptr_t stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + std::vector label_list_; + uint32_t label_size_; + void *cond_; +}; + +MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_MNG_LABEL_SWITCH_H diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index d5313247d2e9022ffc183433f73063f0a3be73f5..726db410841094b5c07f9ba17d2bdc0e0b22c5f0 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -136,6 +136,9 @@ constexpr auto kPadOpName = "Pad"; constexpr auto kConv2DBackpropInputOpName = "Conv2DBackpropInput"; constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2"; constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2"; +constexpr auto kLabelSetOpName = "LabelSet"; +constexpr auto kLabelSwitchOpName = "LabelSwitch"; +constexpr auto kLabelGotoOpName = "LabelGoto"; // attr key name constexpr auto kAttrInputNames = "input_names"; @@ -174,6 +177,8 @@ constexpr auto kAttrGroup = "group"; constexpr auto kAttrOp = "op"; constexpr auto kAttrIsTraining = "is_training"; constexpr auto kAttrFusionId = "fusion_id"; +constexpr auto kAttrLabelIndex = "label_index"; +constexpr auto kAttrLabelSwitchList = "label_switch_list"; // attr value constexpr auto kValueTargetSwitch = "target_switch";