提交 1f809f50 编写于 作者: C chujinjin

fix precision error with fp16 input on PyNative mode

上级 e984f3ec
......@@ -167,7 +167,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
}
} // namespace
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type,
const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(trans_data);
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
MS_EXCEPTION_IF_NULL(ori_build_info);
......@@ -176,6 +177,10 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
builder->SetInputReshapeType({reshape_type});
builder->SetOutputReshapeType({reshape_type});
builder->SetOutputsFormat({output_format});
if (type_id != kTypeUnknown) {
builder->SetOutputsDeviceType({type_id});
builder->SetInputsDeviceType({type_id});
}
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
}
......
......@@ -86,7 +86,8 @@ class OpFinder {
using OpFinderPtr = std::shared_ptr<OpFinder>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {},
const TypeId &type_id = kTypeUnknown);
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const bool need_padding, const std::string &op_name);
......
......@@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
if (origin_format != cur_format && cur_shape.size() > 1) {
auto kernel_select = std::make_shared<KernelSelect>();
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(cur_format, origin_format, final_node);
RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
final_index = 0;
MS_EXCEPTION_IF_NULL(final_node);
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册