提交 73c4022e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3775 remove the dtype convert when update output

Merge pull request !3775 from lianliguang/test-xiu-bug
......@@ -51,33 +51,19 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
AnfNodePtr input_node = nullptr;
CNodePtr trans_data = nullptr;
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
std::vector<Axis> padding_axis;
MS_EXCEPTION_IF_NULL(node);
// if insert transdata for input we need to change the input
if (is_insert_input) {
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
}
auto cnode = node->cast<CNodePtr>();
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index);
} else {
input_node = node;
padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
}
// Init
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node;
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT;
std::vector<Axis> padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index)
: AnfAlgo::GetOutputReshapeType(node, insert_index);
auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index)
: AnfAlgo::GetOutputInferShape(input_node, insert_index);
bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size())
: trans::IsNeedPadding(input_format, input_node_out_shape.size());
auto input_node_out_shape = AnfAlgo::GetOutputInferShape(input_node, 0);
bool need_padding = false;
if (is_insert_input) {
need_padding = (trans::IsNeedPadding(dst_format, input_node_out_shape.size()));
} else {
need_padding = (trans::IsNeedPadding(input_format, input_node_out_shape.size()));
}
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
......@@ -89,6 +75,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
trans_node = trans_data;
trans_data->set_abstract(input_node->abstract());
} else {
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
......@@ -303,7 +290,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second);
TypeId origin_type(kTypeUnknown);
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0);
auto real_input_node = kernel_with_index.first;
if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
// weight
......
......@@ -28,7 +28,8 @@ namespace opt {
class RectifyDoMaskKernelInfo : public PatternProcessPass {
public:
explicit RectifyDoMaskKernelInfo(bool multigraph = true)
: PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared<KernelSelect>()) {}
: PatternProcessPass("rectify_do_mask_kernel_info", multigraph),
kernel_selecter(std::make_shared<KernelSelect>()) {}
~RectifyDoMaskKernelInfo() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
......
......@@ -87,6 +87,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
new_transdata_node =
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node);
new_transdata_node->set_abstract(node->abstract());
new_replace_node = new_transdata_node;
}
FuncGraphManagerPtr manager = func_graph->manager();
......
......@@ -19,6 +19,8 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "base/core_ops.h"
#include "frontend/operator/ops.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace opt {
......@@ -32,21 +34,21 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(reshape_op_1);
auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(out_reshape);
// If reshape operator used by more than one other operators, reshape operator cant not be deleted directly
if (IsUsedByOthers(func_graph, reshape_op_1)) {
if (IsUsedByOthers(func_graph, out_reshape)) {
return nullptr;
}
auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(reshape_op_2);
if (IsUsedByOthers(func_graph, reshape_op_2)) {
auto in_reshape = CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(in_reshape);
if (IsUsedByOthers(func_graph, in_reshape)) {
return nullptr;
}
auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0);
if (input_shape == output_shape) {
auto input_node = reshape_op_2->input(1);
auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0);
if (kernel::IsSameShape(input_shape, output_shape)) {
auto input_node = AnfAlgo::GetInputNode(in_reshape, 0);
return input_node;
}
return nullptr;
......
......@@ -71,7 +71,8 @@ bool CastEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) {
bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) {
return AnfAlgo::GetInputFormat(node1, 0) == AnfAlgo::GetOutputFormat(node2, 0) &&
AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0);
AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0) &&
kernel::IsSameShape(AnfAlgo::GetInputDeviceShape(node2, 0), AnfAlgo::GetOutputDeviceShape(node1, 0));
}
const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode,
......
......@@ -106,12 +106,12 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
// if node is a value node, no need sync addr from device to host
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return value_node->value();
}
if (!AnfAlgo::OutputAddrExist(node, output_index)) {
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
return value_node->value();
}
if (node->isa<Parameter>()) {
for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
if (input_idx >= input_tensors.size()) {
......@@ -252,6 +252,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
AnfAlgo::SetOutputAddr(device_address, 0, param.get());
}
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
// construct abstract of parameter
......
......@@ -481,13 +481,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown &&
AnfAlgo::OutputAddrExist(real_input_node, 0)) {
if (AnfAlgo::OutputAddrExist(real_input_node, 0)) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册