未验证 提交 601626ac 编写于 作者: Z zyfncg 提交者: GitHub

[cherry-pick][code-gen] Support code-gen for opmaker of sparse op (#46993) (#47417)

* support generating code of opmaker for backward op invoke forward op (#46912)

* [code-gen] Support code-gen for opmaker of sparse op (#46993)

* support generating code of opmaker for backward op invoke forward op

* gsupport code-gen of opmaker for sparse op

* refind logic of choose phi kernrel

* fix complie budg

* fix code_gen bug

* fix bug

* fix kernel signature code-gen

* fix complie bug of VarType

* fix complie bug of VarType

* fix test_sparse_conv_op

* fix test_sparse_norm_op

* [Phi] Refactor logic of judging whether having a phi kernrel (#46920)

* refind logic of choose phi kernrel

* fix complie budg

* update cmake
上级 23c05f2f
......@@ -71,7 +71,9 @@ paddle/fluid/pybind/eager_op_function.cc
# these files (directories) are generated before build system generation
paddle/fluid/operators/generated_op.cc
paddle/fluid/operators/generated_sparse_op.cc
paddle/phi/ops/compat/generated_sig.cc
paddle/phi/ops/compat/generated_sparse_sig.cc
paddle/phi/api/yaml/parsed_apis/
python/paddle/utils/code_gen/
paddle/fluid/pybind/tmp_eager_op_function_impl.h
......
......@@ -55,7 +55,9 @@ static std::unordered_set<std::string> black_ops_list = {"run_program",
"fused_gate_attention",
"fused_feedforward",
"fused_attention",
"fused_gemm_epilogue"};
"fused_gemm_epilogue",
"sparse_divide_scalar",
"sparse_scale"};
static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name;
......@@ -3161,6 +3163,12 @@ static void DygraphCodeGeneration(const std::string& output_dir,
continue;
}
// Skip the sparse op
if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" &&
op_type != "sparse_attention") {
continue;
}
GradNodeGenerationInfo bwd_info;
bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info);
......
......@@ -190,7 +190,7 @@ cc_test(
cc_library(
var_type_traits
SRCS var_type_traits.cc
DEPS framework_proto scope tensor_array sparse_coo_tensor)
DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor)
if(WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
......
......@@ -156,6 +156,8 @@ message VarType {
PSTRING = 29;
// the data type of phi::SparseCooTensor
SPARSE_COO = 30;
// the data type of phi::SparseCsrTensor
SPARSE_CSR = 31;
}
required Type type = 1;
......@@ -189,6 +191,7 @@ message VarType {
optional TensorDesc strings = 9;
optional TensorDesc vocab = 10;
optional TensorDesc sparse_coo = 11;
optional TensorDesc sparse_csr = 12;
}
message VarDesc {
......
......@@ -106,6 +106,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_type == proto::VarType::SPARSE_COO;
}
bool IsSparseCsrTensorInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::SPARSE_CSR;
}
bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return std::all_of(var_types.begin(),
......
......@@ -529,6 +529,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return var->IsType<phi::SparseCooTensor>();
}
bool IsSparseCsrTensorInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::SparseCsrTensor>();
}
bool IsDenseTensorOutput(const std::string& name) const override {
auto vars = ctx_.MultiOutputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace paddle {
namespace framework {
......
......@@ -55,6 +55,7 @@ namespace phi {
class DenseTensor;
class SelectedRows;
class SparseCooTensor;
class SparseCsrTensor;
} // namespace phi
// Users should add forward declarations here
......@@ -182,6 +183,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor,
phi::SelectedRows,
phi::SparseCooTensor,
phi::SparseCsrTensor,
std::vector<Scope *>,
LoDRankTable,
Strings,
......
......@@ -108,6 +108,10 @@ bool PluginArgumentMappingContext::IsSparseCooTensorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
......
......@@ -48,6 +48,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsSparseCooTensorInput(const std::string& name) const override;
bool IsSparseCsrTensorInput(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override;
......
......@@ -102,7 +102,7 @@ else()
cc_library(gather_scatter_kernel SRCS gather_scatter_kernel.cc gather_scatter_kernel.cu DEPS tensor)
endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta)
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils gather_scatter_kernel backward_infermeta sparse_backward_infermeta)
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class FlipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout,
library,
customized_type_value);
}
};
class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of flip op.");
AddOutput("Out", "(Tensor), The output tensor of flip op.");
AddAttr<std::vector<int>>("axis", "The axes to flip on.");
AddComment(R"DOC(
Flip Operator.
Reverse the order of a n-D tensor along given axis in axes.
)DOC");
}
};
class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
template <typename T>
class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("flip");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(flip,
FlipInferShapeFunctor,
PD_INFER_META(phi::FlipInferMeta));
REGISTER_OPERATOR(flip,
ops::FlipOp,
ops::FlipOpMaker,
ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>,
FlipInferShapeFunctor);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip).AddCheckpoint(
R"ROC(Upgrade flip, add new attr [axis] and delete attr [dims].)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("axis",
"The added attr 'axis' doesn't set default value.",
paddle::none)
.DeleteAttr("dims", "The attr 'dims' is deleted."));
......@@ -28,50 +28,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
class SparseSparseCooTensorOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("values", "(Tensor), input 0 of sparse_coo_tensor op.");
AddInput("indices", "(Tensor), input 1 of sparse_coo_tensor op.");
AddOutput("out", "(Tensor), output 0 of sparse_coo_tensor op.");
AddAttr<std::vector<int>>(
"dense_shape", "(vector<int>), attribute 0 for sparse_coo_tensor op.");
AddComment(R"DOC(
TODO: Documentation of sparse_coo_tensor op.
)DOC");
}
};
class SparseSparseCooTensorOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(
sparse_sparse_coo_tensor,
SparseSparseCooTensorInferShapeFunctor,
PD_INFER_META(phi::sparse::SparseCooTensorInferMeta));
class SparseValuesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_values op.");
AddOutput("out", "(Tensor), output 0 of sparse_values op.");
AddComment(R"DOC(
TODO: Documentation of sparse_values op.
)DOC");
}
};
class SparseValuesOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_values,
SparseValuesInferShapeFunctor,
PD_INFER_META(phi::sparse::ValuesInferMeta));
class SparseIndicesOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -92,182 +48,12 @@ DECLARE_INFER_SHAPE_FUNCTOR(sparse_indices,
SparseIndicesInferShapeFunctor,
PD_INFER_META(phi::sparse::IndicesInferMeta));
class SparseToDenseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_to_dense op.");
AddOutput("out", "(Tensor), output 0 of sparse_to_dense op.");
AddComment(R"DOC(
TODO: Documentation of sparse_to_dense op.
)DOC");
}
};
class SparseToDenseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_to_dense,
SparseToDenseInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_relu op.");
AddOutput("out", "(Tensor), output 0 of sparse_relu op.");
AddComment(R"DOC(
TODO: Documentation of sparse_relu op.
)DOC");
}
};
class SparseReluOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_relu,
SparseReluInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseConv3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_conv3d op.");
AddInput("kernel", "(Tensor), input 1 of sparse_conv3d op.");
AddOutput("out", "(Tensor), output 0 of sparse_conv3d op.");
AddOutput("rulebook", "(Tensor), output 1 of sparse_conv3d op.");
AddOutput("counter", "(Tensor), output 2 of sparse_conv3d op.");
AddAttr<std::vector<int>>(
"paddings", "(vector<int>), attribute 0 for sparse_conv3d op.");
AddAttr<std::vector<int>>(
"dilations", "(vector<int>), attribute 1 for sparse_conv3d op.");
AddAttr<std::vector<int>>(
"strides", "(vector<int>), attribute 2 for sparse_conv3d op.");
AddAttr<int>("groups", "(int), attribute 3 for sparse_conv3d op.");
AddAttr<bool>("subm", "(bool), attribute 4 for conv3d_coo op.");
AddAttr<std::string>("key", "(string), attribute 5 for sparse_conv3d op.")
.SetDefault("");
AddComment(R"DOC(
TODO: Documentation of sparse_conv3d op.
)DOC");
}
};
class SparseConv3dOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_conv3d,
SparseConv3dInferShapeFunctor,
PD_INFER_META(phi::sparse::Conv3dInferMeta));
class SparseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_add op.");
AddInput("y", "(Tensor), input 1 of sparse_add op.");
AddOutput("out", "(Tensor), output 0 of sparse_add op.");
AddComment(R"DOC(
TODO: Documentation of sparse_add op.
)DOC");
}
};
class SparseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_add,
SparseAddInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
class SparseBatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("x", "(Tensor), input 0 of sparse_batch_norm op.");
AddInput("scale", "(Tensor), input 1 of sparse_batch_norm op.");
AddInput("bias", "(Tensor), input 2 of sparse_batch_norm op.");
AddInput("mean", "(Tensor), input 3 of sparse_batch_norm op.");
AddInput("variance", "(Tensor), input 4 of sparse_batch_norm op.");
AddOutput("y", "(Tensor), output 0 of sparse_batch_norm op.");
AddOutput("mean_out", "(Tensor), output 1 of sparse_batch_norm op.");
AddOutput("variance_out", "(Tensor), output 2 of sparse_batch_norm op.");
AddOutput("saved_mean", "(Tensor), output 3 of sparse_batch_norm op.");
AddOutput("saved_variance", "(Tensor), output 4 of sparse_batch_norm op.");
AddOutput("reserve_space", "(Tensor), output 5 of sparse_batch_norm op.");
AddAttr<float>("momentum",
"(float), attribute 0 for sparse_batch_norm op.");
AddAttr<float>("epsilon", "(float), attribute 1 for sparse_batch_norm op.");
AddAttr<std::string>("data_layout",
"(string), attribute 2 for sparse_batch_norm op.");
AddAttr<bool>("is_test", "(bool), attribute 3 for sparse_batch_norm op.");
AddAttr<bool>("use_global_stats",
"(bool), attribute 4 for sparse_batch_norm op.");
AddAttr<bool>("trainable_statistics",
"(bool), attribute 4 for sparse_batch_norm op.");
AddAttr<bool>("fuse_with_relu",
"(bool), attribute 4 for sparse_batch_norm op.");
AddComment(R"DOC(
TODO: Documentation of sparse_batch_norm op.
)DOC");
}
};
class SparseBatchNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
DECLARE_INFER_SHAPE_FUNCTOR(sparse_batch_norm,
SparseBatchNormInferShapeFunctor,
PD_INFER_META(phi::BatchNormInferMeta));
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sparse_sparse_coo_tensor,
ops::SparseSparseCooTensorOp,
ops::SparseSparseCooTensorOpMaker,
ops::SparseSparseCooTensorInferShapeFunctor);
REGISTER_OPERATOR(sparse_values,
ops::SparseValuesOp,
ops::SparseValuesOpMaker,
ops::SparseValuesInferShapeFunctor);
REGISTER_OPERATOR(sparse_indices,
ops::SparseIndicesOp,
ops::SparseIndicesOpMaker,
ops::SparseIndicesInferShapeFunctor);
REGISTER_OPERATOR(sparse_to_dense,
ops::SparseToDenseOp,
ops::SparseToDenseOpMaker,
ops::SparseToDenseInferShapeFunctor);
REGISTER_OPERATOR(sparse_relu,
ops::SparseReluOp,
ops::SparseReluOpMaker,
ops::SparseReluInferShapeFunctor);
REGISTER_OPERATOR(sparse_conv3d,
ops::SparseConv3dOp,
ops::SparseConv3dOpMaker,
ops::SparseConv3dInferShapeFunctor);
REGISTER_OPERATOR(sparse_add,
ops::SparseAddOp,
ops::SparseAddOpMaker,
ops::SparseAddInferShapeFunctor);
REGISTER_OPERATOR(sparse_batch_norm,
ops::SparseBatchNormOp,
ops::SparseBatchNormOpMaker,
ops::SparseBatchNormInferShapeFunctor);
......@@ -108,7 +108,6 @@ register_unity_group(
register_unity_group(
cc
flatten_op.cc
flip_op.cc
fsp_op.cc
gather_nd_op.cc
gather_op.cc
......@@ -424,7 +423,6 @@ register_unity_group(cu expand_v2_op.cu fake_dequantize_op.cu
fill_any_like_op.cu)
register_unity_group(
cu
flip_op.cu
fsp_op.cu
gather_nd_op.cu
gather_op.cu
......
......@@ -416,6 +416,11 @@ GenerateOpFunctions() {
if (CUSTOM_HANDWRITE_OPS_SET.count(op_type)) {
continue;
}
// Skip the sparse op
if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" &&
op_type != "sparse_attention") {
continue;
}
// Skip operator which is not inherit form OperatorWithKernel, like while,
// since only OperatorWithKernel can run in dygraph mode.
// if the phi lib contains op kernel, we still generate ops method
......
......@@ -118,8 +118,13 @@ endif()
set(parsed_api_dir ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/parsed_apis)
set(generated_op_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_op.cc)
set(generated_sparse_ops_path
${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generated_sparse_op.cc)
set(generated_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sig.cc)
set(generated_sparse_argument_mapping_path
${CMAKE_SOURCE_DIR}/paddle/phi/ops/compat/generated_sparse_sig.cc)
message(
"parse api yamls:
- ${api_yaml_file}
......@@ -130,16 +135,22 @@ execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_api_dir}
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./ops.yaml
--output_path ./parsed_apis/api.parsed.yaml
--output_path ./parsed_apis/ops.parsed.yaml
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_ops.yaml --output_path ./parsed_apis/legacy_api.parsed.yaml
./legacy_ops.yaml --output_path ./parsed_apis/legacy_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path ./backward.yaml
--output_path ./parsed_apis/backward_api.parsed.yaml --backward
--output_path ./parsed_apis/backward_ops.parsed.yaml --backward
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./legacy_backward.yaml --output_path
./parsed_apis/legacy_backward_api.parsed.yaml --backward RESULTS_VARIABLE
./parsed_apis/legacy_backward_ops.parsed.yaml --backward
COMMAND ${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_ops.yaml --output_path ./parsed_apis/sparse_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/parse_api.py --api_yaml_path
./sparse_backward.yaml --output_path
./parsed_apis/sparse_backward.parsed.yaml --backward RESULTS_VARIABLE
_results)
foreach(_result in ${_results})
if(${_result})
......@@ -149,38 +160,53 @@ endforeach()
# validation of api yamls
message("validate api yaml:
- ${parsed_api_dir}/api.parsed.yaml
- ${parsed_api_dir}/backward_api.parsed.yaml")
- ${parsed_api_dir}/ops.parsed.yaml
- ${parsed_api_dir}/backward_ops.parsed.yaml")
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/api.parsed.yaml ./parsed_apis/legacy_api.parsed.yaml
--backward_yaml_paths ./parsed_apis/backward_api.parsed.yaml
./parsed_apis/legacy_backward_api.parsed.yaml
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "api validation failed, exiting.")
endif()
./parsed_apis/ops.parsed.yaml ./parsed_apis/legacy_ops.parsed.yaml
--backward_yaml_paths ./parsed_apis/backward_ops.parsed.yaml
./parsed_apis/legacy_backward_ops.parsed.yaml
COMMAND
${PYTHON_EXECUTABLE} generator/cross_validate.py --forward_yaml_paths
./parsed_apis/sparse_ops.parsed.yaml --backward_yaml_paths
./parsed_apis/sparse_backward.parsed.yaml
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "ops validation failed, exiting.")
endif()
endforeach()
# code generation for op, op makers, and argument mapping functions
message(
"create or remove auto-geneated operators: ${generated_op_path}.tmp
create or remove auto-geneated argument mappings: ${generated_argument_mapping_path}.tmp"
)
execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml
COMMAND
${PYTHON_EXECUTABLE} generator/generate_op.py --api_yaml_path
./parsed_apis/api.parsed.yaml --backward_api_yaml_path
./parsed_apis/backward_api.parsed.yaml --api_version_yaml_path
${PYTHON_EXECUTABLE} generator/generate_op.py --ops_yaml_path
./parsed_apis/ops.parsed.yaml --backward_yaml_path
./parsed_apis/backward_ops.parsed.yaml --op_version_yaml_path
op_version.yaml --op_compat_yaml_path op_compat.yaml --output_op_path
"${generated_op_path}.tmp" --output_arg_map_path
"${generated_argument_mapping_path}.tmp"
RESULT_VARIABLE _result)
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
COMMAND
${PYTHON_EXECUTABLE} generator/generate_sparse_op.py --ops_yaml_path
./parsed_apis/sparse_ops.parsed.yaml --backward_ops_yaml_path
./parsed_apis/sparse_backward.parsed.yaml --output_op_path
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
"${generated_sparse_argument_mapping_path}.tmp"
RESULT_VARIABLE _results)
foreach(_result in ${_results})
if(${_result})
message(FATAL_ERROR "operator codegen failed, exiting.")
endif()
endforeach()
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
......@@ -195,6 +221,25 @@ else()
message("remove ${generated_op_path}")
endif()
if(EXISTS "${generated_sparse_ops_path}.tmp" AND EXISTS
"${generated_sparse_ops_path}")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_ops_path}.tmp" "${generated_sparse_ops_path}")
message(
"copy if different ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}"
)
elseif(EXISTS "${generated_sparse_ops_path}.tmp")
execute_process(
COMMAND ${CMAKE_COMMAND} -E copy "${generated_sparse_ops_path}.tmp"
"${generated_sparse_ops_path}")
message("copy ${generated_sparse_ops_path}.tmp ${generated_sparse_ops_path}")
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_ops_path}")
message("remove ${generated_sparse_ops_path}")
endif()
if(EXISTS "${generated_argument_mapping_path}.tmp"
AND EXISTS "${generated_argument_mapping_path}")
execute_process(
......@@ -218,6 +263,30 @@ else()
message("remove ${generated_argument_mapping_path}")
endif()
if(EXISTS "${generated_sparse_argument_mapping_path}.tmp"
AND EXISTS "${generated_sparse_argument_mapping_path}")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy_if_different
"${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy if different ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
elseif(EXISTS "${generated_sparse_argument_mapping_path}.tmp")
execute_process(
COMMAND
${CMAKE_COMMAND} -E copy "${generated_sparse_argument_mapping_path}.tmp"
"${generated_sparse_argument_mapping_path}")
message(
"copy ${generated_sparse_argument_mapping_path}.tmp ${generated_sparse_argument_mapping_path}"
)
else()
execute_process(COMMAND ${CMAKE_COMMAND} -E remove -f
"${generated_sparse_argument_mapping_path}")
message("remove ${generated_sparse_argument_mapping_path}")
endif()
# generate ops extra info
execute_process(
COMMAND ${PYTHON_EXECUTABLE} ${ops_extra_info_gen_file} --op_compat_yaml_path
......
......@@ -147,6 +147,12 @@
data_type: out_grad
no_need_buffer: x
- backward_op : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis)
output : Tensor(x_grad)
invoke : flip(out_grad, axis)
- backward_op : graph_send_uv_grad
forward : graph_send_uv (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, str message_op = "ADD") -> Tensor(out)
args: (Tensor x, Tensor y, Tensor src_index, Tensor dst_index, Tensor out_grad, str message_op = "ADD")
......
......@@ -21,18 +21,32 @@ from pathlib import Path
import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from filters import to_op_attr_type, to_opmaker_name, to_opmaker_name_cstr, to_pascal_case
from tests import is_base_api, is_vec, is_scalar, is_initializer_list, supports_inplace, supports_no_need_buffer
from filters import (
to_op_attr_type,
to_opmaker_name,
to_opmaker_name_cstr,
to_pascal_case,
)
from tests import (
is_base_api,
is_vec,
is_scalar,
is_initializer_list,
supports_inplace,
supports_no_need_buffer,
)
from filters import to_input_name, cartesian_prod_mapping
from parse_utils import to_named_dict
file_loader = FileSystemLoader(Path(__file__).parent / "templates")
env = Environment(loader=file_loader,
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
undefined=StrictUndefined,
extensions=['jinja2.ext.do'])
env = Environment(
loader=file_loader,
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
undefined=StrictUndefined,
extensions=['jinja2.ext.do'],
)
env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case
......@@ -56,7 +70,6 @@ def restruct_io(api):
# replace name of op and params for OpMaker
def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
def get_api_and_op_name(api_item):
names = api_item.split('(')
if len(names) == 1:
......@@ -76,7 +89,8 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
forward_api_item['op_name'] = op_name
if 'backward' in api_args and has_backward:
bw_api_name, bw_op_name = get_api_and_op_name(
api_args['backward'].split(',')[0])
api_args['backward'].split(',')[0]
)
forward_api_item['backward'] = bw_op_name
backward_api_item['op_name'] = bw_op_name
......@@ -102,8 +116,10 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
]
if forward_api_item['kernel']['data_type']:
forward_api_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param for param in
forward_api_item['kernel']['data_type']['candidates']
args_map[param] if param in args_map else param
for param in forward_api_item['kernel']['data_type'][
'candidates'
]
]
if forward_api_item['kernel']['backend']:
forward_api_item['kernel']['backend']['candidates'] = [
......@@ -130,21 +146,36 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
for args_item in backward_api_item['inputs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
elif args_item['name'].endswith(
'_grad') and args_item['name'][:-5] in args_map:
args_map[args_item['name']] = args_map[args_item['name']
[:-5]] + '_grad'
elif (
args_item['name'].endswith('_grad')
and args_item['name'][:-5] in args_map
):
args_map[args_item['name']] = (
args_map[args_item['name'][:-5]] + '_grad'
)
args_item['name'] = args_map[args_item['name']]
for args_item in backward_api_item['attrs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
for args_item in backward_api_item['outputs']:
if args_item['name'].endswith(
'_grad') and args_item['name'][:-5] in args_map:
args_map[args_item['name']] = args_map[args_item['name']
[:-5]] + '_grad'
if (
args_item['name'].endswith('_grad')
and args_item['name'][:-5] in args_map
):
args_map[args_item['name']] = (
args_map[args_item['name'][:-5]] + '_grad'
)
args_item['name'] = args_map[args_item['name']]
if 'invoke' in backward_api_item:
backward_api_item['invoke']['args'] = [
args_map[param.strip()]
if param.strip() in args_map
else param.strip()
for param in backward_api_item['invoke']['args'].split(',')
]
continue
backward_api_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in backward_api_item['infer_meta']['param']
......@@ -155,18 +186,24 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
]
if backward_api_item['kernel']['data_type']:
backward_api_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param for param in
backward_api_item['kernel']['data_type']['candidates']
args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['data_type'][
'candidates'
]
]
if backward_api_item['kernel']['backend']:
backward_api_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param for param in
backward_api_item['kernel']['backend']['candidates']
args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['backend'][
'candidates'
]
]
if backward_api_item['kernel']['layout']:
backward_api_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param for param in
backward_api_item['kernel']['layout']['candidates']
args_map[param] if param in args_map else param
for param in backward_api_item['kernel']['layout'][
'candidates'
]
]
if backward_api_item['no_need_buffer']:
backward_api_item['no_need_buffer'] = [
......@@ -175,9 +212,56 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
]
def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path,
api_version_yaml_path, output_op_path, output_arg_map_path):
with open(api_yaml_path, "rt") as f:
def process_invoke_op(forward_api_dict, backward_api_dict):
for bw_api in backward_api_dict.values():
if 'invoke' in bw_api:
invoke_op = bw_api['invoke']['func']
args_list = bw_api['invoke']['args']
args_index = 0
if invoke_op in forward_api_dict:
reuse_op = forward_api_dict[invoke_op]
bw_api['invoke']['inputs'] = []
bw_api['invoke']['attrs'] = []
bw_api['invoke']['outputs'] = []
for input_item in reuse_op['inputs']:
bw_api['invoke']['inputs'].append(
{
'name': input_item['name'],
'value': args_list[args_index],
}
)
args_index = args_index + 1
for attr in reuse_op['attrs']:
if args_index < len(args_list):
attr_value = (
f"this->GetAttr(\"{args_list[args_index]}\")"
if args_list[args_index] in bw_api['attr_dict']
else args_list[args_index]
)
bw_api['invoke']['attrs'].append(
{'name': attr['name'], 'value': attr_value}
)
args_index = args_index + 1
else:
break
for idx, output_item in enumerate(reuse_op['outputs']):
bw_api['invoke']['outputs'].append(
{
'name': output_item['name'],
'value': bw_api['outputs'][idx]['name'],
}
)
def main(
ops_yaml_path,
backward_yaml_path,
op_compat_yaml_path,
op_version_yaml_path,
output_op_path,
output_arg_map_path,
):
with open(ops_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
......@@ -187,7 +271,7 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path,
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
with open(api_version_yaml_path, "rt") as f:
with open(op_version_yaml_path, "rt") as f:
api_versions = yaml.safe_load(f)
# add api version info into api
for api_version in api_versions:
......@@ -203,6 +287,9 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path,
replace_compat_name(api_op_map, forward_api_dict, backward_api_dict)
# prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict)
# fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"]
......@@ -224,9 +311,9 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path,
op_template = env.get_template('op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(apis=apis,
backward_apis=backward_apis,
api_dict=api_dict)
msg = op_template.render(
apis=apis, backward_apis=backward_apis, api_dict=api_dict
)
f.write(msg)
ks_template = env.get_template('ks.c.j2')
......@@ -237,28 +324,35 @@ def main(api_yaml_path, backward_yaml_path, op_compat_yaml_path,
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate operator file from api yaml.")
parser.add_argument('--api_yaml_path',
type=str,
help="parsed api yaml file.")
parser.add_argument('--backward_api_yaml_path',
type=str,
help="parsed backward api yaml file.")
parser.add_argument('--op_compat_yaml_path',
type=str,
help="api args compat yaml file.")
parser.add_argument('--api_version_yaml_path',
type=str,
help="api version yaml file.")
parser.add_argument("--output_op_path",
type=str,
help="path to save generated operators.")
description="Generate operator file from api yaml."
)
parser.add_argument(
'--ops_yaml_path', type=str, help="parsed ops yaml file."
)
parser.add_argument(
'--backward_yaml_path', type=str, help="parsed backward ops yaml file."
)
parser.add_argument(
'--op_compat_yaml_path', type=str, help="ops args compat yaml file."
)
parser.add_argument(
'--op_version_yaml_path', type=str, help="ops version yaml file."
)
parser.add_argument(
"--output_op_path", type=str, help="path to save generated operators."
)
parser.add_argument(
"--output_arg_map_path",
type=str,
help="path to save generated argument mapping functions.")
help="path to save generated argument mapping functions.",
)
args = parser.parse_args()
main(args.api_yaml_path, args.backward_api_yaml_path,
args.op_compat_yaml_path, args.api_version_yaml_path,
args.output_op_path, args.output_arg_map_path)
main(
args.ops_yaml_path,
args.backward_yaml_path,
args.op_compat_yaml_path,
args.op_version_yaml_path,
args.output_op_path,
args.output_arg_map_path,
)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from filters import (
to_op_attr_type,
to_opmaker_name,
to_opmaker_name_cstr,
to_pascal_case,
)
from tests import (
is_base_api,
is_vec,
is_scalar,
is_initializer_list,
supports_inplace,
supports_no_need_buffer,
)
from filters import to_input_name, cartesian_prod_mapping
from parse_utils import to_named_dict
from generate_op import process_invoke_op
file_loader = FileSystemLoader(Path(__file__).parent / "templates")
env = Environment(
loader=file_loader,
keep_trailing_newline=True,
trim_blocks=True,
lstrip_blocks=True,
undefined=StrictUndefined,
extensions=['jinja2.ext.do'],
)
env.filters["to_op_attr_type"] = to_op_attr_type
env.filters["to_opmaker_name"] = to_opmaker_name
env.filters["to_pascal_case"] = to_pascal_case
env.filters["to_input_name"] = to_input_name
env.filters["to_opmaker_name_cstr"] = to_opmaker_name_cstr
env.filters["cartesian_prod_mapping"] = cartesian_prod_mapping
env.tests["base_api"] = is_base_api
env.tests["vec"] = is_vec
env.tests["scalar"] = is_scalar
env.tests["initializer_list"] = is_initializer_list
env.tests["supports_inplace"] = supports_inplace
env.tests["supports_no_need_buffer"] = supports_no_need_buffer
def restruct_io(api):
api["input_dict"] = to_named_dict(api["inputs"])
api["attr_dict"] = to_named_dict(api["attrs"])
api["output_dict"] = to_named_dict(api["outputs"])
return api
SPARSE_OP_PREFIX = 'sparse_'
def main(
api_yaml_path, backward_yaml_path, output_op_path, output_arg_map_path
):
with open(api_yaml_path, "rt") as f:
apis = yaml.safe_load(f)
apis = [restruct_io(api) for api in apis]
forward_api_dict = to_named_dict(apis)
with open(backward_yaml_path, "rt") as f:
backward_apis = yaml.safe_load(f)
backward_apis = [restruct_io(api) for api in backward_apis]
backward_api_dict = to_named_dict(backward_apis)
for api in apis:
api['op_name'] = SPARSE_OP_PREFIX + api['name']
api['name'] = api['op_name']
if api["backward"] is not None:
api["backward"] = SPARSE_OP_PREFIX + api["backward"]
for bw_api in backward_apis:
bw_api['op_name'] = SPARSE_OP_PREFIX + bw_api['name']
bw_api['name'] = bw_api['op_name']
if 'invoke' in bw_api:
bw_api['invoke']['args'] = [
param.strip() for param in bw_api['invoke']['args'].split(',')
]
# prepare for invoke case
process_invoke_op(forward_api_dict, backward_api_dict)
for bw_api in backward_apis:
if 'invoke' in bw_api:
if bw_api['invoke']['func'] in forward_api_dict:
bw_api['invoke']['func'] = (
SPARSE_OP_PREFIX + bw_api['invoke']['func']
)
# fill backward field for an api if another api claims it as forward
for name, backward_api in backward_api_dict.items():
forward_name = backward_api["forward"]["name"]
if forward_name in backward_api_dict:
forward_api = backward_api_dict[forward_name]
if forward_api["backward"] is None:
forward_api["backward"] = name
forward_api["backward"] = SPARSE_OP_PREFIX + forward_api["backward"]
api_dict = {}
api_dict.update(forward_api_dict)
api_dict.update(backward_api_dict)
if len(apis) == 0 and len(backward_apis) == 0:
if os.path.isfile(output_op_path):
os.remove(output_op_path)
if os.path.isfile(output_arg_map_path):
os.remove(output_arg_map_path)
return
op_template = env.get_template('sparse_op.c.j2')
with open(output_op_path, "wt") as f:
msg = op_template.render(
apis=apis, backward_apis=backward_apis, api_dict=api_dict
)
f.write(msg)
ks_template = env.get_template('sparse_ks.c.j2')
with open(output_arg_map_path, 'wt') as f:
msg = ks_template.render(apis=apis, backward_apis=backward_apis)
f.write(msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate operator file from api yaml."
)
parser.add_argument(
'--ops_yaml_path', type=str, help="parsed sparse ops yaml file."
)
parser.add_argument(
'--backward_ops_yaml_path',
type=str,
help="parsed backward sparse ops yaml file.",
)
parser.add_argument(
"--output_op_path", type=str, help="path to save generated operators."
)
parser.add_argument(
"--output_arg_map_path",
type=str,
help="path to save generated argument mapping functions.",
)
args = parser.parse_args()
main(
args.ops_yaml_path,
args.backward_ops_yaml_path,
args.output_op_path,
args.output_arg_map_path,
)
......@@ -35,39 +35,42 @@ def parse_arg(api_name: str, s: str) -> Dict[str, str]:
2. typename name = default_value
"""
typename, rest = [item.strip() for item in s.split(" ", 1)]
assert len(
typename
) > 0, f"The arg typename should not be empty. Please check the args of {api_name} in yaml."
assert (
len(typename) > 0
), f"The arg typename should not be empty. Please check the args of {api_name} in yaml."
assert rest.count(
"=") <= 1, f"There is more than 1 = in an arg in {api_name}"
assert (
rest.count("=") <= 1
), f"There is more than 1 = in an arg in {api_name}"
if rest.count("=") == 1:
name, default_value = [item.strip() for item in rest.split("=", 1)]
assert len(
name
) > 0, f"The arg name should not be empty. Please check the args of {api_name} in yaml."
assert len(
default_value
) > 0, f"The default value should not be empty. Please check the args of {api_name} in yaml."
assert (
len(name) > 0
), f"The arg name should not be empty. Please check the args of {api_name} in yaml."
assert (
len(default_value) > 0
), f"The default value should not be empty. Please check the args of {api_name} in yaml."
return {
"typename": typename,
"name": name,
"default_value": default_value
"default_value": default_value,
}
else:
name = rest.strip()
assert len(
name
) > 0, f"The arg name should not be empty. Please check the args of {api_name} in yaml."
assert (
len(name) > 0
), f"The arg name should not be empty. Please check the args of {api_name} in yaml."
return {"typename": typename, "name": name}
def parse_input_and_attr(api_name: str,
arguments: str) -> Tuple[List, List, Dict, Dict]:
def parse_input_and_attr(
api_name: str, arguments: str
) -> Tuple[List, List, Dict, Dict]:
args_str = arguments.strip()
assert args_str.startswith('(') and args_str.endswith(')'), \
(f"Args declaration should start with '(' and end with ')', "
f"please check the args of {api_name} in yaml.")
assert args_str.startswith('(') and args_str.endswith(')'), (
f"Args declaration should start with '(' and end with ')', "
f"please check the args of {api_name} in yaml."
)
args_str = args_str[1:-1]
args = parse_plain_list(args_str)
......@@ -81,14 +84,17 @@ def parse_input_and_attr(api_name: str,
typename = item["typename"]
name = item["name"]
if is_input(typename):
assert len(attrs) == 0, \
(f"The input Tensor should appear before attributes. "
assert len(attrs) == 0, (
f"The input Tensor should appear before attributes. "
f"please check the position of {api_name}:input({name}) "
f"in yaml.")
f"in yaml."
)
inputs.append(item)
elif is_attr(typename):
if met_attr_with_default_value:
assert "default_value" in item, f"{api_name}: Arguments with default value should not precede those without default value"
assert (
"default_value" in item
), f"{api_name}: Arguments with default value should not precede those without default value"
elif "default_value" in item:
met_attr_with_default_value = True
attrs.append(item)
......@@ -101,7 +107,8 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]:
"""parse an output, typename or typename(name)."""
match = re.search(
r"(?P<out_type>[a-zA-Z0-9_[\]]+)\s*(?P<name>\([a-zA-Z0-9_@]+\))?\s*(?P<expr>\{[^\}]+\})?",
s)
s,
)
typename = match.group("out_type")
name = match.group("name")
size_expr = match.group("expr")
......@@ -109,13 +116,15 @@ def parse_output(api_name: str, s: str) -> Dict[str, str]:
name = name[1:-1] if name is not None else 'out'
size_expr = size_expr[1:-1] if size_expr is not None else None
assert is_output(typename), \
(f"Invalid output type: {typename} in api: {api_name}."
f"Supported types are Tensor and Tensor[]")
assert is_output(typename), (
f"Invalid output type: {typename} in api: {api_name}."
f"Supported types are Tensor and Tensor[]"
)
if size_expr is not None:
assert is_vec(typename), \
(f"Invalid output size: output {name} in api: {api_name} is "
f"not a vector but has size expr")
assert is_vec(typename), (
f"Invalid output size: output {name} in api: {api_name} is "
f"not a vector but has size expr"
)
return {"typename": typename, "name": name, "size": size_expr}
else:
return {"typename": typename, "name": name}
......@@ -149,22 +158,24 @@ def parse_plain_list(s: str, sep=",") -> List[str]:
return items
def parse_kernel(api_name: str, kernel_config: Dict[str,
Any]) -> Dict[str, Any]:
def parse_kernel(
api_name: str, kernel_config: Dict[str, Any]
) -> Dict[str, Any]:
# kernel :
# func : [], Kernel functions (example: scale, scale_sr)
# param : [], Input params of kernel
# backend : str, the names of param to choose the kernel backend, default is None
# layout : str, the names of param to choose the kernel layout, default is None
# data_type : str, the names of param to choose the kernel data_type, default is None
# dispatch : {}, the key is kernel_func, the value is type of inputs and outputs for kernel (example: {kernel_name : (['dense','sparse_coo']#input,['sparse_coo']#output)})
kernel = {
'func': None, # up to 2 function names
'func': [], # up to 2 function names
'param': None,
'backend': None,
'layout': None,
'data_type': None
'data_type': None,
'dispatch': {},
}
kernel['func'] = parse_plain_list(kernel_config['func'])
if 'param' in kernel_config:
kernel['param'] = kernel_config['param']
......@@ -176,6 +187,42 @@ def parse_kernel(api_name: str, kernel_config: Dict[str,
if 'data_type' in kernel_config:
kernel['data_type'] = parse_candidates(kernel_config["data_type"])
kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall(
kernel_config['func']
)
def parse_kernel_in_out_type(in_out_str):
if len(in_out_str) == 0:
return None
tmp_in_out_list = in_out_str[1:-1].split('->')
inputs = [item.strip() for item in tmp_in_out_list[0].split(',')]
outputs = [item.strip() for item in tmp_in_out_list[1].split(',')]
# check the tensor type
for item in inputs:
assert item in [
'dense',
'selected_rows',
'sparse_coo',
'sparse_csr',
], f"{api_name} : Invalid input tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
for item in outputs:
assert item in [
'dense',
'selected_rows',
'sparse_coo',
'sparse_csr',
], f"{api_name} : Invalid output tensor type ('{item}'), here we only support 'dense', 'selected_rows', 'sparse_coo' and 'sparse_csr'."
return (inputs, outputs)
for func_item in kernel_funcs:
kernel['func'].append(func_item[0])
kernel['dispatch'][func_item[0]] = parse_kernel_in_out_type(
func_item[1]
)
return kernel
......@@ -200,10 +247,9 @@ def parse_invoke(api_name: str, invoke_config: str) -> Dict[str, Any]:
def extract_type_and_name(records: List[Dict]) -> List[Dict]:
"""extract type and name from forward call, it is simpler than forward api."""
extracted = [{
"name": item["name"],
"typename": item["typename"]
} for item in records]
extracted = [
{"name": item["name"], "typename": item["typename"]} for item in records
]
return extracted
......@@ -211,7 +257,8 @@ def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]:
# api_name (const Tensor& input, ... , int attr, ...) -> Tensor(out)
result = re.search(
r"(?P<op>[a-z][a-z0-9_]+)\s*(?P<args>\([^\)]+\))\s*->\s*(?P<outputs>.+)",
forward_config)
forward_config,
)
api = result.group("op")
outputs = parse_outputs(api_name, result.group("outputs"))
outputs = extract_type_and_name(outputs)
......@@ -223,7 +270,7 @@ def parse_forward(api_name: str, forward_config: str) -> Dict[str, Any]:
"name": api,
"inputs": inputs,
"attrs": attrs,
"outputs": outputs
"outputs": outputs,
}
return forward_cfg
......@@ -239,13 +286,19 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
typename = attr["typename"]
default_value = attr["default_value"]
if typename == "DataType":
assert "DataType" in default_value, f"invalid DataType default value in {api_name}"
assert (
"DataType" in default_value
), f"invalid DataType default value in {api_name}"
# remove namespace
default_value = default_value[default_value.find("DataType"):]
default_value = default_value[default_value.find("DataType") :]
attr["default_value"] = default_value
elif typename == "DataLayout":
assert "DataLayout" in default_value, f"invalid DataLayout default value in {api_name}"
default_value = default_value[default_value.find("DataLayout"):]
assert (
"DataLayout" in default_value
), f"invalid DataLayout default value in {api_name}"
default_value = default_value[
default_value.find("DataLayout") :
]
attr["default_value"] = default_value
input_names = [item["name"] for item in inputs]
......@@ -258,7 +311,9 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
if "optional" in api_entry:
optional_args = parse_plain_list(api_entry["optional"])
for name in optional_args:
assert name in input_names, f"{api_name} has an optional input: '{name}' which is not an input."
assert (
name in input_names
), f"{api_name} has an optional input: '{name}' which is not an input."
for input in inputs:
if input["name"] in optional_args:
input["optional"] = True
......@@ -269,7 +324,9 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
if "intermediate" in api_entry:
intermediate_outs = parse_plain_list(api_entry["intermediate"])
for name in intermediate_outs:
assert name in output_names, f"{api_name} has an intermediate output: '{name}' which is not an output."
assert (
name in output_names
), f"{api_name} has an intermediate output: '{name}' which is not an output."
for output in outputs:
if output["name"] in intermediate_outs:
output["intermediate"] = True
......@@ -280,7 +337,9 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
if "no_need_buffer" in api_entry:
no_buffer_args = parse_plain_list(api_entry["no_need_buffer"])
for name in no_buffer_args:
assert name in input_names, f"{api_name} has an no buffer input: '{name}' which is not an input."
assert (
name in input_names
), f"{api_name} has an no buffer input: '{name}' which is not an input."
for input in inputs:
if input["name"] in no_buffer_args:
input["no_need_buffer"] = True
......@@ -294,7 +353,7 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
"inputs": inputs,
"attrs": attrs,
"outputs": outputs,
"no_need_buffer": no_buffer_args
"no_need_buffer": no_buffer_args,
}
# invokes another api?
......@@ -316,11 +375,13 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
inplace_pairs = parse_inplace(api_name, api_entry["inplace"])
else:
inplace_pairs = None
api.update({
"infer_meta": infer_meta,
"kernel": kernel,
"inplace": inplace_pairs
})
api.update(
{
"infer_meta": infer_meta,
"kernel": kernel,
"inplace": inplace_pairs,
}
)
else:
# invoke
invoke = parse_invoke(api_name, api_entry["invoke"])
......@@ -339,8 +400,9 @@ def parse_api_entry(api_entry: Dict[str, Any], name_field="op"):
if "forward" in api_entry:
forward = parse_forward(api_name, api_entry["forward"])
# validate_fb
validate_backward_inputs(api_name, forward["inputs"],
forward["outputs"], inputs)
validate_backward_inputs(
api_name, forward["inputs"], forward["outputs"], inputs
)
validate_backward_attrs(api_name, forward["attrs"], attrs)
validate_backward_outputs(api_name, forward["inputs"], outputs)
else:
......@@ -356,23 +418,27 @@ def validate_backward_attrs(api, forward_attrs, backward_attrs):
# this is a not-that-clean trick to allow backward api to has more attrs
# than the forward api, as long as they all have default value
for i in range(-num_exceptional_attrs, 0):
assert "default_value" in backward_attrs[
i], f"{api} has exceptional attr without default value"
assert (
"default_value" in backward_attrs[i]
), f"{api} has exceptional attr without default value"
def validate_backward_inputs(api, forward_inputs, forward_outputs,
backward_inputs):
def validate_backward_inputs(
api, forward_inputs, forward_outputs, backward_inputs
):
foward_input_names = [item["name"] for item in forward_inputs]
forward_output_names = [item["name"] for item in forward_outputs]
backward_input_names = [item["name"] for item in backward_inputs]
assert len(backward_input_names) <= len(foward_input_names) + 2 * len(
forward_output_names), f"{api} has too many inputs."
forward_output_names
), f"{api} has too many inputs."
def validate_backward_outputs(api, forward_inputs, backward_outputs):
assert len(backward_outputs) <= len(
forward_inputs), f"{api} has too many outputs"
forward_inputs
), f"{api} has too many outputs"
def cross_validate(apis):
......@@ -391,15 +457,17 @@ def cross_validate(apis):
f"Something Wrong here, {name}'s forward api({fw_name}) does not claim {name} as its backward."
)
else:
assert fw_api[
"backward"] == name, f"{name}: backward and forward name mismatch"
assert (
fw_api["backward"] == name
), f"{name}: backward and forward name mismatch"
assert len(fw_call["inputs"]) <= len(
fw_api["inputs"]
), f"{name}: forward call has more inputs than the api"
for (input, input_) in zip(fw_call["inputs"], fw_api["inputs"]):
assert input["typename"] == input_[
"typename"], f"type mismatch in {name} and {fw_name}"
assert (
input["typename"] == input_["typename"]
), f"type mismatch in {name} and {fw_name}"
assert len(fw_call["attrs"]) <= len(
fw_api["attrs"]
......@@ -411,13 +479,16 @@ def cross_validate(apis):
r"Scalar(\(\w+\))*", attr_["typename"]
), f"type mismatch in {name} and {fw_name}"
else:
assert attr["typename"] == attr_[
"typename"], f"type mismatch in {name} and {fw_name}"
assert (
attr["typename"] == attr_["typename"]
), f"type mismatch in {name} and {fw_name}"
assert len(fw_call["outputs"]) == len(
fw_api["outputs"]
), f"{name}: forward call has more outputs than the api"
for (output, output_) in zip(fw_call["outputs"],
fw_api["outputs"]):
assert output["typename"] == output_[
"typename"], f"type mismatch in {name} and {fw_name}"
for (output, output_) in zip(
fw_call["outputs"], fw_api["outputs"]
):
assert (
output["typename"] == output_["typename"]
), f"type mismatch in {name} and {fw_name}"
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, operator, register_op_with_components, register_op_version %}
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
......@@ -33,6 +33,8 @@ using paddle::framework::GradVarName;
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}}
{{operator(api)}}
{% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}}
{% endif %}
{% endfor %}
} // namespace operators
......
......@@ -81,7 +81,11 @@ AddAttr<{{typename | to_op_attr_type}}>("{{name}}", "({{typename | to_op_attr_ty
{% set default_value = attr["default_value"] %}
{% set typename = attr["typename"] %}
{% if typename == "DataType" %}{# convert back to VarType #}
{% if default_value == "DataType::UNDEFINED" %}
-1
{%- else %}
static_cast<int>(framework::TransToProtoVarType(experimental::{{default_value}}))
{%- endif %}
{%- elif typename == "DataLayout" %} {# does DataLayout need any processing?#}
static_cast<int>(experimental::{{default_value}})
{%- elif typename == "Place" %}{# construct a Place to get the type #}
......@@ -94,7 +98,7 @@ static_cast<int>(phi::Place({{"phi::" if not default_value is initializer_list}}
{# --------------------------------------- name mapping ---------------------------------------------- #}
{% macro name_map(api) %}
KernelSignature {{api["name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs;
......@@ -124,12 +128,64 @@ All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArg
*/
{% endmacro %}
{% macro get_kernel_dispatch(inputs, kernel_config) %}{# inline #}
{%- for kernel_func in kernel_config["func"] %}
{% set input_idx = namespace(idx=0) %}
{% set kernel_in_type_list = kernel_config["dispatch"][kernel_func][0] %}
if ( {%- for input in inputs %}
{%- if input["name"] in kernel_config["param"] %}
{%- if kernel_in_type_list[input_idx.idx] == "dense" %}
ctx.IsDenseTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "selected_rows" %}
ctx.IsSelectedRowsInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "sparse_coo" %}
ctx.IsSparseCooTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- elif kernel_in_type_list[input_idx.idx] == "sparse_csr" %}
ctx.IsSparseCsrTensorInput("{{input["name"]}}"){{" && " if not loop.last}}
{%- endif %}
{% set input_idx.idx = input_idx.idx + 1 %}
{%- endif %}
{%- endfor %}) {
kernel_name = "{{kernel_func}}";
}
{%- endfor %}
{%- endmacro %}
{% macro sparse_op_name_map(api) %}
KernelSignature {{api["op_name"] | to_pascal_case }}OpArgumentMapping(const ArgumentMappingContext& ctx) {
{% set kernel_args = api["kernel"]["param"] %}
{{get_input_list(api["inputs"], kernel_args)}};
paddle::small_vector<const char*> attrs;
{% for attr in api["attrs"]%}
{% filter indent(2)%}
{{get_an_attr(attr)}};
{% endfilter %}
{% endfor %}
{{get_output_list(api["outputs"], kernel_args)}};
const char* kernel_name = "unregistered";
{{get_kernel_dispatch(api["inputs"], api["kernel"])}}
KernelSignature sig (kernel_name, std::move(inputs), std::move(attrs), std::move(outputs));
return sig;
}
/*
******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py'
All possible KernelSignatures returned by {{api["name"] | to_pascal_case }}OpArgumentMapping:
{{api | cartesian_prod_mapping}}
******************************************************************
*/
{% endmacro %}
{% macro register_base_kernel_name(api) %}
PD_REGISTER_BASE_KERNEL_NAME({{api["op_name"]}}, {{api["name"]}});
{%- endmacro %}
{% macro register_name_map(api) %}
PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["name"] | to_pascal_case}}OpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN({{api["op_name"]}}, phi::{{api["op_name"] | to_pascal_case}}OpArgumentMapping);
{%- endmacro %}
{% macro get_input_list(inputs, kernel_args) %}{# inline #}
......@@ -352,6 +408,48 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
};
{% endmacro %}
{% macro backward_op_reused_maker(bw_op, forward_op, invoke_op) %}
{% set name = bw_op["op_name"] %}
{% set forward_input_names = bw_op["forward"]["inputs"] | map(attribute="name") | list %}
{% set forward_output_names = bw_op["forward"]["outputs"] | map(attribute="name") | list %}
{% set forward_attr_names = bw_op["forward"]["attrs"] | map(attribute="name") | list %}
{% set forward_input_orig_names = forward_op["inputs"] | map(attribute="name") | list %}
{% set forward_output_orig_names = forward_op["outputs"] | map(attribute="name") | list %}
{% set forward_attr_orig_names = forward_op["attrs"] | map(attribute="name") | list %}
template <typename T>
class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("{{invoke_op["func"]}}");
{% for input in invoke_op["inputs"] %}
grad_op->SetInput({{input["name"] | to_opmaker_name}}, this->{{extract_input_from_forward(
input["value"],
forward_input_names,
forward_output_names,
forward_input_orig_names,
forward_output_orig_names)}});
{% endfor %}
{% for output in invoke_op["outputs"] %}
grad_op->SetOutput({{output["name"] | to_opmaker_name}}, this->{{extract_output_from_forward(
output["value"],
forward_input_names,
forward_output_names,
forward_input_orig_names,
forward_output_orig_names)}});
{% endfor %}
{% for attr in invoke_op["attrs"] %}
grad_op->SetAttr("{{attr["name"]}}", {{attr["value"]}});
{% endfor %}
}
};
{% endmacro %}
{% macro extract_input_from_forward(name,
input_names, output_names,
......
{% from "operator_utils.c.j2" import sparse_op_name_map, register_name_map, register_base_kernel_name %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace phi {
{% for api in apis %}
{% if api is base_api %}
{{sparse_op_name_map(api)}}
{% endif %}
{% endfor %}
{% for api in backward_apis %}
{% if api is base_api %}
{{sparse_op_name_map(api)}}
{% endif %}
{% endfor %}
} // namespace phi
{% for api in apis + backward_apis %}
{% if api is base_api %}
{{register_name_map(api)}}
{% endif %}
{% endfor %}
{% from "operator_utils.c.j2" import op_maker, backward_op_maker, backward_op_reused_maker, operator, register_op_with_components, register_op_version %}
// this file is generated by paddle/phi/api/yaml/generator/generate_op.py, do not edit.
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/sparse/backward.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/multiary.h"
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
using paddle::framework::GradVarName;
{% for api in apis %}
{% if api is base_api %}
{{op_maker(api)}}
{{operator(api)}}
{% endif %}
{% endfor %}
{% for api in backward_apis %}
{% if api is base_api %}
{{backward_op_maker(api, api_dict[api["forward"]["name"]])}}
{{operator(api)}}
{% else %}
{{backward_op_reused_maker(api, api_dict[api["forward"]["name"]], api["invoke"])}}
{% endif %}
{% endfor %}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
{% for api in apis + backward_apis %}
{% if api is base_api %}
{{register_op_with_components(api)}}
{% endif %}
{% endfor %}
......@@ -839,12 +839,6 @@
layout: out_grad
inplace : (out_grad -> x_grad)
- backward_op : flip_grad
forward : flip (Tensor x, int[] axis) -> Tensor(out)
args : (Tensor out_grad, int[] axis)
output : Tensor(x_grad)
invoke : flip(out_grad, axis)
- backward_op : floor_grad
forward : floor(Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......
......@@ -938,15 +938,6 @@
intermediate : xshape
backward : flatten_grad
- op : flip
args : (Tensor x, int[] axis)
output : Tensor
infer_meta :
func : FlipInferMeta
kernel :
func : flip
backward : flip_grad
- op : floor
args : (Tensor x)
output : Tensor(out)
......
......@@ -324,6 +324,12 @@
inputs: {x: X}
outputs: {out: Out}
- op : flip
inputs :
x : X
outputs :
out : Out
- op : floor
backward : floor_grad
extra :
......
- op : flip
version :
- checkpoint : Upgrade flip, add new attr [axis] and delete attr [dims]
action :
- add_attr : axis
comment : The added attr 'axis' doesn't set default value
default : paddle::none
- delete_attr : dims
comment : The attr 'dims' is deleted.
- op : trace
version :
- checkpoint : Upgrade trace add a new attribute [axis2]
......
......@@ -199,3 +199,12 @@
kernel :
func : trunc
backward : trunc_grad
- op : flip
args : (Tensor x, int[] axis)
output : Tensor (out)
infer_meta :
func : FlipInferMeta
kernel :
func : flip
backward : flip_grad
- backward_op : abs_grad
forward : tanh(Tensor x) -> Tensor(out)
forward : abs(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
......@@ -124,8 +124,8 @@
cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
data_type : out_grad
- backward_op : conv3d_coo_grad
forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter)
- backward_op : conv3d_grad
forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter)
args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
output : Tensor(x_grad), Tensor(kernel_grad)
infer_meta :
......@@ -432,7 +432,7 @@
transpose_csr_grad {sparse_csr -> sparse_csr}
- backward_op : values_grad
forward : values_coo(Tensor x) -> Tensor(out)
forward : values(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
......@@ -442,7 +442,7 @@
func : values_coo_grad{sparse_coo, dense-> sparse_coo}
- backward_op: fused_attention_grad
forward : fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
forward : fused_attention(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
args: (Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
infer_meta :
......
......@@ -112,7 +112,7 @@
backward : cast_grad
- op : conv3d
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key)
args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key="")
output : Tensor(out), Tensor(rulebook), Tensor(counter)
infer_meta :
func : sparse::Conv3dInferMeta
......@@ -120,7 +120,7 @@
func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense}
layout : x
intermediate: rulebook, counter
backward : conv3d_coo_grad
backward : conv3d_grad
- op : divide
args : (Tensor x, Tensor y)
......
......@@ -109,6 +109,7 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorInputs(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
virtual bool IsSparseCooTensorInput(const std::string& name) const = 0;
virtual bool IsSparseCsrTensorInput(const std::string& name) const = 0;
// For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
......
......@@ -40,7 +40,7 @@ const std::unordered_set<std::string> standard_kernel_suffixs({
* after 2.0, and can no longer be occupied by the previously abandoned ops.
* They are marked here uniformly.
*/
const std::unordered_set<std::string> deprecated_op_names(
static const std::unordered_set<std::string> deprecated_op_names(
{"diag",
"flatten",
"flatten_grad",
......
......@@ -16,6 +16,11 @@
#include "glog/logging.h"
#include "paddle/phi/core/enforce.h"
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#include "paddle/phi/core/compat/convert_utils.h"
#endif
#include "paddle/phi/core/compat/op_utils.h"
DECLARE_bool(enable_api_kernel_fallback);
......@@ -41,6 +46,17 @@ KernelFactory& KernelFactory::Instance() {
return g_op_kernel_factory;
}
bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const {
if (deprecated_op_names.find(op_type) == deprecated_op_names.end()) {
if (phi::OpUtilsMap::Instance().Contains(op_type)) {
return true;
} else if (kernels_.find(op_type) != kernels_.end()) {
return true;
}
}
return false;
}
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
......
......@@ -272,9 +272,7 @@ class KernelFactory {
KernelNameMap& kernels() { return kernels_; }
bool HasCompatiblePhiKernel(const std::string& op_type) const {
return kernels_.find(TransToPhiKernelName(op_type)) != kernels_.end();
}
bool HasCompatiblePhiKernel(const std::string& op_type) const;
KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key,
......
......@@ -16,23 +16,6 @@
namespace phi {
// TODO(zhangkaihuo): add csr op
KernelSignature SparseSparseCooTensorOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sparse_coo_tensor", {"values", "indices"}, {"dense_shape"}, {"out"});
}
KernelSignature SparseValuesOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("values_coo", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseIndicesOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
......@@ -42,94 +25,6 @@ KernelSignature SparseIndicesOpArgumentMapping(
}
}
KernelSignature SparseToDenseOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("coo_to_dense", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("relu_coo", {"x"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseConv3dOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature(
"conv3d_coo",
{"x", "kernel"},
{"paddings", "dilations", "strides", "groups", "subm", "key"},
{"out", "rulebook", "counter"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseAddOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x") && ctx.IsSparseCooTensorInput("y")) {
return KernelSignature("add_coo_coo", {"x", "y"}, {}, {"out"});
} else if (ctx.IsSparseCooTensorInput("x") && ctx.IsDenseTensorInput("y")) {
return KernelSignature("add_coo_dense", {"x", "y"}, {}, {"out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
KernelSignature SparseBatchNormOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsSparseCooTensorInput("x")) {
return KernelSignature("batch_norm_coo",
{"x", "scale", "bias", "mean", "variance"},
{"momentum",
"epsilon",
"data_layout",
"is_test",
"use_global_stats",
"trainable_statistics",
"fuse_with_relu"},
{"y",
"mean_out",
"variance_out",
"saved_mean",
"saved_variance",
"reserve_space"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(sparse_sparse_coo_tensor, sparse_coo_tensor);
PD_REGISTER_ARG_MAPPING_FN(sparse_sparse_coo_tensor,
phi::SparseSparseCooTensorOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_values, values_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_values, phi::SparseValuesOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_indices, indices_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_indices, phi::SparseIndicesOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_to_dense, coo_to_dense);
PD_REGISTER_ARG_MAPPING_FN(sparse_to_dense,
phi::SparseToDenseOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_relu, relu_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_relu, phi::SparseReluOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_conv3d, conv3d_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_conv3d, phi::SparseConv3dOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_add, add_coo_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_add, phi::SparseAddOpArgumentMapping);
PD_REGISTER_BASE_KERNEL_NAME(sparse_batch_norm, batch_norm_coo);
PD_REGISTER_ARG_MAPPING_FN(sparse_batch_norm,
phi::SparseBatchNormOpArgumentMapping);
......@@ -86,6 +86,10 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return false;
}
bool IsSparseCsrTensorInput(const std::string& name) const override {
return false;
}
bool IsDenseTensorOutput(const std::string& name) const override {
return dense_tensor_outputs.count(name) > 0;
}
......
......@@ -190,19 +190,20 @@ class BatchNorm(paddle.nn.BatchNorm1D):
reserve_space = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True
)
y = helper.create_sparse_variable_for_type_inference(dtype)
out = helper.create_sparse_variable_for_type_inference(dtype)
outputs = {
"y": y,
"out": out,
"mean_out": mean_out,
"variance_out": variance_out,
"saved_mean": saved_mean,
"saved_variance": saved_variance,
"reserve_space": reserve_space,
}
helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=attrs
)
return y
return out
class SyncBatchNorm(paddle.nn.SyncBatchNorm):
......
文件模式从 100644 更改为 100755
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册