提交 4a0ddaef 编写于 作者: Y yujianfeng

Support specifying reshape type for batchnorm fused op

上级 b45b6a9f
...@@ -321,9 +321,11 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ ...@@ -321,9 +321,11 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_
MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format."; MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format.";
for (size_t i = 0; i < inputs_static.size(); i++) { for (size_t i = 0; i < inputs_static.size(); i++) {
inputs_dyn[i]->set_param_type(inputs_static[i]->param_type()); inputs_dyn[i]->set_param_type(inputs_static[i]->param_type());
inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type());
} }
for (size_t j = 0; j < outputs_static.size(); j++) { for (size_t j = 0; j < outputs_static.size(); j++) {
outputs_dyn[j]->set_param_type(outputs_static[j]->param_type()); outputs_dyn[j]->set_param_type(outputs_static[j]->param_type());
outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type());
} }
op_info_new_ptr->set_inputs_ptr(inputs_dyn); op_info_new_ptr->set_inputs_ptr(inputs_dyn);
op_info_new_ptr->set_outputs_ptr(outputs_dyn); op_info_new_ptr->set_outputs_ptr(outputs_dyn);
...@@ -335,6 +337,29 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ ...@@ -335,6 +337,29 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_
op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type()); op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type());
} }
bool StringToAxisVector(const std::string &reshape_type_str, std::vector<Axis> *reshape_type_vec) {
for (const auto &c : reshape_type_str) {
switch (c) {
case 'N':
reshape_type_vec->push_back(kernel::N);
break;
case 'C':
reshape_type_vec->push_back(kernel::C);
break;
case 'H':
reshape_type_vec->push_back(kernel::H);
break;
case 'W':
reshape_type_vec->push_back(kernel::W);
break;
default:
MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type.";
return false;
}
}
return true;
}
bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num, bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
size_t builder_idex, const std::vector<int> &dyn_input_sizes, size_t builder_idex, const std::vector<int> &dyn_input_sizes,
const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) { const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
...@@ -347,6 +372,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp ...@@ -347,6 +372,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
MS_EXCEPTION_IF_NULL(inputs[0]); MS_EXCEPTION_IF_NULL(inputs[0]);
size_t kernel_info_cnt = inputs[0]->dtypes().size(); size_t kernel_info_cnt = inputs[0]->dtypes().size();
std::vector<std::vector<Axis>> reshape_types;
for (const auto &input : inputs) { for (const auto &input : inputs) {
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
std::string param_type = input->param_type(); std::string param_type = input->param_type();
...@@ -384,8 +410,14 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp ...@@ -384,8 +410,14 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
inputs_format.push_back(formats[builder_idex]); inputs_format.push_back(formats[builder_idex]);
} }
} }
std::vector<Axis> reshape_type;
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
return false;
}
reshape_types.push_back(reshape_type);
} }
builder->SetInputReshapeType(reshape_types);
builder->SetInputsDeviceType(inputs_device_type); builder->SetInputsDeviceType(inputs_device_type);
builder->SetInputsFormat(inputs_format); builder->SetInputsFormat(inputs_format);
return true; return true;
...@@ -403,6 +435,7 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou ...@@ -403,6 +435,7 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
MS_EXCEPTION_IF_NULL(outputs[0]); MS_EXCEPTION_IF_NULL(outputs[0]);
size_t kernel_info_cnt = outputs[0]->dtypes().size(); size_t kernel_info_cnt = outputs[0]->dtypes().size();
std::vector<std::vector<Axis>> reshape_types;
for (const auto &output : outputs) { for (const auto &output : outputs) {
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
if (output_idx >= real_output_num) { if (output_idx >= real_output_num) {
...@@ -436,8 +469,14 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou ...@@ -436,8 +469,14 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
outputs_format.push_back(formats[builder_idex]); outputs_format.push_back(formats[builder_idex]);
output_idx++; output_idx++;
} }
std::vector<Axis> reshape_type;
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
return false;
}
reshape_types.push_back(reshape_type);
} }
builder->SetOutputReshapeType(reshape_types);
builder->SetOutputsFormat(outputs_format); builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_device_type); builder->SetOutputsDeviceType(outputs_device_type);
return true; return true;
...@@ -515,7 +554,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for ...@@ -515,7 +554,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04};
// if format is default, it remarkes support all format // if format is default, it remarkes support all format
if (kOpFormatList.find(format) == kOpFormatList.end()) { if (kOpFormatList.find(format) == kOpFormatList.end()) {
...@@ -528,13 +567,13 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for ...@@ -528,13 +567,13 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for
if (shape.empty()) { if (shape.empty()) {
return true; return true;
} }
if (shape.size() > kShapeSupportFormatMap.size()) { if (shape.size() > kShape4dDims) {
return false; return false;
} }
if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { if (format == kOpFormat_FRAC_NZ && shape.size() < 2) {
return true; return false;
} }
return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); return true;
} }
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
......
...@@ -55,12 +55,17 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, ...@@ -55,12 +55,17 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
trans_inputs.push_back(input); trans_inputs.push_back(input);
CNodePtr trans_node = func_graph->NewCNode(trans_inputs); CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
MS_EXCEPTION_IF_NULL(trans_node); MS_EXCEPTION_IF_NULL(trans_node);
std::vector<kernel::Axis> padding_axis;
if (AnfAlgo::IsRealKernel(input)) {
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
} else {
padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0);
}
if (need_padding) { if (need_padding) {
// if need padding we should set the transdata node's shape to the padding shape // if need padding we should set the transdata node's shape to the padding shape
AnfAlgo::SetOutputInferTypeAndShape( AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{AnfAlgo::GetOutputInferDataType(input, 0)}, {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))}, trans_node.get());
trans_node.get());
} else { } else {
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
...@@ -194,8 +199,14 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt ...@@ -194,8 +199,14 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
input_node = AnfAlgo::GetInputNode(cnode, insert_index); input_node = AnfAlgo::GetInputNode(cnode, insert_index);
} }
bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && bool need_padding = false;
op_name == kTransDataOpName); if (is_insert_input) {
need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
op_name == kTransDataOpName);
} else {
need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) &&
op_name == kTransDataOpName);
}
if (!need_padding) { if (!need_padding) {
// don't need padding insert transdata only // don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
......
...@@ -86,7 +86,6 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra ...@@ -86,7 +86,6 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra
AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad);
(*bn_reduce_grad_outputs).push_back(bn_reduce_grad); (*bn_reduce_grad_outputs).push_back(bn_reduce_grad);
} }
} // namespace } // namespace
const BaseRef BatchNormGradSplit::DefinePattern() const { const BaseRef BatchNormGradSplit::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>(); VarPtr Xs = std::make_shared<SeqVar>();
......
...@@ -344,7 +344,7 @@ bool IsNopNode(const AnfNodePtr &node) { ...@@ -344,7 +344,7 @@ bool IsNopNode(const AnfNodePtr &node) {
return true; return true;
} }
bool IsAllNopNode(session::KernelGraph *const graph) { bool IsAllNopNode(const session::KernelGraph *const graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order(); auto execution_order = graph->execution_order();
for (auto &cnode : execution_order) { for (auto &cnode : execution_order) {
......
...@@ -347,6 +347,11 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n ...@@ -347,6 +347,11 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
} }
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
abstract::BaseShapePtr base_shape = node->Shape(); abstract::BaseShapePtr base_shape = node->Shape();
......
...@@ -95,6 +95,8 @@ class AnfRuntimeAlgorithm { ...@@ -95,6 +95,8 @@ class AnfRuntimeAlgorithm {
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
// get output format from prev node,input_index is the input index of current node related to prev node // get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get reshape_type of from the output of input node.
static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes. // get output shapes inferred by ME from input nodes.
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
// get input shapes inferred by ME from input nodes. // get input shapes inferred by ME from input nodes.
......
...@@ -204,6 +204,7 @@ constexpr auto kOpFormat_FRAC_Z = "FracZ"; ...@@ -204,6 +204,7 @@ constexpr auto kOpFormat_FRAC_Z = "FracZ";
constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04";
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_C1HWNCoC0}; kOpFormat_C1HWNCoC0};
...@@ -225,8 +226,9 @@ const std::set<std::string> kOptOperatorSet = { ...@@ -225,8 +226,9 @@ const std::set<std::string> kOptOperatorSet = {
kApplyRMSPropOpName, kApplyRMSPropOpName,
}; };
const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04};
static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { static inline void ChangeFileMode(const std::string &file_name, mode_t mode) {
if (access(file_name.c_str(), F_OK) != 0) { if (access(file_name.c_str(), F_OK) != 0) {
......
...@@ -23,7 +23,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \ ...@@ -23,7 +23,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("bn_training_reduce") \ .kernel_name("bn_training_reduce") \
.partial_flag(True) \ .partial_flag(True) \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all", reshape_type="NC") \
.output(0, "sum", False, "required", "all") \ .output(0, "sum", False, "required", "all") \
.output(1, "square_sum", False, "required", "all") \ .output(1, "square_sum", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \
......
...@@ -24,14 +24,14 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \ ...@@ -24,14 +24,14 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \
.kernel_name("bn_training_reduce_grad") \ .kernel_name("bn_training_reduce_grad") \
.partial_flag(True) \ .partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \ .attr("epsilon", "optional", "float", "all") \
.input(0, "grads", False, "required", "all") \ .input(0, "grads", False, "required", "all", reshape_type="NC") \
.input(1, "x_norm", False, "required", "all") \ .input(1, "x_norm", False, "required", "all", reshape_type="NC") \
.input(2, "diff_scale", False, "required", "all") \ .input(2, "diff_scale", False, "required", "all") \
.input(3, "diff_offset", False, "required", "all") \ .input(3, "diff_offset", False, "required", "all") \
.input(4, "scale", False, "required", "all") \ .input(4, "scale", False, "required", "all") \
.input(5, "batch_mean", False, "required", "all") \ .input(5, "batch_mean", False, "required", "all") \
.input(6, "batch_variance", False, "required", "all") \ .input(6, "batch_variance", False, "required", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all", reshape_type="NC") \
.dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
......
...@@ -26,14 +26,14 @@ bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \ ...@@ -26,14 +26,14 @@ bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \
.attr("factor", "optional", "float", "all") \ .attr("factor", "optional", "float", "all") \
.attr("epsilon", "optional", "float", "all") \ .attr("epsilon", "optional", "float", "all") \
.attr("isRef", "optional", "bool", "all", "true") \ .attr("isRef", "optional", "bool", "all", "true") \
.input(0, "x", False, "required", "all") \ .input(0, "x", False, "required", "all", reshape_type="NC") \
.input(1, "sum", False, "required", "all") \ .input(1, "sum", False, "required", "all") \
.input(2, "square_sum", False, "required", "all") \ .input(2, "square_sum", False, "required", "all") \
.input(3, "scale", False, "required", "all") \ .input(3, "scale", False, "required", "all") \
.input(4, "offset", False, "required", "all") \ .input(4, "offset", False, "required", "all") \
.input(5, "mean", False, "required", "all") \ .input(5, "mean", False, "required", "all") \
.input(6, "variance", False, "required", "all") \ .input(6, "variance", False, "required", "all") \
.output(0, "y", False, "required", "all") \ .output(0, "y", False, "required", "all", reshape_type="NC") \
.output(1, "mean", False, "required", "all") \ .output(1, "mean", False, "required", "all") \
.output(2, "variance", False, "required", "all") \ .output(2, "variance", False, "required", "all") \
.output(3, "batch_mean", False, "required", "all") \ .output(3, "batch_mean", False, "required", "all") \
......
...@@ -24,8 +24,8 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \ ...@@ -24,8 +24,8 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \
.kernel_name("bn_training_update_grad") \ .kernel_name("bn_training_update_grad") \
.partial_flag(True) \ .partial_flag(True) \
.attr("epsilon", "optional", "float", "all") \ .attr("epsilon", "optional", "float", "all") \
.input(0, "grads", False, "required", "all") \ .input(0, "grads", False, "required", "all", reshape_type="NC") \
.input(1, "x", False, "required", "all") \ .input(1, "x", False, "required", "all", reshape_type="NC") \
.input(2, "batch_mean", False, "required", "all") \ .input(2, "batch_mean", False, "required", "all") \
.input(3, "batch_variance", False, "required", "all") \ .input(3, "batch_variance", False, "required", "all") \
.output(0, "diff_scale", False, "required", "all") \ .output(0, "diff_scale", False, "required", "all") \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册