提交 4a0ddaef 编写于 作者: Y yujianfeng

Support specifying reshape type for batchnorm fused op

上级 b45b6a9f
......@@ -321,9 +321,11 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_
MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format.";
for (size_t i = 0; i < inputs_static.size(); i++) {
inputs_dyn[i]->set_param_type(inputs_static[i]->param_type());
inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type());
}
for (size_t j = 0; j < outputs_static.size(); j++) {
outputs_dyn[j]->set_param_type(outputs_static[j]->param_type());
outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type());
}
op_info_new_ptr->set_inputs_ptr(inputs_dyn);
op_info_new_ptr->set_outputs_ptr(outputs_dyn);
......@@ -335,6 +337,29 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_
op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type());
}
bool StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(kernel::N);
break;
case 'C':
reshape_type_vec->push_back(kernel::C);
break;
case 'H':
reshape_type_vec->push_back(kernel::H);
break;
case 'W':
reshape_type_vec->push_back(kernel::W);
break;
default:
MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type.";
return false;
}
}
return true;
}
bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
size_t builder_idex, const std::vector<int> &dyn_input_sizes,
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
......@@ -347,6 +372,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
MS_EXCEPTION_IF_NULL(inputs[0]);
size_t kernel_info_cnt = inputs[0]->dtypes().size();
std::vector<std::vector<Axis>> reshape_types;
for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input);
std::string param_type = input->param_type();
......@@ -384,8 +410,14 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
inputs_format.push_back(formats[builder_idex]);
}
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
return false;
}
reshape_types.push_back(reshape_type);
}
builder->SetInputReshapeType(reshape_types);
builder->SetInputsDeviceType(inputs_device_type);
builder->SetInputsFormat(inputs_format);
return true;
......@@ -403,6 +435,7 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
MS_EXCEPTION_IF_NULL(outputs[0]);
size_t kernel_info_cnt = outputs[0]->dtypes().size();
std::vector<std::vector<Axis>> reshape_types;
for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output);
if (output_idx >= real_output_num) {
......@@ -436,8 +469,14 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
outputs_format.push_back(formats[builder_idex]);
output_idx++;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
return false;
}
reshape_types.push_back(reshape_type);
}
builder->SetOutputReshapeType(reshape_types);
builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_device_type);
return true;
......@@ -515,7 +554,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04};
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
// if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) {
......@@ -528,13 +567,13 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
if (shape.empty()) {
return true;
}
if (shape.size() > kShapeSupportFormatMap.size()) {
if (shape.size() > kShape4dDims) {
return false;
}
if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) {
return true;
if (format == kOpFormat_FRAC_NZ && shape.size() < 2) {
return false;
}
return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end());
return true;
}
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
......
......@@ -55,12 +55,17 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
trans_inputs.push_back(input);
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
MS_EXCEPTION_IF_NULL(trans_node);
std::vector<kernel::Axis> padding_axis;
if (AnfAlgo::IsRealKernel(input)) {
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
} else {
padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0);
}
if (need_padding) {
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo::SetOutputInferTypeAndShape(
{AnfAlgo::GetOutputInferDataType(input, 0)},
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))},
trans_node.get());
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
trans_node.get());
} else {
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
......@@ -194,8 +199,14 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL(cnode);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
}
bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
op_name == kTransDataOpName);
bool need_padding = false;
if (is_insert_input) {
need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
op_name == kTransDataOpName);
} else {
need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
op_name == kTransDataOpName);
}
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
......
......@@ -86,7 +86,6 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
(*bn_reduce_grad_outputs).push_back(bn_reduce_grad);
}
} // namespace
const BaseRef BatchNormGradSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
......
......@@ -344,7 +344,7 @@ bool IsNopNode(const AnfNodePtr &node) {
return true;
}
bool IsAllNopNode(session::KernelGraph *const graph) {
bool IsAllNopNode(const session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &cnode : execution_order) {
......
......@@ -347,6 +347,11 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
abstract::BaseShapePtr base_shape = node->Shape();
......
......@@ -95,6 +95,8 @@ class AnfRuntimeAlgorithm {
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
// get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get reshape_type of from the output of input node.
static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes.
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
// get input shapes inferred by ME from input nodes.
......
......@@ -204,6 +204,7 @@ constexpr auto kOpFormat_FRAC_Z = "FracZ";
constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_C1HWNCoC0};
......@@ -225,8 +226,9 @@ const std::set<std::string> kOptOperatorSet = {
kApplyRMSPropOpName,
};
const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0};
const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04};
static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
if (access(file_name.c_str(), F_OK) != 0) {
......
......@@ -23,7 +23,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
.compute_cost(10) \
.kernel_name("bn_training_reduce") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(0, "x", False, "required", "all", reshape_type="NC") \
.output(0, "sum", False, "required", "all") \
.output(1, "square_sum", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \
......
......@@ -24,14 +24,14 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
.kernel_name("bn_training_reduce_grad") \
.partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \
.input(0, "grads", False, "required", "all") \
.input(1, "x_norm", False, "required", "all") \
.input(0, "grads", False, "required", "all", reshape_type="NC") \
.input(1, "x_norm", False, "required", "all", reshape_type="NC") \
.input(2, "diff_scale", False, "required", "all") \
.input(3, "diff_offset", False, "required", "all") \
.input(4, "scale", False, "required", "all") \
.input(5, "batch_mean", False, "required", "all") \
.input(6, "batch_variance", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.output(0, "y", False, "required", "all", reshape_type="NC") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
......
......@@ -26,14 +26,14 @@ bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \
.attr("factor", "optional", "float", "all") \
.attr("epsilon", "optional", "float", "all") \
.attr("isRef", "optional", "bool", "all", "true") \
.input(0, "x", False, "required", "all") \
.input(0, "x", False, "required", "all", reshape_type="NC") \
.input(1, "sum", False, "required", "all") \
.input(2, "square_sum", False, "required", "all") \
.input(3, "scale", False, "required", "all") \
.input(4, "offset", False, "required", "all") \
.input(5, "mean", False, "required", "all") \
.input(6, "variance", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.output(0, "y", False, "required", "all", reshape_type="NC") \
.output(1, "mean", False, "required", "all") \
.output(2, "variance", False, "required", "all") \
.output(3, "batch_mean", False, "required", "all") \
......
......@@ -24,8 +24,8 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
.kernel_name("bn_training_update_grad") \
.partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \
.input(0, "grads", False, "required", "all") \
.input(1, "x", False, "required", "all") \
.input(0, "grads", False, "required", "all", reshape_type="NC") \
.input(1, "x", False, "required", "all", reshape_type="NC") \
.input(2, "batch_mean", False, "required", "all") \
.input(3, "batch_variance", False, "required", "all") \
.output(0, "diff_scale", False, "required", "all") \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册