diff --git a/.gitignore b/.gitignore index 4ddcff8adc7d589edca16b507c760187464e431d..17dd1720235e209afd3599a3117547a07cd5ba56 100644 --- a/.gitignore +++ b/.gitignore @@ -71,7 +71,9 @@ paddle/fluid/pybind/eager_op_function.cc # these files (directories) are generated before build system generation paddle/fluid/operators/generated_op.cc +paddle/fluid/operators/generated_sparse_op.cc paddle/phi/ops/compat/generated_sig.cc +paddle/phi/ops/compat/generated_sparse_sig.cc paddle/phi/api/yaml/parsed_apis/ python/paddle/utils/code_gen/ paddle/fluid/pybind/tmp_eager_op_function_impl.h diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 3199582ce29a5a5c7cb13f1c74c31a58afd8241d..efad9f61ee3f9040a96a02c46ca1423a35662fb5 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -55,7 +55,9 @@ static std::unordered_set black_ops_list = {"run_program", "fused_gate_attention", "fused_feedforward", "fused_attention", - "fused_gemm_epilogue"}; + "fused_gemm_epilogue", + "sparse_divide_scalar", + "sparse_scale"}; static std::string LegalizeVariableName(const std::string& var_name) { std::string ret = var_name; @@ -3161,6 +3163,12 @@ static void DygraphCodeGeneration(const std::string& output_dir, continue; } + // Skip the sparse op + if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" && + op_type != "sparse_attention") { + continue; + } + GradNodeGenerationInfo bwd_info; bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info); diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index b31c80fe9dc52d327ee100274fa9612320ec8201..6e3897717596bd44e43d7c1dbef0e01180488cef 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -190,7 +190,7 @@ cc_test( cc_library( var_type_traits SRCS var_type_traits.cc - DEPS framework_proto scope tensor_array sparse_coo_tensor) + DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor) if(WITH_GPU) target_link_libraries(var_type_traits dynload_cuda) endif() diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 3dbb6693e8d838af7936d448da50e8286df14b27..e99316928aba60b7fa1e5a0dc75feb69359deb6f 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -156,6 +156,8 @@ message VarType { PSTRING = 29; // the data type of phi::SparseCooTensor SPARSE_COO = 30; + // the data type of phi::SparseCsrTensor + SPARSE_CSR = 31; } required Type type = 1; @@ -189,6 +191,7 @@ message VarType { optional TensorDesc strings = 9; optional TensorDesc vocab = 10; optional TensorDesc sparse_coo = 11; + optional TensorDesc sparse_csr = 12; } message VarDesc { diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index a97f36d3b55518fcf098bc56dfed0f0969142854..0bf91764b1a7abc908f5bed321c251d4484ae2e1 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -106,6 +106,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { return var_type == proto::VarType::SPARSE_COO; } + bool IsSparseCsrTensorInput(const std::string& name) const override { + auto var_type = ctx_.GetInputVarType(name); + return var_type == proto::VarType::SPARSE_CSR; + } + bool IsDenseTensorOutput(const std::string& name) const override { auto var_types = ctx_.GetOutputsVarType(name); return std::all_of(var_types.begin(), diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index f4b5a6e42ca009f14efebb83993a1705f7e88da8..0eb7dbf4d88afa34e4d284007ddce9f7d39fd033 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -529,6 +529,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { return var->IsType(); } + bool IsSparseCsrTensorInput(const std::string& name) const override { + const auto* var = ctx_.InputVar(name); + return var->IsType(); + } + bool IsDenseTensorOutput(const std::string& name) const override { auto vars = ctx_.MultiOutputVar(name); return std::all_of(vars.begin(), vars.end(), [](const Variable* var) { diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 81ea7d8f0e7467d9325f01703d3cc9fed2fb9b99..4c80f3f7ab404e3f09ef78c18ed82e84b6157b12 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/core/sparse_csr_tensor.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index d2a4788a5038127d7c790c4f801a1c3d7dddabb7..3d78638312ae7435961a31ec10bd813eeb43ee57 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -55,6 +55,7 @@ namespace phi { class DenseTensor; class SelectedRows; class SparseCooTensor; +class SparseCsrTensor; } // namespace phi // Users should add forward declarations here @@ -182,6 +183,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< Tensor, phi::SelectedRows, phi::SparseCooTensor, + phi::SparseCsrTensor, std::vector, LoDRankTable, Strings, diff --git a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc index 121b677d9dcce0d1df83d43b607a9f7630d2c250..6a9ffae51d1372a8d4c32420c7e0122dfcba9607 100644 --- a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc +++ b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc @@ -108,6 +108,10 @@ bool PluginArgumentMappingContext::IsSparseCooTensorInput( const std::string& name) const { return false; } +bool PluginArgumentMappingContext::IsSparseCsrTensorInput( + const std::string& name) const { + return false; +} bool PluginArgumentMappingContext::IsDenseTensorVectorInput( const std::string& name) const { return false; diff --git a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h index cbafdaeec46f1c4122be7d4a4721b460457f33ec..64d26a11b48cb66706210fb2566dca798efeb571 100644 --- a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h +++ b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h @@ -48,6 +48,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext { bool IsSparseCooTensorInput(const std::string& name) const override; + bool IsSparseCsrTensorInput(const std::string& name) const override; + bool IsDenseTensorVectorInput(const std::string& name) const override; bool IsDenseTensorOutput(const std::string& name) const override; diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e4ad6970159a19745c2bf3498ab042d6b2969998..ac1d89ede50214f1eea5c624800407e710a971c2 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -102,7 +102,7 @@ else() cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor) endif() -set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta) +set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta sparse_backward_infermeta) register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/flip_op.cc b/paddle/fluid/operators/flip_op.cc deleted file mode 100644 index 7f00fad6e3d121dafd4631caa907fd159f8d1bee..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/flip_op.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -using framework::OpKernelType; -using framework::Tensor; - -class FlipOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); - } -}; - -class FlipOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of flip op."); - AddOutput("Out", "(Tensor), The output tensor of flip op."); - AddAttr>("axis", "The axes to flip on."); - AddComment(R"DOC( - Flip Operator. - Reverse the order of a n-D tensor along given axis in axes. - )DOC"); - } -}; - -class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map& GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Out"}}; - return m; - } -}; - -template -class FlipOpGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr retv) const override { - retv->SetType("flip"); - retv->SetInput("X", this->OutputGrad("Out")); - retv->SetOutput("Out", this->InputGrad("X")); - retv->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -DECLARE_INFER_SHAPE_FUNCTOR(flip, - FlipInferShapeFunctor, - PD_INFER_META(phi::FlipInferMeta)); -REGISTER_OPERATOR(flip, - ops::FlipOp, - ops::FlipOpMaker, - ops::FlipOpInferVarType, - ops::FlipOpGradMaker, - ops::FlipOpGradMaker, - FlipInferShapeFunctor); - -/* ========================== register checkpoint ===========================*/ -REGISTER_OP_VERSION(flip).AddCheckpoint( - R"ROC(Upgrade flip, add new attr [axis] and delete attr [dims].)ROC", - paddle::framework::compatible::OpVersionDesc() - .NewAttr("axis", - "The added attr 'axis' doesn't set default value.", - paddle::none) - .DeleteAttr("dims", "The attr 'dims' is deleted.")); diff --git a/paddle/fluid/operators/sparse_manual_op.cc b/paddle/fluid/operators/sparse_manual_op.cc index 327e03af80506c87a949f4d9a33078e202e0af44..f95d5250c62f6f3424a7378442fe6002c6a05545 100644 --- a/paddle/fluid/operators/sparse_manual_op.cc +++ b/paddle/fluid/operators/sparse_manual_op.cc @@ -28,50 +28,6 @@ limitations under the License. */ namespace paddle { namespace operators { -class SparseSparseCooTensorOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("values", "(Tensor), input 0 of sparse_coo_tensor op."); - AddInput("indices", "(Tensor), input 1 of sparse_coo_tensor op."); - AddOutput("out", "(Tensor), output 0 of sparse_coo_tensor op."); - AddAttr>( - "dense_shape", "(vector), attribute 0 for sparse_coo_tensor op."); - AddComment(R"DOC( -TODO: Documentation of sparse_coo_tensor op. -)DOC"); - } -}; - -class SparseSparseCooTensorOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR( - sparse_sparse_coo_tensor, - SparseSparseCooTensorInferShapeFunctor, - PD_INFER_META(phi::sparse::SparseCooTensorInferMeta)); - -class SparseValuesOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_values op."); - AddOutput("out", "(Tensor), output 0 of sparse_values op."); - AddComment(R"DOC( -TODO: Documentation of sparse_values op. -)DOC"); - } -}; - -class SparseValuesOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_values, - SparseValuesInferShapeFunctor, - PD_INFER_META(phi::sparse::ValuesInferMeta)); - class SparseIndicesOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -92,182 +48,12 @@ DECLARE_INFER_SHAPE_FUNCTOR(sparse_indices, SparseIndicesInferShapeFunctor, PD_INFER_META(phi::sparse::IndicesInferMeta)); -class SparseToDenseOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_to_dense op."); - AddOutput("out", "(Tensor), output 0 of sparse_to_dense op."); - AddComment(R"DOC( -TODO: Documentation of sparse_to_dense op. -)DOC"); - } -}; - -class SparseToDenseOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_to_dense, - SparseToDenseInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); - -class SparseReluOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_relu op."); - AddOutput("out", "(Tensor), output 0 of sparse_relu op."); - AddComment(R"DOC( -TODO: Documentation of sparse_relu op. -)DOC"); - } -}; - -class SparseReluOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_relu, - SparseReluInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); - -class SparseConv3dOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_conv3d op."); - AddInput("kernel", "(Tensor), input 1 of sparse_conv3d op."); - AddOutput("out", "(Tensor), output 0 of sparse_conv3d op."); - AddOutput("rulebook", "(Tensor), output 1 of sparse_conv3d op."); - AddOutput("counter", "(Tensor), output 2 of sparse_conv3d op."); - AddAttr>( - "paddings", "(vector), attribute 0 for sparse_conv3d op."); - AddAttr>( - "dilations", "(vector), attribute 1 for sparse_conv3d op."); - AddAttr>( - "strides", "(vector), attribute 2 for sparse_conv3d op."); - AddAttr("groups", "(int), attribute 3 for sparse_conv3d op."); - AddAttr("subm", "(bool), attribute 4 for conv3d_coo op."); - AddAttr("key", "(string), attribute 5 for sparse_conv3d op.") - .SetDefault(""); - AddComment(R"DOC( -TODO: Documentation of sparse_conv3d op. -)DOC"); - } -}; - -class SparseConv3dOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_conv3d, - SparseConv3dInferShapeFunctor, - PD_INFER_META(phi::sparse::Conv3dInferMeta)); - -class SparseAddOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_add op."); - AddInput("y", "(Tensor), input 1 of sparse_add op."); - AddOutput("out", "(Tensor), output 0 of sparse_add op."); - AddComment(R"DOC( -TODO: Documentation of sparse_add op. -)DOC"); - } -}; - -class SparseAddOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_add, - SparseAddInferShapeFunctor, - PD_INFER_META(phi::UnchangedInferMeta)); - -class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("x", "(Tensor), input 0 of sparse_batch_norm op."); - AddInput("scale", "(Tensor), input 1 of sparse_batch_norm op."); - AddInput("bias", "(Tensor), input 2 of sparse_batch_norm op."); - AddInput("mean", "(Tensor), input 3 of sparse_batch_norm op."); - AddInput("variance", "(Tensor), input 4 of sparse_batch_norm op."); - AddOutput("y", "(Tensor), output 0 of sparse_batch_norm op."); - AddOutput("mean_out", "(Tensor), output 1 of sparse_batch_norm op."); - AddOutput("variance_out", "(Tensor), output 2 of sparse_batch_norm op."); - AddOutput("saved_mean", "(Tensor), output 3 of sparse_batch_norm op."); - AddOutput("saved_variance", "(Tensor), output 4 of sparse_batch_norm op."); - AddOutput("reserve_space", "(Tensor), output 5 of sparse_batch_norm op."); - AddAttr("momentum", - "(float), attribute 0 for sparse_batch_norm op."); - AddAttr("epsilon", "(float), attribute 1 for sparse_batch_norm op."); - AddAttr("data_layout", - "(string), attribute 2 for sparse_batch_norm op."); - AddAttr("is_test", "(bool), attribute 3 for sparse_batch_norm op."); - AddAttr("use_global_stats", - "(bool), attribute 4 for sparse_batch_norm op."); - AddAttr("trainable_statistics", - "(bool), attribute 4 for sparse_batch_norm op."); - AddAttr("fuse_with_relu", - "(bool), attribute 4 for sparse_batch_norm op."); - AddComment(R"DOC( -TODO: Documentation of sparse_batch_norm op. -)DOC"); - } -}; - -class SparseBatchNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -DECLARE_INFER_SHAPE_FUNCTOR(sparse_batch_norm, - SparseBatchNormInferShapeFunctor, - PD_INFER_META(phi::BatchNormInferMeta)); - } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(sparse_sparse_coo_tensor, - ops::SparseSparseCooTensorOp, - ops::SparseSparseCooTensorOpMaker, - ops::SparseSparseCooTensorInferShapeFunctor); - -REGISTER_OPERATOR(sparse_values, - ops::SparseValuesOp, - ops::SparseValuesOpMaker, - ops::SparseValuesInferShapeFunctor); - REGISTER_OPERATOR(sparse_indices, ops::SparseIndicesOp, ops::SparseIndicesOpMaker, ops::SparseIndicesInferShapeFunctor); - -REGISTER_OPERATOR(sparse_to_dense, - ops::SparseToDenseOp, - ops::SparseToDenseOpMaker, - ops::SparseToDenseInferShapeFunctor); - -REGISTER_OPERATOR(sparse_relu, - ops::SparseReluOp, - ops::SparseReluOpMaker, - ops::SparseReluInferShapeFunctor); - -REGISTER_OPERATOR(sparse_conv3d, - ops::SparseConv3dOp, - ops::SparseConv3dOpMaker, - ops::SparseConv3dInferShapeFunctor); - -REGISTER_OPERATOR(sparse_add, - ops::SparseAddOp, - ops::SparseAddOpMaker, - ops::SparseAddInferShapeFunctor); - -REGISTER_OPERATOR(sparse_batch_norm, - ops::SparseBatchNormOp, - ops::SparseBatchNormOpMaker, - ops::SparseBatchNormInferShapeFunctor); diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 62aa990ca7bc826df129c3c961e779622d69f173..7cc02a0f527a9f8e0b65d3883f326ba88a3cbf68 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -108,7 +108,6 @@ register_unity_group( register_unity_group( cc flatten_op.cc - flip_op.cc fsp_op.cc gather_nd_op.cc gather_op.cc @@ -424,7 +423,6 @@ register_unity_group(cu expand_v2_op.cu fake_dequantize_op.cu fill_any_like_op.cu) register_unity_group( cu - flip_op.cu fsp_op.cu gather_nd_op.cu gather_op.cu diff --git a/paddle/fluid/pybind/eager_legacy_op_function_generator.cc b/paddle/fluid/pybind/eager_legacy_op_function_generator.cc index 1d27d45beb7368d14715649e1c5f8903933b095c..fff811e84ba6f60f41ee76cbba6112170c269e7b 100644 --- a/paddle/fluid/pybind/eager_legacy_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_legacy_op_function_generator.cc @@ -416,6 +416,11 @@ GenerateOpFunctions() { if (CUSTOM_HANDWRITE_OPS_SET.count(op_type)) { continue; } + // Skip the sparse op + if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" && + op_type != "sparse_attention") { + continue; + } // Skip operator which is not inherit form OperatorWithKernel, like while, // since only OperatorWithKernel can run in dygraph mode. // if the phi lib contains op kernel, we still generate ops method diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 3795060d24b98b41765d7100ee6872e1ce14a46e..c4310b43f29bbc7000ab4d7753fb1d20651e1e91 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -118,8 +118,13 @@ endif() set(parsed_api_dir ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/parsed_apis) set(generated_op_path ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc) +set(generated_sparse_ops_path + ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc) set(generated_argument_mapping_path ${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc) +set(generated_sparse_argument_mapping_path + ${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc) + message( "parse api yamls: - ${api_yaml_file} @@ -130,16 +135,22 @@ execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_api_dir} COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./ops.yaml - --output_path ./parsed_apis/api.parsed.yaml + --output_path ./parsed_apis/ops.parsed.yaml COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path - ./legacy_ops.yaml --output_path ./parsed_apis/legacy_api.parsed.yaml + ./legacy_ops.yaml --output_path ./parsed_apis/legacy_ops.parsed.yaml COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./backward.yaml - --output_path ./parsed_apis/backward_api.parsed.yaml --backward + --output_path ./parsed_apis/backward_ops.parsed.yaml --backward COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./legacy_backward.yaml --output_path - ./parsed_apis/legacy_backward_api.parsed.yaml --backward RESULTS_VARIABLE + ./parsed_apis/legacy_backward_ops.parsed.yaml --backward + COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path + ./sparse_ops.yaml --output_path ./parsed_apis/sparse_ops.parsed.yaml + COMMAND + ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path + ./sparse_backward.yaml --output_path + ./parsed_apis/sparse_backward.parsed.yaml --backward RESULTS_VARIABLE _results) foreach(_result in ${_results}) if(${_result}) @@ -149,38 +160,53 @@ endforeach() # validation of api yamls message("validate api yaml: -- ${parsed_api_dir}/api.parsed.yaml -- ${parsed_api_dir}/backward_api.parsed.yaml") +- ${parsed_api_dir}/ops.parsed.yaml +- ${parsed_api_dir}/backward_ops.parsed.yaml") execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml COMMAND ${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths - ./parsed_apis/api.parsed.yaml ./parsed_apis/legacy_api.parsed.yaml - --backward_yaml_paths ./parsed_apis/backward_api.parsed.yaml - ./parsed_apis/legacy_backward_api.parsed.yaml - RESULT_VARIABLE _result) -if(${_result}) - message(FATAL_ERROR "api validation failed, exiting.") -endif() + ./parsed_apis/ops.parsed.yaml ./parsed_apis/legacy_ops.parsed.yaml + --backward_yaml_paths ./parsed_apis/backward_ops.parsed.yaml + ./parsed_apis/legacy_backward_ops.parsed.yaml + COMMAND + ${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths + ./parsed_apis/sparse_ops.parsed.yaml --backward_yaml_paths + ./parsed_apis/sparse_backward.parsed.yaml + RESULT_VARIABLE _results) +foreach(_result in ${_results}) + if(${_result}) + message(FATAL_ERROR "ops validation failed, exiting.") + endif() +endforeach() # code generation for op, op makers, and argument mapping functions message( "create or remove auto-geneated operators: ${generated_op_path}.tmp create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp" ) + execute_process( WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml COMMAND - ${PYTHON_EXECUTABLE} generator/generate_op.py --api_yaml_path - ./parsed_apis/api.parsed.yaml --backward_api_yaml_path - ./parsed_apis/backward_api.parsed.yaml --api_version_yaml_path + ${PYTHON_EXECUTABLE} generator/generate_op.py --ops_yaml_path + ./parsed_apis/ops.parsed.yaml --backward_yaml_path + ./parsed_apis/backward_ops.parsed.yaml --op_version_yaml_path op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path "${generated_op_path}.tmp" --output_arg_map_path "${generated_argument_mapping_path}.tmp" - RESULT_VARIABLE _result) -if(${_result}) - message(FATAL_ERROR "operator codegen failed, exiting.") -endif() + COMMAND + ${PYTHON_EXECUTABLE} generator/generate_sparse_op.py --ops_yaml_path + ./parsed_apis/sparse_ops.parsed.yaml --backward_ops_yaml_path + ./parsed_apis/sparse_backward.parsed.yaml --output_op_path + "${generated_sparse_ops_path}.tmp" --output_arg_map_path + "${generated_sparse_argument_mapping_path}.tmp" + RESULT_VARIABLE _results) +foreach(_result in ${_results}) + if(${_result}) + message(FATAL_ERROR "operator codegen failed, exiting.") + endif() +endforeach() if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}") execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different @@ -195,6 +221,25 @@ else() message("remove ${generated_op_path}") endif() +if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS + "${generated_sparse_ops_path}") + execute_process( + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}") + message( + "copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}" + ) +elseif(EXISTS "${generated_sparse_ops_path}.tmp") + execute_process( + COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp" + "${generated_sparse_ops_path}") + message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}") +else() + execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f + "${generated_sparse_ops_path}") + message("remove ${generated_sparse_ops_path}") +endif() + if(EXISTS "${generated_argument_mapping_path}.tmp" AND EXISTS "${generated_argument_mapping_path}") execute_process( @@ -218,6 +263,30 @@ else() message("remove ${generated_argument_mapping_path}") endif() +if(EXISTS "${generated_sparse_argument_mapping_path}.tmp" + AND EXISTS "${generated_sparse_argument_mapping_path}") + execute_process( + COMMAND + ${CMAKE_COMMAND} -E copy_if_different + "${generated_sparse_argument_mapping_path}.tmp" + "${generated_sparse_argument_mapping_path}") + message( + "copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}" + ) +elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp") + execute_process( + COMMAND + ${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp" + "${generated_sparse_argument_mapping_path}") + message( + "copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}" + ) +else() + execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f + "${generated_sparse_argument_mapping_path}") + message("remove ${generated_sparse_argument_mapping_path}") +endif() + # generate ops extra info execute_process( COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 9d81435169c6a3b7de33aeab5deff24f6919bc79..6603c785a0f48abcc2558ddff3dd2239d6871f02 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -147,6 +147,12 @@ data_type: out_grad no_need_buffer: x +- backward_op : flip_grad + forward : flip (Tensor x, int[] axis) -> Tensor(out) + args : (Tensor out_grad, int[] axis) + output : Tensor(x_grad) + invoke : flip(out_grad, axis) + - backward_op : graph_send_uv_grad forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out) args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD") diff --git a/paddle/phi/api/yaml/generator/generate_op.py b/paddle/phi/api/yaml/generator/generate_op.py index 5fa3be685e487fae2ee33be818a570b637290e97..4984db8f8b3cda29f4fc10d1d4e3a7048e7110eb 100644 --- a/paddle/phi/api/yaml/generator/generate_op.py +++ b/paddle/phi/api/yaml/generator/generate_op.py @@ -21,18 +21,32 @@ from pathlib import Path import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined -from filters import to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case -from tests import is_base_api, is_vec, is_scalar, is_initializer_list, supports_inplace, supports_no_need_buffer +from filters import ( + to_op_attr_type, + to_opmaker_name, + to_opmaker_name_cstr, + to_pascal_case, +) +from tests import ( + is_base_api, + is_vec, + is_scalar, + is_initializer_list, + supports_inplace, + supports_no_need_buffer, +) from filters import to_input_name, cartesian_prod_mapping from parse_utils import to_named_dict file_loader = FileSystemLoader(Path(__file__).parent / "templates") -env = Environment(loader=file_loader, - keep_trailing_newline=True, - trim_blocks=True, - lstrip_blocks=True, - undefined=StrictUndefined, - extensions=['jinja2.ext.do']) +env = Environment( + loader=file_loader, + keep_trailing_newline=True, + trim_blocks=True, + lstrip_blocks=True, + undefined=StrictUndefined, + extensions=['jinja2.ext.do'], +) env.filters["to_op_attr_type"] = to_op_attr_type env.filters["to_opmaker_name"] = to_opmaker_name env.filters["to_pascal_case"] = to_pascal_case @@ -56,7 +70,6 @@ def restruct_io(api): # replace name of op and params for OpMaker def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): - def get_api_and_op_name(api_item): names = api_item.split('(') if len(names) == 1: @@ -76,7 +89,8 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): forward_api_item['op_name'] = op_name if 'backward' in api_args and has_backward: bw_api_name, bw_op_name = get_api_and_op_name( - api_args['backward'].split(',')[0]) + api_args['backward'].split(',')[0] + ) forward_api_item['backward'] = bw_op_name backward_api_item['op_name'] = bw_op_name @@ -102,8 +116,10 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): ] if forward_api_item['kernel']['data_type']: forward_api_item['kernel']['data_type']['candidates'] = [ - args_map[param] if param in args_map else param for param in - forward_api_item['kernel']['data_type']['candidates'] + args_map[param] if param in args_map else param + for param in forward_api_item['kernel']['data_type'][ + 'candidates' + ] ] if forward_api_item['kernel']['backend']: forward_api_item['kernel']['backend']['candidates'] = [ @@ -130,21 +146,36 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): for args_item in backward_api_item['inputs']: if args_item['name'] in args_map: args_item['name'] = args_map[args_item['name']] - elif args_item['name'].endswith( - '_grad') and args_item['name'][:-5] in args_map: - args_map[args_item['name']] = args_map[args_item['name'] - [:-5]] + '_grad' + elif ( + args_item['name'].endswith('_grad') + and args_item['name'][:-5] in args_map + ): + args_map[args_item['name']] = ( + args_map[args_item['name'][:-5]] + '_grad' + ) args_item['name'] = args_map[args_item['name']] for args_item in backward_api_item['attrs']: if args_item['name'] in args_map: args_item['name'] = args_map[args_item['name']] for args_item in backward_api_item['outputs']: - if args_item['name'].endswith( - '_grad') and args_item['name'][:-5] in args_map: - args_map[args_item['name']] = args_map[args_item['name'] - [:-5]] + '_grad' + if ( + args_item['name'].endswith('_grad') + and args_item['name'][:-5] in args_map + ): + args_map[args_item['name']] = ( + args_map[args_item['name'][:-5]] + '_grad' + ) args_item['name'] = args_map[args_item['name']] + if 'invoke' in backward_api_item: + backward_api_item['invoke']['args'] = [ + args_map[param.strip()] + if param.strip() in args_map + else param.strip() + for param in backward_api_item['invoke']['args'].split(',') + ] + continue + backward_api_item['infer_meta']['param'] = [ args_map[param] if param in args_map else param for param in backward_api_item['infer_meta']['param'] @@ -155,18 +186,24 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): ] if backward_api_item['kernel']['data_type']: backward_api_item['kernel']['data_type']['candidates'] = [ - args_map[param] if param in args_map else param for param in - backward_api_item['kernel']['data_type']['candidates'] + args_map[param] if param in args_map else param + for param in backward_api_item['kernel']['data_type'][ + 'candidates' + ] ] if backward_api_item['kernel']['backend']: backward_api_item['kernel']['backend']['candidates'] = [ - args_map[param] if param in args_map else param for param in - backward_api_item['kernel']['backend']['candidates'] + args_map[param] if param in args_map else param + for param in backward_api_item['kernel']['backend'][ + 'candidates' + ] ] if backward_api_item['kernel']['layout']: backward_api_item['kernel']['layout']['candidates'] = [ - args_map[param] if param in args_map else param for param in - backward_api_item['kernel']['layout']['candidates'] + args_map[param] if param in args_map else param + for param in backward_api_item['kernel']['layout'][ + 'candidates' + ] ] if backward_api_item['no_need_buffer']: backward_api_item['no_need_buffer'] = [ @@ -175,9 +212,56 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict): ] -def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, - api_version_yaml_path, output_op_path, output_arg_map_path): - with open(api_yaml_path, "rt") as f: +def process_invoke_op(forward_api_dict, backward_api_dict): + for bw_api in backward_api_dict.values(): + if 'invoke' in bw_api: + invoke_op = bw_api['invoke']['func'] + args_list = bw_api['invoke']['args'] + args_index = 0 + if invoke_op in forward_api_dict: + reuse_op = forward_api_dict[invoke_op] + bw_api['invoke']['inputs'] = [] + bw_api['invoke']['attrs'] = [] + bw_api['invoke']['outputs'] = [] + for input_item in reuse_op['inputs']: + bw_api['invoke']['inputs'].append( + { + 'name': input_item['name'], + 'value': args_list[args_index], + } + ) + args_index = args_index + 1 + for attr in reuse_op['attrs']: + if args_index < len(args_list): + attr_value = ( + f"this->GetAttr(\"{args_list[args_index]}\")" + if args_list[args_index] in bw_api['attr_dict'] + else args_list[args_index] + ) + bw_api['invoke']['attrs'].append( + {'name': attr['name'], 'value': attr_value} + ) + args_index = args_index + 1 + else: + break + for idx, output_item in enumerate(reuse_op['outputs']): + bw_api['invoke']['outputs'].append( + { + 'name': output_item['name'], + 'value': bw_api['outputs'][idx]['name'], + } + ) + + +def main( + ops_yaml_path, + backward_yaml_path, + op_compat_yaml_path, + op_version_yaml_path, + output_op_path, + output_arg_map_path, +): + with open(ops_yaml_path, "rt") as f: apis = yaml.safe_load(f) apis = [restruct_io(api) for api in apis] forward_api_dict = to_named_dict(apis) @@ -187,7 +271,7 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, backward_apis = [restruct_io(api) for api in backward_apis] backward_api_dict = to_named_dict(backward_apis) - with open(api_version_yaml_path, "rt") as f: + with open(op_version_yaml_path, "rt") as f: api_versions = yaml.safe_load(f) # add api version info into api for api_version in api_versions: @@ -203,6 +287,9 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, replace_compat_name(api_op_map, forward_api_dict, backward_api_dict) + # prepare for invoke case + process_invoke_op(forward_api_dict, backward_api_dict) + # fill backward field for an api if another api claims it as forward for name, backward_api in backward_api_dict.items(): forward_name = backward_api["forward"]["name"] @@ -224,9 +311,9 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, op_template = env.get_template('op.c.j2') with open(output_op_path, "wt") as f: - msg = op_template.render(apis=apis, - backward_apis=backward_apis, - api_dict=api_dict) + msg = op_template.render( + apis=apis, backward_apis=backward_apis, api_dict=api_dict + ) f.write(msg) ks_template = env.get_template('ks.c.j2') @@ -237,28 +324,35 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path, if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate operator file from api yaml.") - parser.add_argument('--api_yaml_path', - type=str, - help="parsed api yaml file.") - parser.add_argument('--backward_api_yaml_path', - type=str, - help="parsed backward api yaml file.") - parser.add_argument('--op_compat_yaml_path', - type=str, - help="api args compat yaml file.") - parser.add_argument('--api_version_yaml_path', - type=str, - help="api version yaml file.") - parser.add_argument("--output_op_path", - type=str, - help="path to save generated operators.") + description="Generate operator file from api yaml." + ) + parser.add_argument( + '--ops_yaml_path', type=str, help="parsed ops yaml file." + ) + parser.add_argument( + '--backward_yaml_path', type=str, help="parsed backward ops yaml file." + ) + parser.add_argument( + '--op_compat_yaml_path', type=str, help="ops args compat yaml file." + ) + parser.add_argument( + '--op_version_yaml_path', type=str, help="ops version yaml file." + ) + parser.add_argument( + "--output_op_path", type=str, help="path to save generated operators." + ) parser.add_argument( "--output_arg_map_path", type=str, - help="path to save generated argument mapping functions.") + help="path to save generated argument mapping functions.", + ) args = parser.parse_args() - main(args.api_yaml_path, args.backward_api_yaml_path, - args.op_compat_yaml_path, args.api_version_yaml_path, - args.output_op_path, args.output_arg_map_path) + main( + args.ops_yaml_path, + args.backward_yaml_path, + args.op_compat_yaml_path, + args.op_version_yaml_path, + args.output_op_path, + args.output_arg_map_path, + ) diff --git a/paddle/phi/api/yaml/generator/generate_sparse_op.py b/paddle/phi/api/yaml/generator/generate_sparse_op.py new file mode 100644 index 0000000000000000000000000000000000000000..48ba0d81eca3d621fb3e39c317a93da675544da1 --- /dev/null +++ b/paddle/phi/api/yaml/generator/generate_sparse_op.py @@ -0,0 +1,168 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from pathlib import Path + +import yaml +from jinja2 import Environment, FileSystemLoader, StrictUndefined + +from filters import ( + to_op_attr_type, + to_opmaker_name, + to_opmaker_name_cstr, + to_pascal_case, +) +from tests import ( + is_base_api, + is_vec, + is_scalar, + is_initializer_list, + supports_inplace, + supports_no_need_buffer, +) +from filters import to_input_name, cartesian_prod_mapping +from parse_utils import to_named_dict +from generate_op import process_invoke_op + +file_loader = FileSystemLoader(Path(__file__).parent / "templates") +env = Environment( + loader=file_loader, + keep_trailing_newline=True, + trim_blocks=True, + lstrip_blocks=True, + undefined=StrictUndefined, + extensions=['jinja2.ext.do'], +) +env.filters["to_op_attr_type"] = to_op_attr_type +env.filters["to_opmaker_name"] = to_opmaker_name +env.filters["to_pascal_case"] = to_pascal_case +env.filters["to_input_name"] = to_input_name +env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr +env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping +env.tests["base_api"] = is_base_api +env.tests["vec"] = is_vec +env.tests["scalar"] = is_scalar +env.tests["initializer_list"] = is_initializer_list +env.tests["supports_inplace"] = supports_inplace +env.tests["supports_no_need_buffer"] = supports_no_need_buffer + + +def restruct_io(api): + api["input_dict"] = to_named_dict(api["inputs"]) + api["attr_dict"] = to_named_dict(api["attrs"]) + api["output_dict"] = to_named_dict(api["outputs"]) + return api + + +SPARSE_OP_PREFIX = 'sparse_' + + +def main( + api_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path +): + with open(api_yaml_path, "rt") as f: + apis = yaml.safe_load(f) + apis = [restruct_io(api) for api in apis] + forward_api_dict = to_named_dict(apis) + + with open(backward_yaml_path, "rt") as f: + backward_apis = yaml.safe_load(f) + backward_apis = [restruct_io(api) for api in backward_apis] + backward_api_dict = to_named_dict(backward_apis) + + for api in apis: + api['op_name'] = SPARSE_OP_PREFIX + api['name'] + api['name'] = api['op_name'] + if api["backward"] is not None: + api["backward"] = SPARSE_OP_PREFIX + api["backward"] + for bw_api in backward_apis: + bw_api['op_name'] = SPARSE_OP_PREFIX + bw_api['name'] + bw_api['name'] = bw_api['op_name'] + if 'invoke' in bw_api: + bw_api['invoke']['args'] = [ + param.strip() for param in bw_api['invoke']['args'].split(',') + ] + + # prepare for invoke case + process_invoke_op(forward_api_dict, backward_api_dict) + for bw_api in backward_apis: + if 'invoke' in bw_api: + if bw_api['invoke']['func'] in forward_api_dict: + bw_api['invoke']['func'] = ( + SPARSE_OP_PREFIX + bw_api['invoke']['func'] + ) + + # fill backward field for an api if another api claims it as forward + for name, backward_api in backward_api_dict.items(): + forward_name = backward_api["forward"]["name"] + if forward_name in backward_api_dict: + forward_api = backward_api_dict[forward_name] + if forward_api["backward"] is None: + forward_api["backward"] = name + forward_api["backward"] = SPARSE_OP_PREFIX + forward_api["backward"] + + api_dict = {} + api_dict.update(forward_api_dict) + api_dict.update(backward_api_dict) + + if len(apis) == 0 and len(backward_apis) == 0: + if os.path.isfile(output_op_path): + os.remove(output_op_path) + if os.path.isfile(output_arg_map_path): + os.remove(output_arg_map_path) + return + + op_template = env.get_template('sparse_op.c.j2') + with open(output_op_path, "wt") as f: + msg = op_template.render( + apis=apis, backward_apis=backward_apis, api_dict=api_dict + ) + f.write(msg) + + ks_template = env.get_template('sparse_ks.c.j2') + with open(output_arg_map_path, 'wt') as f: + msg = ks_template.render(apis=apis, backward_apis=backward_apis) + f.write(msg) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate operator file from api yaml." + ) + parser.add_argument( + '--ops_yaml_path', type=str, help="parsed sparse ops yaml file." + ) + parser.add_argument( + '--backward_ops_yaml_path', + type=str, + help="parsed backward sparse ops yaml file.", + ) + parser.add_argument( + "--output_op_path", type=str, help="path to save generated operators." + ) + parser.add_argument( + "--output_arg_map_path", + type=str, + help="path to save generated argument mapping functions.", + ) + + args = parser.parse_args() + main( + args.ops_yaml_path, + args.backward_ops_yaml_path, + args.output_op_path, + args.output_arg_map_path, + ) diff --git a/paddle/phi/api/yaml/generator/parse_utils.py b/paddle/phi/api/yaml/generator/parse_utils.py index f617f166dd13b5e1a23ab489beb6d6651a21d0d5..0502dba0fdf8a8a1a71000872db33ffa0ac0230b 100644 --- a/paddle/phi/api/yaml/generator/parse_utils.py +++ b/paddle/phi/api/yaml/generator/parse_utils.py @@ -35,39 +35,42 @@ def parse_arg(api_name: str, s: str) -> Dict[str, str]: 2. typename name = default_value """ typename, rest = [item.strip() for item in s.split(" ", 1)] - assert len( - typename - ) > 0, f"The arg typename should not be empty. Please check the args of {api_name} in yaml." + assert ( + len(typename) > 0 + ), f"The arg typename should not be empty. Please check the args of {api_name} in yaml." - assert rest.count( - "=") <= 1, f"There is more than 1 = in an arg in {api_name}" + assert ( + rest.count("=") <= 1 + ), f"There is more than 1 = in an arg in {api_name}" if rest.count("=") == 1: name, default_value = [item.strip() for item in rest.split("=", 1)] - assert len( - name - ) > 0, f"The arg name should not be empty. Please check the args of {api_name} in yaml." - assert len( - default_value - ) > 0, f"The default value should not be empty. Please check the args of {api_name} in yaml." + assert ( + len(name) > 0 + ), f"The arg name should not be empty. Please check the args of {api_name} in yaml." + assert ( + len(default_value) > 0 + ), f"The default value should not be empty. Please check the args of {api_name} in yaml." return { "typename": typename, "name": name, - "default_value": default_value + "default_value": default_value, } else: name = rest.strip() - assert len( - name - ) > 0, f"The arg name should not be empty. Please check the args of {api_name} in yaml." + assert ( + len(name) > 0 + ), f"The arg name should not be empty. Please check the args of {api_name} in yaml." return {"typename": typename, "name": name} -def parse_input_and_attr(api_name: str, - arguments: str) -> Tuple[List, List, Dict, Dict]: +def parse_input_and_attr( + api_name: str, arguments: str +) -> Tuple[List, List, Dict, Dict]: args_str = arguments.strip() - assert args_str.startswith('(') and args_str.endswith(')'), \ - (f"Args declaration should start with '(' and end with ')', " - f"please check the args of {api_name} in yaml.") + assert args_str.startswith('(') and args_str.endswith(')'), ( + f"Args declaration should start with '(' and end with ')', " + f"please check the args of {api_name} in yaml." + ) args_str = args_str[1:-1] args = parse_plain_list(args_str) @@ -81,14 +84,17 @@ def parse_input_and_attr(api_name: str, typename = item["typename"] name = item["name"] if is_input(typename): - assert len(attrs) == 0, \ - (f"The input Tensor should appear before attributes. " + assert len(attrs) == 0, ( + f"The input Tensor should appear before attributes. " f"please check the position of {api_name}:input({name}) " - f"in yaml.") + f"in yaml." + ) inputs.append(item) elif is_attr(typename): if met_attr_with_default_value: - assert "default_value" in item, f"{api_name}: Arguments with default value should not precede those without default value" + assert ( + "default_value" in item + ), f"{api_name}: Arguments with default value should not precede those without default value" elif "default_value" in item: met_attr_with_default_value = True attrs.append(item) @@ -101,7 +107,8 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]: """parse an output, typename or typename(name).""" match = re.search( r"(?P[a-zA-Z0-9_[\]]+)\s*(?P\([a-zA-Z0-9_@]+\))?\s*(?P\{[^\}]+\})?", - s) + s, + ) typename = match.group("out_type") name = match.group("name") size_expr = match.group("expr") @@ -109,13 +116,15 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]: name = name[1:-1] if name is not None else 'out' size_expr = size_expr[1:-1] if size_expr is not None else None - assert is_output(typename), \ - (f"Invalid output type: {typename} in api: {api_name}." - f"Supported types are Tensor and Tensor[]") + assert is_output(typename), ( + f"Invalid output type: {typename} in api: {api_name}." + f"Supported types are Tensor and Tensor[]" + ) if size_expr is not None: - assert is_vec(typename), \ - (f"Invalid output size: output {name} in api: {api_name} is " - f"not a vector but has size expr") + assert is_vec(typename), ( + f"Invalid output size: output {name} in api: {api_name} is " + f"not a vector but has size expr" + ) return {"typename": typename, "name": name, "size": size_expr} else: return {"typename": typename, "name": name} @@ -149,22 +158,24 @@ def parse_plain_list(s: str, sep=",") -> List[str]: return items -def parse_kernel(api_name: str, kernel_config: Dict[str, - Any]) -> Dict[str, Any]: +def parse_kernel( + api_name: str, kernel_config: Dict[str, Any] +) -> Dict[str, Any]: # kernel : # func : [], Kernel functions (example: scale, scale_sr) # param : [], Input params of kernel # backend : str, the names of param to choose the kernel backend, default is None # layout : str, the names of param to choose the kernel layout, default is None # data_type : str, the names of param to choose the kernel data_type, default is None + # dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)}) kernel = { - 'func': None, # up to 2 function names + 'func': [], # up to 2 function names 'param': None, 'backend': None, 'layout': None, - 'data_type': None + 'data_type': None, + 'dispatch': {}, } - kernel['func'] = parse_plain_list(kernel_config['func']) if 'param' in kernel_config: kernel['param'] = kernel_config['param'] @@ -176,6 +187,42 @@ def parse_kernel(api_name: str, kernel_config: Dict[str, if 'data_type' in kernel_config: kernel['data_type'] = parse_candidates(kernel_config["data_type"]) + + kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall( + kernel_config['func'] + ) + + def parse_kernel_in_out_type(in_out_str): + if len(in_out_str) == 0: + return None + tmp_in_out_list = in_out_str[1:-1].split('->') + inputs = [item.strip() for item in tmp_in_out_list[0].split(',')] + outputs = [item.strip() for item in tmp_in_out_list[1].split(',')] + + # check the tensor type + for item in inputs: + assert item in [ + 'dense', + 'selected_rows', + 'sparse_coo', + 'sparse_csr', + ], f"{api_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'." + for item in outputs: + assert item in [ + 'dense', + 'selected_rows', + 'sparse_coo', + 'sparse_csr', + ], f"{api_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'." + + return (inputs, outputs) + + for func_item in kernel_funcs: + kernel['func'].append(func_item[0]) + kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type( + func_item[1] + ) + return kernel @@ -200,10 +247,9 @@ def parse_invoke(api_name: str, invoke_config: str) -> Dict[str, Any]: def extract_type_and_name(records: List[Dict]) -> List[Dict]: """extract type and name from forward call, it is simpler than forward api.""" - extracted = [{ - "name": item["name"], - "typename": item["typename"] - } for item in records] + extracted = [ + {"name": item["name"], "typename": item["typename"]} for item in records + ] return extracted @@ -211,7 +257,8 @@ def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]: # api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out) result = re.search( r"(?P[a-z][a-z0-9_]+)\s*(?P\([^\)]+\))\s*->\s*(?P.+)", - forward_config) + forward_config, + ) api = result.group("op") outputs = parse_outputs(api_name, result.group("outputs")) outputs = extract_type_and_name(outputs) @@ -223,7 +270,7 @@ def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]: "name": api, "inputs": inputs, "attrs": attrs, - "outputs": outputs + "outputs": outputs, } return forward_cfg @@ -239,13 +286,19 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): typename = attr["typename"] default_value = attr["default_value"] if typename == "DataType": - assert "DataType" in default_value, f"invalid DataType default value in {api_name}" + assert ( + "DataType" in default_value + ), f"invalid DataType default value in {api_name}" # remove namespace - default_value = default_value[default_value.find("DataType"):] + default_value = default_value[default_value.find("DataType") :] attr["default_value"] = default_value elif typename == "DataLayout": - assert "DataLayout" in default_value, f"invalid DataLayout default value in {api_name}" - default_value = default_value[default_value.find("DataLayout"):] + assert ( + "DataLayout" in default_value + ), f"invalid DataLayout default value in {api_name}" + default_value = default_value[ + default_value.find("DataLayout") : + ] attr["default_value"] = default_value input_names = [item["name"] for item in inputs] @@ -258,7 +311,9 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): if "optional" in api_entry: optional_args = parse_plain_list(api_entry["optional"]) for name in optional_args: - assert name in input_names, f"{api_name} has an optional input: '{name}' which is not an input." + assert ( + name in input_names + ), f"{api_name} has an optional input: '{name}' which is not an input." for input in inputs: if input["name"] in optional_args: input["optional"] = True @@ -269,7 +324,9 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): if "intermediate" in api_entry: intermediate_outs = parse_plain_list(api_entry["intermediate"]) for name in intermediate_outs: - assert name in output_names, f"{api_name} has an intermediate output: '{name}' which is not an output." + assert ( + name in output_names + ), f"{api_name} has an intermediate output: '{name}' which is not an output." for output in outputs: if output["name"] in intermediate_outs: output["intermediate"] = True @@ -280,7 +337,9 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): if "no_need_buffer" in api_entry: no_buffer_args = parse_plain_list(api_entry["no_need_buffer"]) for name in no_buffer_args: - assert name in input_names, f"{api_name} has an no buffer input: '{name}' which is not an input." + assert ( + name in input_names + ), f"{api_name} has an no buffer input: '{name}' which is not an input." for input in inputs: if input["name"] in no_buffer_args: input["no_need_buffer"] = True @@ -294,7 +353,7 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): "inputs": inputs, "attrs": attrs, "outputs": outputs, - "no_need_buffer": no_buffer_args + "no_need_buffer": no_buffer_args, } # invokes another api? @@ -316,11 +375,13 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): inplace_pairs = parse_inplace(api_name, api_entry["inplace"]) else: inplace_pairs = None - api.update({ - "infer_meta": infer_meta, - "kernel": kernel, - "inplace": inplace_pairs - }) + api.update( + { + "infer_meta": infer_meta, + "kernel": kernel, + "inplace": inplace_pairs, + } + ) else: # invoke invoke = parse_invoke(api_name, api_entry["invoke"]) @@ -339,8 +400,9 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"): if "forward" in api_entry: forward = parse_forward(api_name, api_entry["forward"]) # validate_fb - validate_backward_inputs(api_name, forward["inputs"], - forward["outputs"], inputs) + validate_backward_inputs( + api_name, forward["inputs"], forward["outputs"], inputs + ) validate_backward_attrs(api_name, forward["attrs"], attrs) validate_backward_outputs(api_name, forward["inputs"], outputs) else: @@ -356,23 +418,27 @@ def validate_backward_attrs(api, forward_attrs, backward_attrs): # this is a not-that-clean trick to allow backward api to has more attrs # than the forward api, as long as they all have default value for i in range(-num_exceptional_attrs, 0): - assert "default_value" in backward_attrs[ - i], f"{api} has exceptional attr without default value" + assert ( + "default_value" in backward_attrs[i] + ), f"{api} has exceptional attr without default value" -def validate_backward_inputs(api, forward_inputs, forward_outputs, - backward_inputs): +def validate_backward_inputs( + api, forward_inputs, forward_outputs, backward_inputs +): foward_input_names = [item["name"] for item in forward_inputs] forward_output_names = [item["name"] for item in forward_outputs] backward_input_names = [item["name"] for item in backward_inputs] assert len(backward_input_names) <= len(foward_input_names) + 2 * len( - forward_output_names), f"{api} has too many inputs." + forward_output_names + ), f"{api} has too many inputs." def validate_backward_outputs(api, forward_inputs, backward_outputs): assert len(backward_outputs) <= len( - forward_inputs), f"{api} has too many outputs" + forward_inputs + ), f"{api} has too many outputs" def cross_validate(apis): @@ -391,15 +457,17 @@ def cross_validate(apis): f"Something Wrong here, {name}'s forward api({fw_name}) does not claim {name} as its backward." ) else: - assert fw_api[ - "backward"] == name, f"{name}: backward and forward name mismatch" + assert ( + fw_api["backward"] == name + ), f"{name}: backward and forward name mismatch" assert len(fw_call["inputs"]) <= len( fw_api["inputs"] ), f"{name}: forward call has more inputs than the api" for (input, input_) in zip(fw_call["inputs"], fw_api["inputs"]): - assert input["typename"] == input_[ - "typename"], f"type mismatch in {name} and {fw_name}" + assert ( + input["typename"] == input_["typename"] + ), f"type mismatch in {name} and {fw_name}" assert len(fw_call["attrs"]) <= len( fw_api["attrs"] @@ -411,13 +479,16 @@ def cross_validate(apis): r"Scalar(\(\w+\))*", attr_["typename"] ), f"type mismatch in {name} and {fw_name}" else: - assert attr["typename"] == attr_[ - "typename"], f"type mismatch in {name} and {fw_name}" + assert ( + attr["typename"] == attr_["typename"] + ), f"type mismatch in {name} and {fw_name}" assert len(fw_call["outputs"]) == len( fw_api["outputs"] ), f"{name}: forward call has more outputs than the api" - for (output, output_) in zip(fw_call["outputs"], - fw_api["outputs"]): - assert output["typename"] == output_[ - "typename"], f"type mismatch in {name} and {fw_name}" + for (output, output_) in zip( + fw_call["outputs"], fw_api["outputs"] + ): + assert ( + output["typename"] == output_["typename"] + ), f"type mismatch in {name} and {fw_name}" diff --git a/paddle/phi/api/yaml/generator/templates/op.c.j2 b/paddle/phi/api/yaml/generator/templates/op.c.j2 index 0c2708ce223c7c2bef271cafa75150d8dae0084a..4799866f993cb83602c791ee4f42159bde8e2385 100644 --- a/paddle/phi/api/yaml/generator/templates/op.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/op.c.j2 @@ -1,4 +1,4 @@ -{% from "operator_utils.c.j2" import op_maker, backward_op_maker, operator, register_op_with_components, register_op_version %} +{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %} // this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. #include #include "paddle/fluid/framework/infershape_utils.h" @@ -33,6 +33,8 @@ using paddle::framework::GradVarName; {{backward_op_maker(api, api_dict[api["forward"]["name"]])}} {{operator(api)}} + {% else %} +{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}} {% endif %} {% endfor %} } // namespace operators diff --git a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 index 3910da99d8ae3109b1b4fb5a29afde66c5078bf7..da497e2b3bd00bffca5855b18c282f3f45cc22a2 100644 --- a/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 +++ b/paddle/phi/api/yaml/generator/templates/operator_utils.c.j2 @@ -81,7 +81,11 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty {% set default_value = attr["default_value"] %} {% set typename = attr["typename"] %} {% if typename == "DataType" %}{# convert back to VarType #} + {% if default_value == "DataType::UNDEFINED" %} +-1 + {%- else %} static_cast(framework::TransToProtoVarType(experimental::{{default_value}})) + {%- endif %} {%- elif typename == "DataLayout" %} {# does DataLayout need any processing?#} static_cast(experimental::{{default_value}}) {%- elif typename == "Place" %}{# construct a Place to get the type #} @@ -94,7 +98,7 @@ static_cast(phi::Place({{"phi::" if not default_value is initializer_list}} {# --------------------------------------- name mapping ---------------------------------------------- #} {% macro name_map(api) %} -KernelSignature {{api["name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { +KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { {% set kernel_args = api["kernel"]["param"] %} {{get_input_list(api["inputs"], kernel_args)}}; paddle::small_vector attrs; @@ -124,12 +128,64 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg */ {% endmacro %} +{% macro get_kernel_dispatch(inputs, kernel_config) %}{# inline #} +{%- for kernel_func in kernel_config["func"] %} + {% set input_idx = namespace(idx=0) %} + {% set kernel_in_type_list = kernel_config["dispatch"][kernel_func][0] %} + + if ( {%- for input in inputs %} + {%- if input["name"] in kernel_config["param"] %} + {%- if kernel_in_type_list[input_idx.idx] == "dense" %} +ctx.IsDenseTensorInput("{{input["name"]}}"){{" && " if not loop.last}} + {%- elif kernel_in_type_list[input_idx.idx] == "selected_rows" %} +ctx.IsSelectedRowsInput("{{input["name"]}}"){{" && " if not loop.last}} + {%- elif kernel_in_type_list[input_idx.idx] == "sparse_coo" %} +ctx.IsSparseCooTensorInput("{{input["name"]}}"){{" && " if not loop.last}} + {%- elif kernel_in_type_list[input_idx.idx] == "sparse_csr" %} +ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}} + {%- endif %} + {% set input_idx.idx = input_idx.idx + 1 %} + {%- endif %} + {%- endfor %}) { + kernel_name = "{{kernel_func}}"; + } +{%- endfor %} +{%- endmacro %} + +{% macro sparse_op_name_map(api) %} +KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) { + {% set kernel_args = api["kernel"]["param"] %} + {{get_input_list(api["inputs"], kernel_args)}}; + paddle::small_vector attrs; + {% for attr in api["attrs"]%} + {% filter indent(2)%} + {{get_an_attr(attr)}}; + {% endfilter %} + {% endfor %} + {{get_output_list(api["outputs"], kernel_args)}}; + + const char* kernel_name = "unregistered"; +{{get_kernel_dispatch(api["inputs"], api["kernel"])}} + KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs)); + return sig; +} + +/* +****************************************************************** +NOTE: The following codes are for 'get_compat_kernel_signature.py' +All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping: + +{{api | cartesian_prod_mapping}} +****************************************************************** +*/ +{% endmacro %} + {% macro register_base_kernel_name(api) %} PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}}); {%- endmacro %} {% macro register_name_map(api) %} -PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["op_name"] | to_pascal_case}}OpArgumentMapping); {%- endmacro %} {% macro get_input_list(inputs, kernel_args) %}{# inline #} @@ -352,6 +408,48 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker }; {% endmacro %} +{% macro backward_op_reused_maker(bw_op, forward_op, invoke_op) %} + {% set name = bw_op["op_name"] %} + {% set forward_input_names = bw_op["forward"]["inputs"] | map(attribute="name") | list %} + {% set forward_output_names = bw_op["forward"]["outputs"] | map(attribute="name") | list %} + {% set forward_attr_names = bw_op["forward"]["attrs"] | map(attribute="name") | list %} + {% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %} + {% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %} + {% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %} +template +class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("{{invoke_op["func"]}}"); + + {% for input in invoke_op["inputs"] %} + grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward( + input["value"], + forward_input_names, + forward_output_names, + forward_input_orig_names, + forward_output_orig_names)}}); + {% endfor %} + + {% for output in invoke_op["outputs"] %} + grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward( + output["value"], + forward_input_names, + forward_output_names, + forward_input_orig_names, + forward_output_orig_names)}}); + {% endfor %} + + {% for attr in invoke_op["attrs"] %} + grad_op->SetAttr("{{attr["name"]}}", {{attr["value"]}}); + {% endfor %} + } +}; +{% endmacro %} + {% macro extract_input_from_forward(name, input_names, output_names, diff --git a/paddle/phi/api/yaml/generator/templates/sparse_ks.c.j2 b/paddle/phi/api/yaml/generator/templates/sparse_ks.c.j2 new file mode 100644 index 0000000000000000000000000000000000000000..1af54ca8660838930f72de6642763f50c4a2ca79 --- /dev/null +++ b/paddle/phi/api/yaml/generator/templates/sparse_ks.c.j2 @@ -0,0 +1,24 @@ +{% from "operator_utils.c.j2" import sparse_op_name_map, register_name_map, register_base_kernel_name %} +// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +{% for api in apis %} + {% if api is base_api %} +{{sparse_op_name_map(api)}} + {% endif %} +{% endfor %} +{% for api in backward_apis %} + {% if api is base_api %} +{{sparse_op_name_map(api)}} + {% endif %} +{% endfor %} +} // namespace phi + +{% for api in apis + backward_apis %} + {% if api is base_api %} +{{register_name_map(api)}} + {% endif %} +{% endfor %} diff --git a/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 b/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 new file mode 100644 index 0000000000000000000000000000000000000000..15d887e589e70dba164a26aaf5067f03f22f3c83 --- /dev/null +++ b/paddle/phi/api/yaml/generator/templates/sparse_op.c.j2 @@ -0,0 +1,49 @@ +{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %} +// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit. +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/infermeta/sparse/backward.h" +#include "paddle/phi/infermeta/sparse/binary.h" +#include "paddle/phi/infermeta/sparse/multiary.h" +#include "paddle/phi/infermeta/sparse/unary.h" +#include "paddle/phi/infermeta/ternary.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +using paddle::framework::GradVarName; + +{% for api in apis %} + {% if api is base_api %} + +{{op_maker(api)}} + +{{operator(api)}} + {% endif %} +{% endfor %} + +{% for api in backward_apis %} + {% if api is base_api %} + +{{backward_op_maker(api, api_dict[api["forward"]["name"]])}} + +{{operator(api)}} + {% else %} +{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}} + {% endif %} +{% endfor %} +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +{% for api in apis + backward_apis %} +{% if api is base_api %} +{{register_op_with_components(api)}} +{% endif %} +{% endfor %} diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 2f4eff98356044343bee115cf169a050be117417..2e7d240c5f586d231edcdd2350ca6e013b13f4a4 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -839,12 +839,6 @@ layout: out_grad inplace : (out_grad -> x_grad) -- backward_op : flip_grad - forward : flip (Tensor x, int[] axis) -> Tensor(out) - args : (Tensor out_grad, int[] axis) - output : Tensor(x_grad) - invoke : flip(out_grad, axis) - - backward_op : floor_grad forward : floor(Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index ac67838e6c81b2bda88d1bf279b627411638b9a2..5b9aa9c68c07ac4bf5be63b80f2c5bfaa1563208 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -938,15 +938,6 @@ intermediate : xshape backward : flatten_grad -- op : flip - args : (Tensor x, int[] axis) - output : Tensor - infer_meta : - func : FlipInferMeta - kernel : - func : flip - backward : flip_grad - - op : floor args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index a6360d2ae77026fa16b7d77418531c6048480f4f..ccf3c5852adc00fe0bb4f366bf98bb3cec528b00 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -324,6 +324,12 @@ inputs: {x: X} outputs: {out: Out} +- op : flip + inputs : + x : X + outputs : + out : Out + - op : floor backward : floor_grad extra : diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index 5702884533a280ef4a1d0d9f6b5696608dc71bd4..3028b927966a20ab76c22e23cabc261d585e66d1 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -1,3 +1,13 @@ +- op : flip + version : + - checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims] + action : + - add_attr : axis + comment : The added attr 'axis' doesn't set default value + default : paddle::none + - delete_attr : dims + comment : The attr 'dims' is deleted. + - op : trace version : - checkpoint : Upgrade trace add a new attribute [axis2] diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 02bcd1f0040a8df2dad37473883c4cbea63f5588..10e617bd91243903cf2b1c8b942ec0f92e4d174f 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -199,3 +199,12 @@ kernel : func : trunc backward : trunc_grad + +- op : flip + args : (Tensor x, int[] axis) + output : Tensor (out) + infer_meta : + func : FlipInferMeta + kernel : + func : flip + backward : flip_grad diff --git a/paddle/phi/api/yaml/sparse_backward.yaml b/paddle/phi/api/yaml/sparse_backward.yaml index 5bb52b921680fa3b85f3a593bd24546938f6d09e..ffb5406436faa2f37eb0c6108366594b778ddadc 100644 --- a/paddle/phi/api/yaml/sparse_backward.yaml +++ b/paddle/phi/api/yaml/sparse_backward.yaml @@ -1,5 +1,5 @@ - backward_op : abs_grad - forward : tanh(Tensor x) -> Tensor(out) + forward : abs(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : @@ -124,8 +124,8 @@ cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr} data_type : out_grad -- backward_op : conv3d_coo_grad - forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter) +- backward_op : conv3d_grad + forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter) args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) output : Tensor(x_grad), Tensor(kernel_grad) infer_meta : @@ -432,7 +432,7 @@ transpose_csr_grad {sparse_csr -> sparse_csr} - backward_op : values_grad - forward : values_coo(Tensor x) -> Tensor(out) + forward : values(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) output : Tensor(x_grad) infer_meta : @@ -442,7 +442,7 @@ func : values_coo_grad{sparse_coo, dense-> sparse_coo} - backward_op: fused_attention_grad - forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax) + forward : fused_attention(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax) args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad) output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) infer_meta : diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index 5ef29bdcb16f0350293d6dff86cf7971f5e1d9a3..fb1562520f9a093868c4221567043221fab9e3e1 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -112,7 +112,7 @@ backward : cast_grad - op : conv3d - args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) + args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key="") output : Tensor(out), Tensor(rulebook), Tensor(counter) infer_meta : func : sparse::Conv3dInferMeta @@ -120,7 +120,7 @@ func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense} layout : x intermediate: rulebook, counter - backward : conv3d_coo_grad + backward : conv3d_grad - op : divide args : (Tensor x, Tensor y) diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 3f039e0d62338656a9fa9d7ac194c12ff6b3a5a0..099aed0870892888b9158eed323375b9bbfff7ad 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -109,6 +109,7 @@ class ArgumentMappingContext { virtual bool IsDenseTensorInputs(const std::string& name) const = 0; virtual bool IsSelectedRowsInput(const std::string& name) const = 0; virtual bool IsSparseCooTensorInput(const std::string& name) const = 0; + virtual bool IsSparseCsrTensorInput(const std::string& name) const = 0; // For compatibility with LoDTensorArray virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0; diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index b578afa7c2b854c058e238b2923b2fe5830243d2..10b859fdac260396618116e0da5efbe3f3de192f 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -40,7 +40,7 @@ const std::unordered_set standard_kernel_suffixs({ * after 2.0, and can no longer be occupied by the previously abandoned ops. * They are marked here uniformly. */ -const std::unordered_set deprecated_op_names( +static const std::unordered_set deprecated_op_names( {"diag", "flatten", "flatten_grad", diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 6e16029ee40b551f1d35f0f034233ec5a3fd10cd..a8479f8624ba32fe3d234fdb48e5b68fa5997f86 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -16,6 +16,11 @@ #include "glog/logging.h" #include "paddle/phi/core/enforce.h" +#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) +#include "paddle/fluid/platform/device/xpu/xpu_op_list.h" +#include "paddle/phi/core/compat/convert_utils.h" +#endif +#include "paddle/phi/core/compat/op_utils.h" DECLARE_bool(enable_api_kernel_fallback); @@ -41,6 +46,17 @@ KernelFactory& KernelFactory::Instance() { return g_op_kernel_factory; } +bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const { + if (deprecated_op_names.find(op_type) == deprecated_op_names.end()) { + if (phi::OpUtilsMap::Instance().Contains(op_type)) { + return true; + } else if (kernels_.find(op_type) != kernels_.end()) { + return true; + } + } + return false; +} + const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 8e98c276646d9cf387758417b53eaedba10e0564..ed9280fa475bf5b46c689b0a5e49e470484b77a6 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -272,9 +272,7 @@ class KernelFactory { KernelNameMap& kernels() { return kernels_; } - bool HasCompatiblePhiKernel(const std::string& op_type) const { - return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end(); - } + bool HasCompatiblePhiKernel(const std::string& op_type) const; KernelResult SelectKernelOrThrowError(const std::string& kernel_name, const KernelKey& kernel_key, diff --git a/paddle/phi/ops/compat/sparse_manual_op_sig.cc b/paddle/phi/ops/compat/sparse_manual_op_sig.cc index 6c2a2bc9f451f90bdd53e6eca7195c6d1c9229b0..6e520cbdd96cdb53316e11c750f9505911354a89 100644 --- a/paddle/phi/ops/compat/sparse_manual_op_sig.cc +++ b/paddle/phi/ops/compat/sparse_manual_op_sig.cc @@ -16,23 +16,6 @@ namespace phi { -// TODO(zhangkaihuo): add csr op - -KernelSignature SparseSparseCooTensorOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "sparse_coo_tensor", {"values", "indices"}, {"dense_shape"}, {"out"}); -} - -KernelSignature SparseValuesOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("values_coo", {"x"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - KernelSignature SparseIndicesOpArgumentMapping( const ArgumentMappingContext& ctx) { if (ctx.IsSparseCooTensorInput("x")) { @@ -42,94 +25,6 @@ KernelSignature SparseIndicesOpArgumentMapping( } } -KernelSignature SparseToDenseOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("coo_to_dense", {"x"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseReluOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("relu_coo", {"x"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseConv3dOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature( - "conv3d_coo", - {"x", "kernel"}, - {"paddings", "dilations", "strides", "groups", "subm", "key"}, - {"out", "rulebook", "counter"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseAddOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x") && ctx.IsSparseCooTensorInput("y")) { - return KernelSignature("add_coo_coo", {"x", "y"}, {}, {"out"}); - } else if (ctx.IsSparseCooTensorInput("x") && ctx.IsDenseTensorInput("y")) { - return KernelSignature("add_coo_dense", {"x", "y"}, {}, {"out"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - -KernelSignature SparseBatchNormOpArgumentMapping( - const ArgumentMappingContext& ctx) { - if (ctx.IsSparseCooTensorInput("x")) { - return KernelSignature("batch_norm_coo", - {"x", "scale", "bias", "mean", "variance"}, - {"momentum", - "epsilon", - "data_layout", - "is_test", - "use_global_stats", - "trainable_statistics", - "fuse_with_relu"}, - {"y", - "mean_out", - "variance_out", - "saved_mean", - "saved_variance", - "reserve_space"}); - } else { - return KernelSignature("unregistered", {}, {}, {}); - } -} - } // namespace phi -PD_REGISTER_BASE_KERNEL_NAME(sparse_sparse_coo_tensor, sparse_coo_tensor); -PD_REGISTER_ARG_MAPPING_FN(sparse_sparse_coo_tensor, - phi::SparseSparseCooTensorOpArgumentMapping); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_values, values_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_values, phi::SparseValuesOpArgumentMapping); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_indices, indices_coo); PD_REGISTER_ARG_MAPPING_FN(sparse_indices, phi::SparseIndicesOpArgumentMapping); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_to_dense, coo_to_dense); -PD_REGISTER_ARG_MAPPING_FN(sparse_to_dense, - phi::SparseToDenseOpArgumentMapping); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_relu, relu_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_relu, phi::SparseReluOpArgumentMapping); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_conv3d, conv3d_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_conv3d, phi::SparseConv3dOpArgumentMapping); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_add, add_coo_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_add, phi::SparseAddOpArgumentMapping); - -PD_REGISTER_BASE_KERNEL_NAME(sparse_batch_norm, batch_norm_coo); -PD_REGISTER_ARG_MAPPING_FN(sparse_batch_norm, - phi::SparseBatchNormOpArgumentMapping); diff --git a/paddle/phi/tests/ops/test_op_signature.h b/paddle/phi/tests/ops/test_op_signature.h index 7f66fa6c7629fda8433056d548ae945c824e6c2b..eda7b0f806d7467f3ccba4ddd99a625b82cfe93a 100644 --- a/paddle/phi/tests/ops/test_op_signature.h +++ b/paddle/phi/tests/ops/test_op_signature.h @@ -86,6 +86,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { return false; } + bool IsSparseCsrTensorInput(const std::string& name) const override { + return false; + } + bool IsDenseTensorOutput(const std::string& name) const override { return dense_tensor_outputs.count(name) > 0; } diff --git a/python/paddle/sparse/nn/layer/norm.py b/python/paddle/sparse/nn/layer/norm.py index 7166906f11f23f29814f8ee5c5e0a06dbd935a19..2b0dba5a591e405965078c942d86b94c2855b42e 100644 --- a/python/paddle/sparse/nn/layer/norm.py +++ b/python/paddle/sparse/nn/layer/norm.py @@ -190,19 +190,20 @@ class BatchNorm(paddle.nn.BatchNorm1D): reserve_space = helper.create_variable_for_type_inference( dtype=dtype, stop_gradient=True ) - y = helper.create_sparse_variable_for_type_inference(dtype) + out = helper.create_sparse_variable_for_type_inference(dtype) outputs = { - "y": y, + "out": out, "mean_out": mean_out, "variance_out": variance_out, "saved_mean": saved_mean, "saved_variance": saved_variance, "reserve_space": reserve_space, } + helper.append_op( type=op_type, inputs=inputs, outputs=outputs, attrs=attrs ) - return y + return out class SyncBatchNorm(paddle.nn.SyncBatchNorm): diff --git a/tools/nvcc_lazy.sh b/tools/nvcc_lazy.sh old mode 100644 new mode 100755