提交 fa0684d1 编写于 作者: Y yujianfeng

Add pack and concat fission pass

上级 4a19e6b8
......@@ -97,6 +97,8 @@
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
......@@ -153,6 +155,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
ir_fusion_pm->AddPass(std::make_shared<ConcatFission>());
}
} // namespace
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
#include <memory>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr CreateNewConcat(const FuncGraphPtr &func_graph, const CNodePtr &origin_concat_cnode, size_t begin_index,
size_t offset) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(origin_concat_cnode);
std::vector<AnfNodePtr> new_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
for (size_t i = begin_index; i < begin_index + offset; ++i) {
new_concat_inputs.push_back(origin_concat_cnode->input(i));
}
CNodePtr new_concat = func_graph->NewCNode(new_concat_inputs);
MS_EXCEPTION_IF_NULL(new_concat);
new_concat->set_scope(origin_concat_cnode->scope());
// Set attrs
AnfAlgo::CopyNodeAttr(kAttrAxis, origin_concat_cnode, new_concat);
AnfAlgo::CopyNodeAttr(kAttrT, origin_concat_cnode, new_concat);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_concat);
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(offset)), new_concat);
std::vector<int> dyn_input_sizes{SizeToInt(offset)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_concat);
// infer shape
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_concat_cnode, 0);
auto axis = AnfAlgo::GetNodeAttr<int>(origin_concat_cnode, kAttrAxis);
if (axis < 0) {
axis += input_shape.size();
}
auto output_shape = AnfAlgo::GetOutputInferShape(origin_concat_cnode, 0);
if (axis < 0 || axis >= SizeToInt(output_shape.size()) || axis >= SizeToInt(input_shape.size())) {
MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range";
}
output_shape[axis] = input_shape[axis] * offset;
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_concat_cnode, 0)}, {output_shape},
new_concat.get());
return new_concat;
}
} // namespace
const BaseRef ConcatFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimConcat, Xs});
}
const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// The real input begins with index 1.
size_t origin_input_size = cnode->inputs().size() - 1;
if (origin_input_size <= inputs_divisor_) {
return nullptr;
}
CNodePtr new_cnode = cnode;
while (origin_input_size > inputs_divisor_) {
MS_EXCEPTION_IF_NULL(new_cnode);
std::vector<AnfNodePtr> base_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
size_t cur_input_index = 1;
// Divide the inputs of concat by inputs_divisor_.
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
base_concat_inputs.push_back(CreateNewConcat(func_graph, new_cnode, cur_input_index, inputs_divisor_));
cur_input_index += inputs_divisor_;
}
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
base_concat_inputs.push_back(new_cnode->input(i));
}
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
MS_EXCEPTION_IF_NULL(base_concat);
base_concat->set_scope(new_cnode->scope());
base_concat->set_abstract(new_cnode->abstract());
// Set attrs
AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat);
AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat);
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat);
std::vector<int> dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat);
new_cnode = base_concat;
origin_input_size = base_concat->inputs().size() - 1;
}
return new_cnode;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
constexpr size_t kConcatInputsDivisor = 63;
class ConcatFission : public PatternProcessPass {
public:
explicit ConcatFission(bool multigraph = true)
: PatternProcessPass("concat_fission", multigraph), inputs_divisor_(kConcatInputsDivisor) {}
~ConcatFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
size_t inputs_divisor_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_CONCAT_FISSION_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#include <memory>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr CreateNewPack(const FuncGraphPtr &func_graph, const CNodePtr &origin_pack_cnode, size_t begin_index,
size_t offset) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(origin_pack_cnode);
std::vector<AnfNodePtr> new_pack_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimPack->name()))};
for (size_t i = begin_index; i < begin_index + offset; ++i) {
new_pack_inputs.push_back(origin_pack_cnode->input(i));
}
CNodePtr new_pack = func_graph->NewCNode(new_pack_inputs);
MS_EXCEPTION_IF_NULL(new_pack);
new_pack->set_scope(origin_pack_cnode->scope());
new_pack->set_abstract(origin_pack_cnode->abstract());
AnfAlgo::CopyNodeAttr(kAttrAxis, origin_pack_cnode, new_pack);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_pack);
AnfAlgo::SetNodeAttr(kAttrNum, MakeValue(SizeToInt(offset)), new_pack);
std::vector<int> dyn_input_sizes{SizeToInt(offset)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_pack);
// infer shape
auto output_shape = AnfAlgo ::GetOutputInferShape(origin_pack_cnode, 0);
auto axis = AnfAlgo::GetNodeAttr<int>(new_pack, kAttrAxis);
if (axis < 0) {
axis += output_shape.size();
}
if (axis < 0) {
MS_LOG(EXCEPTION) << "The concat_dim value " << axis << "is out of range";
}
std::vector<size_t> new_shape;
for (size_t i = 0; i < output_shape.size() + 1; ++i) {
if (i < IntToSize(axis)) {
new_shape.push_back(output_shape[i]);
} else if (i == IntToSize(axis)) {
new_shape.push_back(offset);
} else {
new_shape.push_back(output_shape[i - 1]);
}
}
new_shape.erase(new_shape.begin() + axis + 1);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(origin_pack_cnode, 0)}, {output_shape},
new_pack.get());
return new_pack;
}
} // namespace
const BaseRef PackFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimPack, Xs});
}
const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// The real input begins with index 1.
size_t origin_input_size = cnode->inputs().size() - 1;
if (origin_input_size <= inputs_divisor_) {
return nullptr;
}
std::vector<AnfNodePtr> base_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
size_t cur_input_index = 1;
// Divide the inputs of pack by inputs_divisor_.
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
base_concat_inputs.push_back(CreateNewPack(func_graph, cnode, cur_input_index, inputs_divisor_));
cur_input_index += inputs_divisor_;
}
if (cur_input_index <= origin_input_size) {
base_concat_inputs.push_back(
CreateNewPack(func_graph, cnode, cur_input_index, origin_input_size - cur_input_index + 1));
}
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
MS_EXCEPTION_IF_NULL(base_concat);
base_concat->set_scope(cnode->scope());
base_concat->set_abstract(cnode->abstract());
AnfAlgo::CopyNodeAttr(kAttrAxis, cnode, base_concat);
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat);
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToInt(base_concat_inputs.size() - 1)), base_concat);
std::vector<int> dyn_input_sizes{SizeToInt(base_concat_inputs.size() - 1)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat);
return base_concat;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
constexpr size_t kPackInputsDivisor = 63;
class PackFission : public PatternProcessPass {
public:
explicit PackFission(bool multigraph = true)
: PatternProcessPass("pack_fission", multigraph), inputs_divisor_(kPackInputsDivisor) {}
~PackFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
size_t inputs_divisor_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_PACK_FISSION_H_
......@@ -241,6 +241,9 @@ constexpr auto kAttrOffset = "offset";
constexpr auto kAttrPsKey = "ps_key";
constexpr auto kAttrOptimizerType = "optim_type";
constexpr auto kAttrChildGraph = "child_graph";
constexpr auto kAttrInputNums = "inputNums";
constexpr auto kAttrT = "T";
constexpr auto kAttrNum = "num";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#define private public
#define protected public
#include "backend/optimizer/ascend/ir_fission/concat_fission.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
class TestHWConcatFission : public BackendCommon {
public:
TestHWConcatFission() : get_py_fun_("gtest_input.pre_activate.concat_fission_test", true) {}
~TestHWConcatFission() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_2) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 9; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 2;
pm->AddPass(concat_fission);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_2");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
}
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_3) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 9; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 3;
pm->AddPass(concat_fission);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_3");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
}
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_4) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 9; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 4;
pm->AddPass(concat_fission);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_4");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
}
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_8) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 9; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 8;
pm->AddPass(concat_fission);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_8");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
}
TEST_F(TestHWConcatFission, test_concat_fission_divided_by_9) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_concat_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 9; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto concat_fission = std::make_shared<opt::ConcatFission>();
concat_fission->inputs_divisor_ = 9;
pm->AddPass(concat_fission);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_concat_fission", "after_divided_by_9");
EXPECT_NE(g_after, nullptr);
auto kg_after = GetKernelGraph(g_after, args_spec_list);
EXPECT_TRUE(CheckEqualGraph(kg_after, new_graph));
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#define private public
#define protected public
#include "backend/optimizer/ascend/ir_fission/pack_fission.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
class TestHWPackFission : public BackendCommon {
public:
TestHWPackFission() : get_py_fun_("gtest_input.pre_activate.pack_fission_test", true) {}
~TestHWPackFission() override = default;
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWPackFission, test_pack_fission_divided_by_3) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 9; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pack_fission = std::make_shared<opt::PackFission>();
pack_fission->inputs_divisor_ = 3;
pm->AddPass(pack_fission);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_3");
EXPECT_NE(g_after, nullptr);
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWPackFission, test_pack_fission_divided_by_4) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_pack_fission", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 9; ++i) {
args_spec_list.push_back(x_abstract);
}
auto kg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pack_fission = std::make_shared<opt::PackFission>();
pack_fission->inputs_divisor_ = 4;
pm->AddPass(pack_fission);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_pack_fission", "after_divided_by_4");
EXPECT_NE(g_after, nullptr);
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
concat = P.Concat()
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def test_concat_fission(tag):
""" test_adam_apply_one_with_decay_rule """
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, input5, input6, input7, input8):
return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8))
@fns
def after_divided_by_2(input0, input1, input2, input3, input4, input5, input6, input7, input8):
a = concat((input0, input1))
b = concat((input2, input3))
c = concat((input4, input5))
d = concat((input6, input7))
f = concat((a, b))
g = concat((c, d))
i = concat((f, g))
return concat((i, input8))
@fns
def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8):
a = concat((input0, input1, input2))
b = concat((input3, input4, input5))
c = concat((input6, input7, input8))
return concat((a, b, c))
@fns
def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8):
a = concat((input0, input1, input2, input3))
b = concat((input4, input5, input6, input7))
return concat((a, b, input8))
@fns
def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8):
a = concat((input0, input1, input2, input3, input4, input5, input6, input7))
return concat((a, input8))
@fns
def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8):
return concat((input0, input1, input2, input3, input4, input5, input6, input7, input8))
return fns[tag]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
from mindspore.ops import Primitive
pack = P.Pack()
concat = P.Concat()
make_tuple = Primitive('make_tuple')
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def test_pack_fission(tag):
""" test_adam_apply_one_with_decay_rule """
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, input5, input6, input7, input8):
return pack((input0, input1, input2, input3, input4, input5, input6, input7, input8))
@fns
def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8):
pack1 = pack(input0, input1, input2)
pack2 = pack(input3, input4, input5)
pack3 = pack(input6, input7, input8)
return make_tuple(concat(pack1, pack2, pack3))
@fns
def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8):
pack1 = pack(input0, input1, input2, input3)
pack2 = pack(input4, input5, input6, input7)
pack3 = pack(input8)
return make_tuple(concat(pack1, pack2, pack3))
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册