未验证 提交 bdd3dde3 编写于 作者: Z zyfncg 提交者: GitHub

[code-gen] Support code-gen for opmaker of sparse op (#46993)

* support generating code of opmaker for backward op invoke forward op

* gsupport code-gen of opmaker for sparse op

* refind logic of choose phi kernrel

* fix complie budg

* fix code_gen bug

* fix bug

* fix kernel signature code-gen

* fix complie bug of VarType

* fix complie bug of VarType

* fix test_sparse_conv_op

* fix test_sparse_norm_op
上级 a9c20660
......@@ -72,7 +72,9 @@ tools/nvcc_lazy
# these files (directories) are generated before build system generation
paddle/fluid/operators/generated_op.cc
paddle/fluid/operators/generated_sparse_op.cc
paddle/phi/ops/compat/generated_sig.cc
paddle/phi/ops/compat/generated_sparse_sig.cc
paddle/phi/api/yaml/parsed_apis/
python/paddle/utils/code_gen/
paddle/fluid/pybind/tmp_eager_op_function_impl.h
......
......@@ -55,7 +55,9 @@ static std::unordered_set<std::string> black_ops_list = {"run_program",
"fused_gate_attention",
"fused_feedforward",
"fused_attention",
"fused_gemm_epilogue"};
"fused_gemm_epilogue",
"sparse_divide_scalar",
"sparse_scale"};
static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name;
......@@ -3161,6 +3163,12 @@ static void DygraphCodeGeneration(const std::string& output_dir,
continue;
}
// Skip the sparse op
if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" &&
op_type != "sparse_attention") {
continue;
}
GradNodeGenerationInfo bwd_info;
bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info);
......
......@@ -237,7 +237,7 @@ cc_test(
cc_library(
var_type_traits
SRCS var_type_traits.cc
DEPS framework_proto scope tensor_array sparse_coo_tensor)
DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor)
if(WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
......
......@@ -156,6 +156,8 @@ message VarType {
PSTRING = 29;
// the data type of phi::SparseCooTensor
SPARSE_COO = 30;
// the data type of phi::SparseCsrTensor
SPARSE_CSR = 31;
}
required Type type = 1;
......@@ -189,6 +191,7 @@ message VarType {
optional TensorDesc strings = 9;
optional TensorDesc vocab = 10;
optional TensorDesc sparse_coo = 11;
optional TensorDesc sparse_csr = 12;
}
message VarDesc {
......
......@@ -117,6 +117,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_type == proto::VarType::SPARSE_COO;
}
bool IsSparseCsrTensorInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::SPARSE_CSR;
}
bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return std::all_of(var_types.begin(),
......
......@@ -543,6 +543,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return var->IsType<phi::SparseCooTensor>();
}
bool IsSparseCsrTensorInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::SparseCsrTensor>();
}
bool IsDenseTensorOutput(const std::string& name) const override {
auto vars = ctx_.MultiOutputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace paddle {
namespace framework {
......
......@@ -55,6 +55,7 @@ namespace phi {
class DenseTensor;
class SelectedRows;
class SparseCooTensor;
class SparseCsrTensor;
} // namespace phi
// Users should add forward declarations here
......@@ -182,6 +183,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
phi::DenseTensor,
phi::SelectedRows,
phi::SparseCooTensor,
phi::SparseCsrTensor,
std::vector<Scope *>,
LoDRankTable,
Strings,
......
......@@ -112,6 +112,10 @@ bool PluginArgumentMappingContext::IsSparseCooTensorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
......
......@@ -50,6 +50,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsSparseCooTensorInput(const std::string& name) const override;
bool IsSparseCsrTensorInput(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override;
......
......@@ -101,7 +101,7 @@ else()
cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta)
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta sparse_backward_infermeta)
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
......
......@@ -28,50 +28,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
class SparseSparseCooTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("values", "(Tensor), input 0 of sparse_coo_tensor op.");
AddInput("indices", "(Tensor), input 1 of sparse_coo_tensor op.");
AddOutput("out", "(Tensor), output 0 of sparse_coo_tensor op.");
AddAttr<std::vector<int>>(
"dense_shape", "(vector<int>), attribute 0 for sparse_coo_tensor op.");
AddComment(R"DOC(
TODO: Documentation of sparse_coo_tensor op.
)DOC");
}
};
class SparseSparseCooTensorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(
sparse_sparse_coo_tensor,
SparseSparseCooTensorInferShapeFunctor,
PD_INFER_META(phi::sparse::SparseCooTensorInferMeta));
class SparseValuesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_values op.");
AddOutput("out", "(Tensor), output 0 of sparse_values op.");
AddComment(R"DOC(
TODO: Documentation of sparse_values op.
)DOC");
}
};
class SparseValuesOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_values,
SparseValuesInferShapeFunctor,
PD_INFER_META(phi::sparse::ValuesInferMeta));
class SparseIndicesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -92,182 +48,12 @@ DECLARE_INFER_SHAPE_FUNCTOR(sparse_indices,
SparseIndicesInferShapeFunctor,
PD_INFER_META(phi::sparse::IndicesInferMeta));
class SparseToDenseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_to_dense op.");
AddOutput("out", "(Tensor), output 0 of sparse_to_dense op.");
AddComment(R"DOC(
TODO: Documentation of sparse_to_dense op.
)DOC");
}
};
class SparseToDenseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_to_dense,
SparseToDenseInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_relu op.");
AddOutput("out", "(Tensor), output 0 of sparse_relu op.");
AddComment(R"DOC(
TODO: Documentation of sparse_relu op.
)DOC");
}
};
class SparseReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_relu,
SparseReluInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseConv3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_conv3d op.");
AddInput("kernel", "(Tensor), input 1 of sparse_conv3d op.");
AddOutput("out", "(Tensor), output 0 of sparse_conv3d op.");
AddOutput("rulebook", "(Tensor), output 1 of sparse_conv3d op.");
AddOutput("counter", "(Tensor), output 2 of sparse_conv3d op.");
AddAttr<std::vector<int>>(
"paddings", "(vector<int>), attribute 0 for sparse_conv3d op.");
AddAttr<std::vector<int>>(
"dilations", "(vector<int>), attribute 1 for sparse_conv3d op.");
AddAttr<std::vector<int>>(
"strides", "(vector<int>), attribute 2 for sparse_conv3d op.");
AddAttr<int>("groups", "(int), attribute 3 for sparse_conv3d op.");
AddAttr<bool>("subm", "(bool), attribute 4 for conv3d_coo op.");
AddAttr<std::string>("key", "(string), attribute 5 for sparse_conv3d op.")
.SetDefault("");
AddComment(R"DOC(
TODO: Documentation of sparse_conv3d op.
)DOC");
}
};
class SparseConv3dOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_conv3d,
SparseConv3dInferShapeFunctor,
PD_INFER_META(phi::sparse::Conv3dInferMeta));
class SparseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_add op.");
AddInput("y", "(Tensor), input 1 of sparse_add op.");
AddOutput("out", "(Tensor), output 0 of sparse_add op.");
AddComment(R"DOC(
TODO: Documentation of sparse_add op.
)DOC");
}
};
class SparseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_add,
SparseAddInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_batch_norm op.");
AddInput("scale", "(Tensor), input 1 of sparse_batch_norm op.");
AddInput("bias", "(Tensor), input 2 of sparse_batch_norm op.");
AddInput("mean", "(Tensor), input 3 of sparse_batch_norm op.");
AddInput("variance", "(Tensor), input 4 of sparse_batch_norm op.");
AddOutput("y", "(Tensor), output 0 of sparse_batch_norm op.");
AddOutput("mean_out", "(Tensor), output 1 of sparse_batch_norm op.");
AddOutput("variance_out", "(Tensor), output 2 of sparse_batch_norm op.");
AddOutput("saved_mean", "(Tensor), output 3 of sparse_batch_norm op.");
AddOutput("saved_variance", "(Tensor), output 4 of sparse_batch_norm op.");
AddOutput("reserve_space", "(Tensor), output 5 of sparse_batch_norm op.");
AddAttr<float>("momentum",
"(float), attribute 0 for sparse_batch_norm op.");
AddAttr<float>("epsilon", "(float), attribute 1 for sparse_batch_norm op.");
AddAttr<std::string>("data_layout",
"(string), attribute 2 for sparse_batch_norm op.");
AddAttr<bool>("is_test", "(bool), attribute 3 for sparse_batch_norm op.");
AddAttr<bool>("use_global_stats",
"(bool), attribute 4 for sparse_batch_norm op.");
AddAttr<bool>("trainable_statistics",
"(bool), attribute 4 for sparse_batch_norm op.");
AddAttr<bool>("fuse_with_relu",
"(bool), attribute 4 for sparse_batch_norm op.");
AddComment(R"DOC(
TODO: Documentation of sparse_batch_norm op.
)DOC");
}
};
class SparseBatchNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_batch_norm,
SparseBatchNormInferShapeFunctor,
PD_INFER_META(phi::BatchNormInferMeta));
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sparse_sparse_coo_tensor,
ops::SparseSparseCooTensorOp,
ops::SparseSparseCooTensorOpMaker,
ops::SparseSparseCooTensorInferShapeFunctor);
REGISTER_OPERATOR(sparse_values,
ops::SparseValuesOp,
ops::SparseValuesOpMaker,
ops::SparseValuesInferShapeFunctor);
REGISTER_OPERATOR(sparse_indices,
ops::SparseIndicesOp,
ops::SparseIndicesOpMaker,
ops::SparseIndicesInferShapeFunctor);
REGISTER_OPERATOR(sparse_to_dense,
ops::SparseToDenseOp,
ops::SparseToDenseOpMaker,
ops::SparseToDenseInferShapeFunctor);
REGISTER_OPERATOR(sparse_relu,
ops::SparseReluOp,
ops::SparseReluOpMaker,
ops::SparseReluInferShapeFunctor);
REGISTER_OPERATOR(sparse_conv3d,
ops::SparseConv3dOp,
ops::SparseConv3dOpMaker,
ops::SparseConv3dInferShapeFunctor);
REGISTER_OPERATOR(sparse_add,
ops::SparseAddOp,
ops::SparseAddOpMaker,
ops::SparseAddInferShapeFunctor);
REGISTER_OPERATOR(sparse_batch_norm,
ops::SparseBatchNormOp,
ops::SparseBatchNormOpMaker,
ops::SparseBatchNormInferShapeFunctor);
......@@ -416,6 +416,11 @@ GenerateOpFunctions() {
if (CUSTOM_HANDWRITE_OPS_SET.count(op_type)) {
continue;
}
// Skip the sparse op
if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" &&
op_type != "sparse_attention") {
continue;
}
// Skip operator which is not inherit form OperatorWithKernel, like while,
// since only OperatorWithKernel can run in dygraph mode.
// if the phi lib contains op kernel, we still generate ops method
......
......@@ -118,8 +118,13 @@ endif()
set(parsed_api_dir ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/parsed_apis)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_sparse_ops_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc)
set(generated_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc)
set(generated_sparse_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc)
message(
"parse api yamls:
- ${api_yaml_file}
......@@ -130,16 +135,22 @@ execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_api_dir}
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./ops.yaml
--output_path ./parsed_apis/api.parsed.yaml
--output_path ./parsed_apis/ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_ops.yaml --output_path ./parsed_apis/legacy_api.parsed.yaml
./legacy_ops.yaml --output_path ./parsed_apis/legacy_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./backward.yaml
--output_path ./parsed_apis/backward_api.parsed.yaml --backward
--output_path ./parsed_apis/backward_ops.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_backward.yaml --output_path
./parsed_apis/legacy_backward_api.parsed.yaml --backward RESULTS_VARIABLE
./parsed_apis/legacy_backward_ops.parsed.yaml --backward
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_ops.yaml --output_path ./parsed_apis/sparse_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_backward.yaml --output_path
./parsed_apis/sparse_backward.parsed.yaml --backward RESULTS_VARIABLE
_results)
foreach(_result in ${_results})
if(${_result})
......@@ -149,19 +160,25 @@ endforeach()
# validation of api yamls
message("validate api yaml:
- ${parsed_api_dir}/api.parsed.yaml
- ${parsed_api_dir}/backward_api.parsed.yaml")
- ${parsed_api_dir}/ops.parsed.yaml
- ${parsed_api_dir}/backward_ops.parsed.yaml")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/api.parsed.yaml ./parsed_apis/legacy_api.parsed.yaml
--backward_yaml_paths ./parsed_apis/backward_api.parsed.yaml
./parsed_apis/legacy_backward_api.parsed.yaml
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "api validation failed, exiting.")
endif()
./parsed_apis/ops.parsed.yaml ./parsed_apis/legacy_ops.parsed.yaml
--backward_yaml_paths ./parsed_apis/backward_ops.parsed.yaml
./parsed_apis/legacy_backward_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/sparse_ops.parsed.yaml --backward_yaml_paths
./parsed_apis/sparse_backward.parsed.yaml
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
endforeach()
# code generation for op, op makers, and argument mapping functions
message(
......@@ -172,15 +189,23 @@ execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/generate_op.py --ops_yaml_path
./parsed_apis/api.parsed.yaml --backward_yaml_path
./parsed_apis/backward_api.parsed.yaml --op_version_yaml_path
./parsed_apis/ops.parsed.yaml --backward_yaml_path
./parsed_apis/backward_ops.parsed.yaml --op_version_yaml_path
op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path
"${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
COMMAND
${PYTHON_EXECUTABLE} generator/generate_sparse_op.py --ops_yaml_path
./parsed_apis/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_apis/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
......@@ -195,6 +220,25 @@ else()
message("remove ${generated_op_path}")
endif()
if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS
"${generated_sparse_ops_path}")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}")
message(
"copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}"
)
elseif(EXISTS "${generated_sparse_ops_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp"
"${generated_sparse_ops_path}")
message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_ops_path}")
message("remove ${generated_sparse_ops_path}")
endif()
if(EXISTS "${generated_argument_mapping_path}.tmp"
AND EXISTS "${generated_argument_mapping_path}")
execute_process(
......@@ -218,6 +262,30 @@ else()
message("remove ${generated_argument_mapping_path}")
endif()
if(EXISTS "${generated_sparse_argument_mapping_path}.tmp"
AND EXISTS "${generated_sparse_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_argument_mapping_path}")
message("remove ${generated_sparse_argument_mapping_path}")
endif()
# generate ops extra info
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path
......
......@@ -181,41 +181,13 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
]
def main(ops_yaml_path, backward_yaml_path, op_compat_yaml_path,
op_version_yaml_path, output_op_path, output_arg_map_path):
with open(ops_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
with open(op_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f)
# add api version info into api
for api_version in api_versions:
forward_api_dict[api_version['op']]['version'] = api_version['version']
with open(op_compat_yaml_path, "rt") as f:
api_op_map = yaml.safe_load(f)
for api in apis:
api['op_name'] = api['name']
for bw_api in backward_apis:
bw_api['op_name'] = bw_api['name']
replace_compat_name(api_op_map, forward_api_dict, backward_api_dict)
# prepare for invoke case
for bw_name, bw_api in backward_api_dict.items():
def process_invoke_op(forward_api_dict, backward_api_dict):
for bw_api in backward_api_dict.values():
if 'invoke' in bw_api:
invoke_op = bw_api['invoke']['func']
args_list = bw_api['invoke']['args']
args_index = 0
if invoke_op in forward_api_dict.keys():
if invoke_op in forward_api_dict:
reuse_op = forward_api_dict[invoke_op]
bw_api['invoke']['inputs'] = []
bw_api['invoke']['attrs'] = []
......@@ -248,6 +220,38 @@ def main(ops_yaml_path, backward_yaml_path, op_compat_yaml_path,
bw_api['outputs'][idx]['name']
})
def main(ops_yaml_path, backward_yaml_path, op_compat_yaml_path,
op_version_yaml_path, output_op_path, output_arg_map_path):
with open(ops_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
with open(op_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f)
# add api version info into api
for api_version in api_versions:
forward_api_dict[api_version['op']]['version'] = api_version['version']
with open(op_compat_yaml_path, "rt") as f:
api_op_map = yaml.safe_load(f)
for api in apis:
api['op_name'] = api['name']
for bw_api in backward_apis:
bw_api['op_name'] = bw_api['name']
replace_compat_name(api_op_map, forward_api_dict, backward_api_dict)
# prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict)
# fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"]
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from filters import to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case
from tests import is_base_api, is_vec, is_scalar, is_initializer_list, supports_inplace, supports_no_need_buffer
from filters import to_input_name, cartesian_prod_mapping
from parse_utils import to_named_dict
from generate_op import process_invoke_op
file_loader = FileSystemLoader(Path(__file__).parent / "templates")
env = Environment(loader=file_loader,
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
undefined=StrictUndefined,
extensions=['jinja2.ext.do'])
env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.tests["base_api"] = is_base_api
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer
def restruct_io(api):
api["input_dict"] = to_named_dict(api["inputs"])
api["attr_dict"] = to_named_dict(api["attrs"])
api["output_dict"] = to_named_dict(api["outputs"])
return api
SPARSE_OP_PREFIX = 'sparse_'
def main(api_yaml_path, backward_yaml_path, output_op_path,
output_arg_map_path):
with open(api_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
for api in apis:
api['op_name'] = SPARSE_OP_PREFIX + api['name']
api['name'] = api['op_name']
if api["backward"] is not None:
api["backward"] = SPARSE_OP_PREFIX + api["backward"]
for bw_api in backward_apis:
bw_api['op_name'] = SPARSE_OP_PREFIX + bw_api['name']
bw_api['name'] = bw_api['op_name']
if 'invoke' in bw_api:
bw_api['invoke']['args'] = [
param.strip() for param in bw_api['invoke']['args'].split(',')
]
# prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict)
for bw_api in backward_apis:
if 'invoke' in bw_api:
if bw_api['invoke']['func'] in forward_api_dict:
bw_api['invoke'][
'func'] = SPARSE_OP_PREFIX + bw_api['invoke']['func']
# fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"]
if forward_name in backward_api_dict:
forward_api = backward_api_dict[forward_name]
if forward_api["backward"] is None:
forward_api["backward"] = name
forward_api["backward"] = SPARSE_OP_PREFIX + forward_api["backward"]
api_dict = {}
api_dict.update(forward_api_dict)
api_dict.update(backward_api_dict)
if len(apis) == 0 and len(backward_apis) == 0:
if os.path.isfile(output_op_path):
os.remove(output_op_path)
if os.path.isfile(output_arg_map_path):
os.remove(output_arg_map_path)
return
op_template = env.get_template('sparse_op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(apis=apis,
backward_apis=backward_apis,
api_dict=api_dict)
f.write(msg)
ks_template = env.get_template('sparse_ks.c.j2')
with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(apis=apis, backward_apis=backward_apis)
f.write(msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate operator file from api yaml.")
parser.add_argument('--ops_yaml_path',
type=str,
help="parsed sparse ops yaml file.")
parser.add_argument('--backward_ops_yaml_path',
type=str,
help="parsed backward sparse ops yaml file.")
parser.add_argument("--output_op_path",
type=str,
help="path to save generated operators.")
parser.add_argument(
"--output_arg_map_path",
type=str,
help="path to save generated argument mapping functions.")
args = parser.parse_args()
main(args.ops_yaml_path, args.backward_ops_yaml_path, args.output_op_path,
args.output_arg_map_path)
......@@ -156,14 +156,15 @@ def parse_kernel(api_name: str, kernel_config: Dict[str,
# backend : str, the names of param to choose the kernel backend, default is None
# layout : str, the names of param to choose the kernel layout, default is None
# data_type : str, the names of param to choose the kernel data_type, default is None
# dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
kernel = {
'func': None, # up to 2 function names
'func': [], # up to 2 function names
'param': None,
'backend': None,
'layout': None,
'data_type': None
'data_type': None,
'dispatch': {}
}
kernel['func'] = parse_plain_list(kernel_config['func'])
if 'param' in kernel_config:
kernel['param'] = kernel_config['param']
......@@ -175,6 +176,34 @@ def parse_kernel(api_name: str, kernel_config: Dict[str,
if 'data_type' in kernel_config:
kernel['data_type'] = parse_candidates(kernel_config["data_type"])
kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
kernel_config['func'])
def parse_kernel_in_out_type(in_out_str):
if len(in_out_str) == 0:
return None
tmp_in_out_list = in_out_str[1:-1].split('->')
inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]
# check the tensor type
for item in inputs:
assert item in [
'dense', 'selected_rows', 'sparse_coo', 'sparse_csr'
], f"{api_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
for item in outputs:
assert item in [
'dense', 'selected_rows', 'sparse_coo', 'sparse_csr'
], f"{api_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
return (inputs, outputs)
for func_item in kernel_funcs:
kernel['func'].append(func_item[0])
kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
func_item[1])
return kernel
......
......@@ -81,7 +81,11 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty
{% set default_value = attr["default_value"] %}
{% set typename = attr["typename"] %}
{% if typename == "DataType" %}{# convert back to VarType #}
{% if default_value == "DataType::UNDEFINED" %}
-1
{%- else %}
static_cast<int>(framework::TransToProtoVarType(experimental::{{default_value}}))
{%- endif %}
{%- elif typename == "DataLayout" %} {# does DataLayout need any processing?#}
static_cast<int>(experimental::{{default_value}})
{%- elif typename == "Place" %}{# construct a Place to get the type #}
......@@ -94,7 +98,7 @@ static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}}
{# --------------------------------------- name mapping ---------------------------------------------- #}
{% macro name_map(api) %}
KernelSignature {{api["name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs;
......@@ -124,12 +128,64 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg
*/
{% endmacro %}
{% macro get_kernel_dispatch(inputs, kernel_config) %}{# inline #}
{%- for kernel_func in kernel_config["func"] %}
{% set input_idx = namespace(idx=0) %}
{% set kernel_in_type_list = kernel_config["dispatch"][kernel_func][0] %}
if ( {%- for input in inputs %}
{%- if input["name"] in kernel_config["param"] %}
{%- if kernel_in_type_list[input_idx.idx] == "dense" %}
ctx.IsDenseTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "selected_rows" %}
ctx.IsSelectedRowsInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "sparse_coo" %}
ctx.IsSparseCooTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "sparse_csr" %}
ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- endif %}
{% set input_idx.idx = input_idx.idx + 1 %}
{%- endif %}
{%- endfor %}) {
kernel_name = "{{kernel_func}}";
}
{%- endfor %}
{%- endmacro %}
{% macro sparse_op_name_map(api) %}
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs;
{% for attr in api["attrs"]%}
{% filter indent(2)%}
{{get_an_attr(attr)}};
{% endfilter %}
{% endfor %}
{{get_output_list(api["outputs"], kernel_args)}};
const char* kernel_name = "unregistered";
{{get_kernel_dispatch(api["inputs"], api["kernel"])}}
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig;
}
/*
******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}}
******************************************************************
*/
{% endmacro %}
{% macro register_base_kernel_name(api) %}
PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}});
{%- endmacro %}
{% macro register_name_map(api) %}
PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["op_name"] | to_pascal_case}}OpArgumentMapping);
{%- endmacro %}
{% macro get_input_list(inputs, kernel_args) %}{# inline #}
......
{% from "operator_utils.c.j2" import sparse_op_name_map, register_name_map, register_base_kernel_name %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace phi {
{% for api in apis %}
{% if api is base_api %}
{{sparse_op_name_map(api)}}
{% endif %}
{% endfor %}
{% for api in backward_apis %}
{% if api is base_api %}
{{sparse_op_name_map(api)}}
{% endif %}
{% endfor %}
} // namespace phi
{% for api in apis + backward_apis %}
{% if api is base_api %}
{{register_name_map(api)}}
{% endif %}
{% endfor %}
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/sparse/backward.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/multiary.h"
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
using paddle::framework::GradVarName;
{% for api in apis %}
{% if api is base_api %}
{{op_maker(api)}}
{{operator(api)}}
{% endif %}
{% endfor %}
{% for api in backward_apis %}
{% if api is base_api %}
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}}
{{operator(api)}}
{% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}}
{% endif %}
{% endfor %}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
{% for api in apis + backward_apis %}
{% if api is base_api %}
{{register_op_with_components(api)}}
{% endif %}
{% endfor %}
- backward_op : abs_grad
forward : tanh(Tensor x) -> Tensor(out)
forward : abs(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
......@@ -432,7 +432,7 @@
transpose_csr_grad {sparse_csr -> sparse_csr}
- backward_op : values_grad
forward : values_coo(Tensor x) -> Tensor(out)
forward : values(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
......@@ -442,7 +442,7 @@
func : values_coo_grad{sparse_coo, dense-> sparse_coo}
- backward_op: fused_attention_grad
forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
forward : fused_attention(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
infer_meta :
......
......@@ -111,7 +111,7 @@
backward : cast_grad
- op : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key="")
output : Tensor(out), Tensor(rulebook), Tensor(counter)
infer_meta :
func : sparse::Conv3dInferMeta
......
......@@ -110,6 +110,7 @@ class ArgumentMappingContext {
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInputs(const std::string& name) const = 0;
virtual bool IsSparseCooTensorInput(const std::string& name) const = 0;
virtual bool IsSparseCsrTensorInput(const std::string& name) const = 0;
// For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
......
......@@ -16,23 +16,6 @@
namespace phi {
// TODO(zhangkaihuo): add csr op
KernelSignature SparseSparseCooTensorOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sparse_coo_tensor", {"values", "indices"}, {"dense_shape"}, {"out"});
}
KernelSignature SparseValuesOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("values_coo", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseIndicesOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
......@@ -42,94 +25,6 @@ KernelSignature SparseIndicesOpArgumentMapping(
}
}
KernelSignature SparseToDenseOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("coo_to_dense", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("relu_coo", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseConv3dOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature(
"conv3d_coo",
{"x", "kernel"},
{"paddings", "dilations", "strides", "groups", "subm", "key"},
{"out", "rulebook", "counter"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseAddOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x") && ctx.IsSparseCooTensorInput("y")) {
return KernelSignature("add_coo_coo", {"x", "y"}, {}, {"out"});
} else if (ctx.IsSparseCooTensorInput("x") && ctx.IsDenseTensorInput("y")) {
return KernelSignature("add_coo_dense", {"x", "y"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseBatchNormOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("batch_norm_coo",
{"x", "scale", "bias", "mean", "variance"},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{"y",
"mean_out",
"variance_out",
"saved_mean",
"saved_variance",
"reserve_space"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(sparse_sparse_coo_tensor, sparse_coo_tensor);
PD_REGISTER_ARG_MAPPING_FN(sparse_sparse_coo_tensor,
phi::SparseSparseCooTensorOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_values, values_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_values, phi::SparseValuesOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_indices, indices_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_indices, phi::SparseIndicesOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_to_dense, coo_to_dense);
PD_REGISTER_ARG_MAPPING_FN(sparse_to_dense,
phi::SparseToDenseOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_relu, relu_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_relu, phi::SparseReluOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_conv3d, conv3d_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_conv3d, phi::SparseConv3dOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_add, add_coo_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_add, phi::SparseAddOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_batch_norm, batch_norm_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_batch_norm,
phi::SparseBatchNormOpArgumentMapping);
......@@ -90,6 +90,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return false;
}
bool IsSparseCsrTensorInput(const std::string& name) const override {
return false;
}
bool IsDenseTensorOutput(const std::string& name) const override {
return dense_tensor_outputs.count(name) > 0;
}
......
......@@ -170,9 +170,9 @@ class BatchNorm(paddle.nn.BatchNorm1D):
dtype=dtype, stop_gradient=True)
reserve_space = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
y = helper.create_sparse_variable_for_type_inference(dtype)
out = helper.create_sparse_variable_for_type_inference(dtype)
outputs = {
"y": y,
"out": out,
"mean_out": mean_out,
"variance_out": variance_out,
"saved_mean": saved_mean,
......@@ -183,7 +183,7 @@ class BatchNorm(paddle.nn.BatchNorm1D):
inputs=inputs,
outputs=outputs,
attrs=attrs)
return y
return out
class SyncBatchNorm(paddle.nn.SyncBatchNorm):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册