提交 1f1a07e6 编写于 作者: C chenfei

don't insert assign from condition to true branch of while

上级 9fff0508
......@@ -22,6 +22,7 @@
#include "backend/optimizer/pass/convert_const_input_to_tensor_input.h"
#include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h"
#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h"
#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h"
#include "utils/context/ms_context.h"
#include "debug/anf_ir_dump.h"
......@@ -46,8 +47,9 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto common_pm = std::make_shared<PassManager>("common_pm");
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
common_pm->AddPass(std::make_shared<ConvertConstScalarToTensor>());
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
optimizer->AddPassManager(common_pm);
(void)optimizer->Optimize(kernel_graph);
......
......@@ -781,5 +781,27 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString();
return false;
}
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node);
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
new_value_node->set_abstract(value_node->abstract());
// create kernel_info fo new value node
auto kernel_info = std::make_shared<device::KernelInfo>();
new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
types.push_back(kTypeUnknown);
}
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
return new_value_node;
}
} // namespace opt
} // namespace mindspore
......@@ -194,6 +194,9 @@ bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name);
// Check node's data type is in supported data type set
bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &supported_data_type_set);
// Create a new value node of func graph,not kernel graph
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_HELPER_H_
......@@ -29,28 +29,8 @@
namespace mindspore {
namespace opt {
namespace {
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node);
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
new_value_node->set_abstract(value_node->abstract());
// create kernel_info fo new value node
auto kernel_info = std::make_shared<device::KernelInfo>();
new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
types.push_back(kTypeUnknown);
}
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
return new_value_node;
}
AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
AnfNodePtr CreateTensorInput(const AnfNodePtr &node, const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node);
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
......@@ -60,6 +40,9 @@ AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePt
if (value->isa<Scalar>()) {
tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
} else if (value->isa<ValueTuple>()) {
if (!AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
} else {
MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
......@@ -93,7 +76,7 @@ AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePt
for (size_t i = 0; i < inputs.size() - 1; ++i) {
auto input_node = inputs[i + 1];
if (IsValueNode<Scalar>(input_node) || IsValueNode<ValueTuple>(input_node)) {
auto tensor_input = CreateTensorInput(kernel_graph, input_node);
auto tensor_input = CreateTensorInput(cnode, kernel_graph, input_node);
if (tensor_input == nullptr) {
new_inputs.push_back(input_node);
continue;
......@@ -139,7 +122,7 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) {
const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
if (!node->isa<CNode>()) {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/pass/convert_const_scalar_to_tensor.h"
#include <vector>
#include <memory>
#include <utility>
#include "utils/graph_utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/kernel_info.h"
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node);
if (!input_node->isa<ValueNode>()) {
return nullptr;
}
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (!value->isa<Scalar>()) {
return nullptr;
}
tensor::TensorPtr tensor_ptr = ScalarToTensor(value->cast<ScalarPtr>());
if (tensor_ptr == nullptr) {
MS_LOG(WARNING) << "Create tensor of" << input_node->DebugString() << "failed";
return nullptr;
}
auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
MS_EXCEPTION_IF_NULL(tensor_input);
tensor_input->set_abstract(tensor_ptr->ToAbstract());
if (kernel_graph != nullptr) {
tensor_input = kernel_graph->NewValueNode(tensor_input);
kernel_graph->AddValueNodeToGraph(tensor_input);
} else {
tensor_input = MakeValueNode(tensor_input);
}
tensor_input->set_scope(input_node->scope());
return tensor_input;
}
} // namespace
const AnfNodePtr ConvertConstScalarToTensor::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || func_graph == nullptr || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
return nullptr;
}
if (!node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
bool input_changed = false;
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
auto new_input = CreateTensorInput(func_graph->cast<KernelGraphPtr>(), cnode->inputs()[i]);
if (new_input != nullptr) {
cnode->set_input(i, new_input);
input_changed = true;
}
}
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
if (kernel_graph == nullptr || !input_changed) {
return nullptr;
}
return kernel_graph->NewCNode(cnode);
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_
#include <string>
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class ConvertConstScalarToTensor : public PatternProcessPass {
public:
explicit ConvertConstScalarToTensor(bool multigraph = true)
: PatternProcessPass("convert_const_scalar_to_tensor", multigraph) {}
~ConvertConstScalarToTensor() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_CONST_SCALAR_TO_TENSOR_H_
......@@ -75,8 +75,10 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
}
}
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
return cnode_input_changed ? kernel_graph->NewCNode(cnode) : nullptr;
if (kernel_graph == nullptr || !cnode_input_changed) {
return nullptr;
}
return kernel_graph->NewCNode(cnode);
}
} // namespace opt
} // namespace mindspore
......@@ -881,22 +881,22 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
return;
}
memo->insert(graph.get());
graph->UpdateChildGraphOrder();
for (auto &child_graph : graph->child_graph_order()) {
CreateMultiBranchOutput(NOT_NULL(child_graph), memo);
}
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert assign
// from condition to true graph
if (graph->get_output_null()) {
return;
}
std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
auto node_list = GetCNodes(TopoSort(graph->get_return()));
for (auto &node : node_list) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
// create a parameter to store the output of multiple branch and set the parameter as the condition graph's output
// auto multi_output_param = graph->NewParameter();
auto origin_inputs = graph->inputs();
auto output_param = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
MS_EXCEPTION_IF_NULL(graph->MutableInputs());
graph->MutableInputs()->operator=(origin_inputs);
graph->AddChildGraphResult(output_param);
std::vector<AnfNodePtr> depend_inputs = {
......
......@@ -133,7 +133,6 @@ AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
if (value_node == nullptr) {
return nullptr;
}
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
new_value_node->set_abstract(value_node->abstract());
this->SetKernelInfoForNode(new_value_node);
......
......@@ -99,18 +99,13 @@ TEST_F(TestHWConstInputToTensorInput, test_value_tuple_tensor_input) {
EXPECT_NE(ret->input(1)->cast<CNodePtr>(), nullptr);
auto cnode = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
EXPECT_EQ(AnfAlgo::GetCNodeName(cnode), prim::kPrimDropoutGenMask->name());
std::vector<int> out;
for (size_t i = 1; i <= 4; i++) {
auto input = cnode->input(i);
ASSERT_TRUE(input != nullptr);
EXPECT_TRUE(IsValueNode<tensor::Tensor>(input));
auto tensor = input->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
ASSERT_TRUE(tensor != nullptr);
int *data = (int *)(tensor->data_c());
ASSERT_TRUE(data != nullptr);
out.push_back(*data);
}
EXPECT_EQ(out, std::vector<int>({2, 4, 2, 2}));
auto input1 = cnode->input(1);
ASSERT_TRUE(input1 != nullptr);
EXPECT_TRUE(IsValueNode<tensor::Tensor>(input1));
auto tensor = input1->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
ASSERT_TRUE(tensor != nullptr);
auto data = tensor->data_c();
EXPECT_EQ(std::vector<int>((int *)data, (int *)data + 4), std::vector<int>({2, 4, 2, 2}));
}
} // namespace opt
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册