提交 1f809f50 编写于 作者: C chujinjin

fix precision error with fp16 input on PyNative mode

上级 e984f3ec
...@@ -167,7 +167,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const ...@@ -167,7 +167,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
} }
} // namespace } // namespace
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) { const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type,
const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(trans_data); MS_EXCEPTION_IF_NULL(trans_data);
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
MS_EXCEPTION_IF_NULL(ori_build_info); MS_EXCEPTION_IF_NULL(ori_build_info);
...@@ -176,6 +177,10 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & ...@@ -176,6 +177,10 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
builder->SetInputReshapeType({reshape_type}); builder->SetInputReshapeType({reshape_type});
builder->SetOutputReshapeType({reshape_type}); builder->SetOutputReshapeType({reshape_type});
builder->SetOutputsFormat({output_format}); builder->SetOutputsFormat({output_format});
if (type_id != kTypeUnknown) {
builder->SetOutputsDeviceType({type_id});
builder->SetInputsDeviceType({type_id});
}
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
} }
......
...@@ -86,7 +86,8 @@ class OpFinder { ...@@ -86,7 +86,8 @@ class OpFinder {
using OpFinderPtr = std::shared_ptr<OpFinder>; using OpFinderPtr = std::shared_ptr<OpFinder>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {}); const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {},
const TypeId &type_id = kTypeUnknown);
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const bool need_padding, const std::string &op_name); const bool need_padding, const std::string &op_name);
......
...@@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP ...@@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
if (origin_format != cur_format && cur_shape.size() > 1) { if (origin_format != cur_format && cur_shape.size() > 1) {
auto kernel_select = std::make_shared<KernelSelect>(); auto kernel_select = std::make_shared<KernelSelect>();
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(cur_format, origin_format, final_node); RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
final_index = 0; final_index = 0;
MS_EXCEPTION_IF_NULL(final_node); MS_EXCEPTION_IF_NULL(final_node);
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册