提交 24f6b9d7 编写于 作者: Y yujianfeng

Add input2output pass

上级 f5dc6fbe
......@@ -94,6 +94,7 @@
#include "pre_activate/ascend/ir_fission/split_fission.h"
#include "pre_activate/ascend/format_type/modify_ops_attrs.h"
#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h"
#include "pre_activate/ascend/ir_fusion/add_input_to_output.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
......@@ -259,6 +260,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
}
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>());
ir_fusion_pm->AddPass(std::make_shared<AddInputToOutput>());
optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
......
......@@ -70,6 +70,21 @@ class KernelQuery {
}
};
using KernelQueryPtr = std::shared_ptr<KernelQuery>;
class OpFinder {
public:
OpFinder() = default;
virtual ~OpFinder() = default;
virtual int GetOpRegisteredOutputNum(const std::string &op_name) {
auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE);
if (op_info == nullptr) {
return -1;
}
return op_info->outputs_ptr().size();
}
};
using OpFinderPtr = std::shared_ptr<OpFinder>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/add_input_to_output.h"
#include <vector>
#include <algorithm>
#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h"
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
namespace mindspore {
namespace opt {
namespace {
void GetInputOrOutputNames(const CNodePtr &cnode, const std::string &attr_name, std::vector<std::string> *names_vec) {
MS_EXCEPTION_IF_NULL(names_vec);
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
ValuePtr names_value = primitive->GetAttr(attr_name);
if (names_value == nullptr) {
return;
}
*names_vec = GetValue<std::vector<std::string>>(names_value);
}
void AddOutputs(const CNodePtr &cnode, const std::vector<size_t> &input_indices) {
MS_EXCEPTION_IF_NULL(cnode);
std::vector<std::string> input_names_vec;
GetInputOrOutputNames(cnode, kAttrInputNames, &input_names_vec);
std::vector<std::string> output_names_vec;
GetInputOrOutputNames(cnode, kAttrOutputNames, &output_names_vec);
AbstractBasePtrList abstract_list;
auto origin_abstract = cnode->abstract();
MS_EXCEPTION_IF_NULL(origin_abstract);
if (origin_abstract->isa<abstract::AbstractTuple>()) {
auto origin_abstract_tuple = dyn_cast<abstract::AbstractTuple>(origin_abstract);
MS_EXCEPTION_IF_NULL(origin_abstract_tuple);
AbstractBasePtrList origin_abstract_list = origin_abstract_tuple->elements();
(void)std::copy(origin_abstract_list.begin(), origin_abstract_list.end(), std::back_inserter(abstract_list));
} else {
abstract_list.emplace_back(origin_abstract);
}
for (size_t i = 0; i < input_indices.size(); ++i) {
size_t index = input_indices[i];
if (index + 1 >= cnode->inputs().size()) {
MS_LOG(INFO) << "The input index " << index << " for converting to output is out of range, "
<< "node: " << cnode->DebugString();
continue;
}
auto node_to_output = cnode->input(index + 1);
MS_EXCEPTION_IF_NULL(node_to_output);
abstract_list.emplace_back(node_to_output->abstract());
if (!input_names_vec.empty() && !output_names_vec.empty() && index < input_names_vec.size()) {
output_names_vec.emplace_back(input_names_vec[index]);
}
}
if (!output_names_vec.empty()) {
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names_vec), cnode);
}
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
cnode->set_abstract(abstract_tuple);
}
} // namespace
const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::string op_name = AnfAlgo::GetCNodeName(cnode);
InputToOutputRegister reg;
if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, &reg)) {
return nullptr;
}
int output_num = op_finder_->GetOpRegisteredOutputNum(op_name);
// No need add output when it is not a tbe op.
if (output_num == -1) {
return nullptr;
}
// No need add output if the output num matches the registered output num for tbe.
if (AnfAlgo::GetOutputTensorNum(cnode) >= IntToSize(output_num)) {
return nullptr;
}
bool is_origin_tuple_output = AnfAlgo::IsTupleOutput(cnode);
AddOutputs(cnode, reg.input_indices());
// No need to create tuple_getitem if the origin output is a tuple because there has already been some tuple_getitems
// pointed to the outputs.
if (is_origin_tuple_output) {
return nullptr;
}
std::vector<AnfNodePtr> new_outputs;
auto new_abstract_tuple = dyn_cast<abstract::AbstractTuple>(cnode->abstract());
MS_EXCEPTION_IF_NULL(new_abstract_tuple);
CreateMultipleOutputsOfAnfNode(func_graph, cnode, new_abstract_tuple->size(), &new_outputs);
if (new_outputs.size() != new_abstract_tuple->size()) {
MS_LOG(EXCEPTION) << "Failed to create outputs of " << cnode->DebugString();
}
return new_outputs[0];
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
#include <string>
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class AddInputToOutput : public PatternProcessPass {
public:
explicit AddInputToOutput(bool multigraph = true)
: PatternProcessPass("add_input_to_output", multigraph), op_finder_(std::make_shared<OpFinder>()) {}
~AddInputToOutput() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
OpFinderPtr op_finder_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h"
#include <utility>
#include "utils/utils.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
namespace {
bool ApplyRMSPropPreCheck(const CNodePtr &node) {
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
}
bool FusedMulApplyMomentumPreCheck(const CNodePtr &node) {
TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16);
}
bool SparseApplyRMSPropPreCheck(const CNodePtr &node) {
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
}
bool ApplyAdagradV2PreCheck(const CNodePtr &node) {
TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16);
}
bool ApplyKerasMomentumPreCheck(const CNodePtr &node) {
TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16);
}
bool SparseApplyFtrlPreCheck(const CNodePtr &node) {
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
}
bool SparseApplyFtrlV2PreCheck(const CNodePtr &node) {
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
}
bool SparseApplyAdagradV2PreCheck(const CNodePtr &node) {
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
}
bool SparseApplyAdadeltaPreCheck(const CNodePtr &node) {
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
}
} // namespace
InputToOutputRegistry::InputToOutputRegistry() {
Register(kApplyRMSPropOpName, {1, 2}, ApplyRMSPropPreCheck);
Register(kFusedMulApplyMomentumOpName, {1}, FusedMulApplyMomentumPreCheck);
Register(kApplyAdagradOpName, {1});
Register(kApplyAdagradDAName, {1, 2});
Register(kApplyAdadeltaOpName, {1, 2});
Register(kApplyPowerSignOpName, {1});
Register(kApplyProximalAdagradOpName, {1});
Register(kApplyAdaMaxOpName, {1, 2});
Register(kApplyAdagradV2OpName, {1}, ApplyAdagradV2PreCheck);
Register(kApplyKerasMomentumOpName, {1}, ApplyKerasMomentumPreCheck);
Register(kSparseApplyFtrlOpName, {1, 2}, SparseApplyFtrlPreCheck);
Register(kSparseApplyFtrlV2OpName, {1, 2}, SparseApplyFtrlV2PreCheck);
Register(kSparseApplyAdagradV2OpName, {1}, SparseApplyAdagradV2PreCheck);
Register(kSparseApplyProximalAdagradOpName, {1});
Register(kSparseApplyAdagradOpName, {1});
Register(kApplyFtrlV2OpName, {1, 2});
Register(kApplyMomentumOpName, {1});
Register(kApplyFtrlOpName, {1, 2});
Register(kApplyAdamOpName, {1, 2});
Register(kApplyCenteredRMSPropOpName, {1, 2, 3});
Register(kApplyAddSignOpName, {1});
Register(kSparseApplyRMSPropOpName, {1, 2}, SparseApplyRMSPropPreCheck);
Register(kSparseApplyAdadeltaOpName, {1, 2}, SparseApplyAdadeltaPreCheck);
Register(kApplyAdamWithAmsgradOpName, {1, 2});
}
InputToOutputRegistry &InputToOutputRegistry::Instance() {
static InputToOutputRegistry instance;
return instance;
}
void InputToOutputRegistry::Register(const InputToOutputRegister &reg) {
auto op_name = reg.op_name();
if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) {
(void)op_input_to_output_map_.insert(make_pair(op_name, reg));
MS_LOG(DEBUG) << op_name << " input2output register successfully!";
}
}
void InputToOutputRegistry::Register(const std::string &op_name, const std::vector<size_t> &input_indices,
const PreCheckFunc &pre_check_func) {
if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) {
InputToOutputRegister reg(op_name, pre_check_func);
reg.set_input_indices(input_indices);
(void)op_input_to_output_map_.insert(make_pair(op_name, reg));
MS_LOG(DEBUG) << op_name << " input2output register successfully!";
}
}
bool InputToOutputRegistry::GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const {
if (op_input_to_output_map_.find(op_name) != op_input_to_output_map_.end()) {
*reg = op_input_to_output_map_.at(op_name);
MS_LOG(DEBUG) << op_name << " input2output find in registry.";
return true;
}
return false;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
#include <string>
#include <unordered_map>
#include <vector>
#include <utility>
#include "ir/anf.h"
#include "common/utils.h"
namespace mindspore {
namespace opt {
using PreCheckFunc = std::function<bool(const CNodePtr &node)>;
class InputToOutputRegister {
public:
explicit InputToOutputRegister(
const std::string &op_name = "", const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; })
: op_name_(op_name), pre_check_func_(pre_check_func) {}
virtual ~InputToOutputRegister() = default;
void set_input_indices(const std::vector<size_t> &input_indices) { input_indices_ = input_indices; }
const std::vector<size_t> &input_indices() const { return input_indices_; }
const std::string &op_name() const { return op_name_; }
private:
std::string op_name_;
std::vector<size_t> input_indices_;
PreCheckFunc pre_check_func_;
};
class InputToOutputRegistry {
public:
static InputToOutputRegistry &Instance();
void Register(const InputToOutputRegister &reg);
void Register(
const std::string &op_name, const std::vector<size_t> &input_indices,
const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; });
bool GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const;
private:
InputToOutputRegistry();
~InputToOutputRegistry() = default;
DISABLE_COPY_AND_ASSIGN(InputToOutputRegistry)
std::unordered_map<std::string, InputToOutputRegister> op_input_to_output_map_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
......@@ -164,6 +164,15 @@ constexpr auto kStridedReadOpName = "StridedRead";
constexpr auto kStridedWriteOpName = "StridedWrite";
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
constexpr auto kFusedAdamName = "FusedAdam";
constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2";
constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2";
constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl";
constexpr auto kSparseApplyFtrlV2OpName = "SparseApplyFtrlV2";
constexpr auto kApplyKerasMomentumOpName = "ApplyKerasMomentum";
constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad";
constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp";
constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta";
constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad";
// attr key name
constexpr auto kAttrInputNames = "input_names";
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "debug/anf_ir_dump.h"
#define private public
#define protected public
#include "pre_activate/ascend/ir_fusion/add_input_to_output.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
class TestHWAddInputToOutput : public BackendCommon {
public:
TestHWAddInputToOutput() : getPyFun_("gtest_input.pre_activate.add_input_to_output_test", true) {}
~TestHWAddInputToOutput() override = default;
public:
UT::PyFuncGraphFetcher getPyFun_;
};
class MockOpFinder : public OpFinder {
public:
MockOpFinder() = default;
~MockOpFinder() override = default;
int GetOpRegisteredOutputNum(const std::string &op_name) override { return 2; }
};
TEST_F(TestHWAddInputToOutput, test_add_input_to_output) {
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_add_input_to_output", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 5; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);
auto ret = kg->get_return();
EXPECT_NE(ret, nullptr);
auto make_tuple = ret->input(1);
EXPECT_NE(make_tuple, nullptr);
auto momentum = make_tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(momentum, nullptr);
EXPECT_NE(momentum->abstract(), nullptr);
EXPECT_FALSE(momentum->abstract()->isa<abstract::AbstractTuple>());
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::AddInputToOutput>();
pass->op_finder_ = std::make_shared<MockOpFinder>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kg);
EXPECT_TRUE(momentum->abstract()->isa<abstract::AbstractTuple>());
}
} // namespace opt
} // namespace mindspore
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
ApplyMomentum = P.ApplyMomentum()
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def test_add_input_to_output(tag):
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4):
return ApplyMomentum(input0, input1, input2, input3, input4)
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册