提交 3d9f1087 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!695 Check topk supported before converting input to attr

Merge pull request !695 from YuJianfeng/master
......@@ -21,6 +21,7 @@
#include <vector>
#include "device/ascend/kernel_select_ascend.h"
#include "kernel/kernel_query.h"
#include "kernel/tbe/tbe_kernel_select.h"
namespace mindspore {
namespace opt {
......@@ -36,6 +37,16 @@ class KernelSelect {
};
using KernelSelectPtr = std::shared_ptr<KernelSelect>;
class SupportedChecker {
public:
SupportedChecker() = default;
virtual ~SupportedChecker() = default;
virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
return kernel::CheckSupported(anf_node, select_kernel_build_info);
}
};
using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
class KernelQuery {
public:
KernelQuery() = default;
......
......@@ -16,6 +16,9 @@
#include "pre_activate/ascend/ir_fission/topk_split.h"
#include <vector>
#include <memory>
#include <unordered_set>
#include "pre_activate/common/helper.h"
#include "kernel/kernel_build_info.h"
#include "utils/utils.h"
#include "session/kernel_graph.h"
#include "session/anf_runtime_algorithm.h"
......@@ -25,6 +28,7 @@
namespace mindspore {
namespace opt {
constexpr size_t kFloat16Len = 2; // size of float16;
constexpr size_t kTopkIndexK = 1;
namespace {
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
// 1 create tensor
......@@ -70,37 +74,68 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get());
return indices_const;
}
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32});
return builder.Build();
}
} // namespace
const BaseRef TopKSplit::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
VarPtr X1 = std::make_shared<Var>();
VarPtr X2 = std::make_shared<Var>();
auto prim = std::make_shared<Primitive>(kTopKOpName);
MS_EXCEPTION_IF_NULL(prim);
return VectorRef({prim, X});
return VectorRef({prim, X1, X2});
}
const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
auto indices_const = CreateValueNode(node);
// set value node as topk's input
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(INFO) << "already has input size: " << cnode->inputs().size();
cnode->add_input(indices_const);
// Copy a new node to check supported.
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
CheckCNodeInputSize(new_cnode, kTopkInputNum);
// Convert the tensor input to scalar and convert it to attr
auto input_k = new_cnode->input(kTopkIndexK + 1);
MS_EXCEPTION_IF_NULL(input_k);
if (!IsValueNode<tensor::Tensor>(input_k)) {
return nullptr;
}
ValuePtr value = GetValueNode(input_k);
MS_EXCEPTION_IF_NULL(value);
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(data);
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data));
new_cnode->set_input(kTopkIndexK + 1, new_value_node);
std::unordered_set<size_t> attr_index{kTopkIndexK};
ConstInputToAttr(new_cnode, attr_index);
auto indices_const = CreateValueNode(new_cnode);
new_cnode->add_input(indices_const);
MS_EXCEPTION_IF_NULL(supported_checker_);
if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) {
return nullptr;
}
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(indices_const);
}
CNodePtr new_cnode = nullptr;
if (kernel_graph == nullptr) {
new_cnode = std::make_shared<CNode>(*cnode);
} else {
new_cnode = kernel_graph->NewCNode(cnode);
}
MS_EXCEPTION_IF_NULL(new_cnode);
return new_cnode;
}
} // namespace opt
......
......@@ -16,15 +16,22 @@
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_
#include <memory>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
class TopKSplit : public PatternProcessPass {
public:
explicit TopKSplit(bool multigraph = true) : PatternProcessPass("topk_split", multigraph) {}
explicit TopKSplit(bool multigraph = true)
: PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared<SupportedChecker>()) {}
~TopKSplit() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
SupportedCheckerPtr supported_checker_;
};
} // namespace opt
} // namespace mindspore
......
......@@ -422,5 +422,47 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
return tuple_getitem;
}
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> new_inputs;
std::vector<std::string> new_input_names;
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
auto input_names = primitive->GetAttr(kAttrInputNames);
if (input_names == nullptr) {
MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
return;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
auto inputs = cnode->inputs();
new_inputs.push_back(inputs[0]);
bool need_update = false;
for (size_t i = 0; i < inputs.size() - 1; ++i) {
auto input_node = inputs[i + 1];
MS_EXCEPTION_IF_NULL(input_node);
if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
if (i >= input_names_vec.size()) {
MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
}
primitive->set_attr(input_names_vec[i], value_node->value());
need_update = true;
} else {
new_inputs.push_back(input_node);
if (i < input_names_vec.size()) {
new_input_names.push_back(input_names_vec[i]);
}
}
}
if (need_update) {
// Update cnode's inputs
cnode->set_inputs(new_inputs);
// Update cnode's input_names attr
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
}
}
} // namespace opt
} // namespace mindspore
......@@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include <string>
#include <unordered_set>
#include "ir/func_graph.h"
#include "session/kernel_graph.h"
#include "common/utils.h"
......@@ -86,6 +87,7 @@ constexpr size_t kAdamApplyOneOutputNum = 3;
constexpr size_t kBackendTransDataInputNum = 2;
constexpr size_t kApplyMomentumInputNum = 6;
constexpr size_t kBiasAddInputNum = 3;
constexpr size_t kTopkInputNum = 3;
enum FusedBatchNormInput {
kX = 1,
......@@ -150,6 +152,8 @@ void RemoveNopNode(session::KernelGraph *const graph);
AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx);
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
......@@ -52,7 +52,6 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(kFlattenGradOpName, {1});
Register(kExpandDimsOpName, {1});
Register(kSplitOpName, {0});
Register(kTopKOpName, {1});
Register(kErfOpName, {1});
Register(kSparseApplyAdagradOpName, {2});
Register(kResizeNearestNeighborGrad, {1});
......
......@@ -18,10 +18,10 @@
#include <vector>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <memory>
#include "pre_activate/pass/const_input_to_attr_registry.h"
#include "pre_activate/common/helper.h"
#include "utils/utils.h"
#include "utils/context/ms_context.h"
#include "operator/ops.h"
......@@ -29,50 +29,6 @@
namespace mindspore {
namespace opt {
namespace {
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> new_inputs;
std::vector<std::string> new_input_names;
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
auto input_names = primitive->GetAttr(kAttrInputNames);
if (input_names == nullptr) {
MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
return;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
auto inputs = cnode->inputs();
new_inputs.push_back(inputs[0]);
bool need_update = false;
for (size_t i = 0; i < inputs.size() - 1; ++i) {
auto input_node = inputs[i + 1];
MS_EXCEPTION_IF_NULL(input_node);
if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
if (i >= input_names_vec.size()) {
MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
}
primitive->set_attr(input_names_vec[i], value_node->value());
need_update = true;
} else {
new_inputs.push_back(input_node);
if (i < input_names_vec.size()) {
new_input_names.push_back(input_names_vec[i]);
}
}
}
if (need_update) {
// Update cnode's inputs
cnode->set_inputs(new_inputs);
// Update cnode's input_names attr
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
}
}
} // namespace
const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
......
......@@ -17,8 +17,13 @@
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "device/kernel_info.h"
#include "pre_activate/ascend/ir_fission/topk_split.h"
#include "pre_activate/pass/convert_const_input_to_attr.h"
#include "debug/anf_ir_dump.h"
#define private public
#define protected public
#include "pre_activate/ascend/ir_fission/topk_split.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
......@@ -30,6 +35,15 @@ class TestHWTopKSplit : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_;
};
class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
}; // namespace opt
TEST_F(TestHWTopKSplit, test_topk_split) {
/*
* def before(input):
......@@ -40,19 +54,25 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before");
std::vector<int> shp{4, 4};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
g->parameters()[0]->set_abstract(x_abstract);
auto ret = g->get_return();
EXPECT_NE(ret, nullptr);
auto tuple_getitem = ret->input(1);
EXPECT_NE(tuple_getitem, nullptr);
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
topk->set_abstract(x_abstract);
AbstractBasePtrList args_spec_list{x_abstract};
auto kernel_graph = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::TopKSplit>());
pm->AddPass(std::make_shared<opt::ConvertConstInputToAttr>());
auto topk_split = std::make_shared<opt::TopKSplit>();
topk_split->supported_checker_ = std::make_shared<MockSupportedChecker>();
pm->AddPass(topk_split);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(g);
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph);
auto ret = new_graph->get_return();
EXPECT_NE(ret, nullptr);
auto make_tuple = ret->input(1);
EXPECT_NE(make_tuple, nullptr);
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(tuple_getitem, nullptr);
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
auto topk_cnode = topk->cast<CNodePtr>();
EXPECT_EQ(topk_cnode->inputs().size(), 3);
EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>());
......
......@@ -35,7 +35,7 @@ def test_topk_split(tag):
@fns
def before(input):
topk = TopK(input)
topk = TopK(input, 2)
output = tuple_getitem(topk, 0)
return output
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册