From b75507d34203c28fef5a2f1d1df4c57ba5633695 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 26 Jan 2022 09:53:12 +0800 Subject: [PATCH] [PTen] Unify InferMeta(Shape) Function in pten and fluid op (#38976) * infermeta context init design * support infermeta called in fluid op * add hasattr and attr methods * add dygraah GetVarPtrs support * rename arg_map_context to arg_map_utils * add registry for arg map func * resolve conflit * refactor op utils design * polish meta config * fix details * remove hasattr method * resolve conflit * revert cmake order change * revert some change * change init pos * fix compile faileed * fix typo * fix inference failed * fix windows ccompile failed * polish format Co-authored-by: Wang Huan --- cmake/pten.cmake | 26 +++ paddle/fluid/framework/CMakeLists.txt | 6 +- paddle/fluid/framework/details/op_registry.h | 6 +- paddle/fluid/framework/infershape_utils.cc | 38 ++++ paddle/fluid/framework/infershape_utils.h | 1 - paddle/fluid/framework/operator.cc | 14 +- paddle/fluid/framework/operator.h | 6 +- paddle/fluid/framework/pten_utils.cc | 61 ++----- paddle/fluid/framework/pten_utils.h | 22 +-- paddle/fluid/operators/scale_op.cc | 7 - paddle/fluid/operators/sign_op.cc | 16 +- paddle/fluid/pybind/CMakeLists.txt | 2 +- paddle/fluid/pybind/pybind.cc | 1 + paddle/pten/CMakeLists.txt | 2 +- paddle/pten/core/CMakeLists.txt | 3 +- paddle/pten/core/compat/CMakeLists.txt | 1 + paddle/pten/core/compat/op_utils.cc | 29 +++ paddle/pten/core/compat/op_utils.h | 166 ++++++++++++++++++ paddle/pten/core/infermeta_utils.h | 2 +- paddle/pten/core/meta_tensor.cc | 2 +- paddle/pten/core/meta_tensor.h | 2 +- paddle/pten/infermeta/CMakeLists.txt | 2 +- paddle/pten/infermeta/unary.cc | 11 +- paddle/pten/infermeta/unary.h | 26 +-- paddle/pten/ops/CMakeLists.txt | 1 + paddle/pten/ops/compat/CMakeLists.txt | 11 ++ .../compat/{scale_args_fn.h => scale_sig.cc} | 15 +- paddle/pten/tests/core/CMakeLists.txt | 1 + paddle/pten/tests/core/test_op_utils.cc | 32 ++++ 29 files changed, 395 insertions(+), 117 deletions(-) create mode 100644 paddle/pten/core/compat/op_utils.cc create mode 100644 paddle/pten/core/compat/op_utils.h create mode 100644 paddle/pten/ops/compat/CMakeLists.txt rename paddle/pten/ops/compat/{scale_args_fn.h => scale_sig.cc} (72%) create mode 100644 paddle/pten/tests/core/test_op_utils.cc diff --git a/cmake/pten.cmake b/cmake/pten.cmake index 70d61027da..8e1d233986 100644 --- a/cmake/pten.cmake +++ b/cmake/pten.cmake @@ -243,3 +243,29 @@ function(register_kernels) endif() endforeach() endfunction() + +function(append_op_util_declare TARGET) + file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content) + string(REGEX MATCH "(PT_REGISTER_API_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}") + string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}") + string(REPLACE "PT_REGISTER_API_NAME" "PT_REGISTER_API_NAME" util_declare "${util_declare}") + string(APPEND util_declare ");") + file(APPEND ${op_utils_header} "${util_declare}") +endfunction() + +function(register_op_utils TARGET_NAME) + set(utils_srcs) + set(options "") + set(oneValueArgs "") + set(multiValueArgs EXCLUDES DEPS) + cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + file(GLOB SIGNATURES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_sig.cc") + foreach(target ${SIGNATURES}) + append_op_util_declare(${target}) + list(APPEND utils_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${target}) + endforeach() + + cc_library(${TARGET_NAME} SRCS ${utils_srcs} DEPS ${register_op_utils_DEPS}) +endfunction() diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 9282054227..ce63a58d41 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -192,11 +192,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va IF(WITH_XPU) cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils - pten pten_utils kernel_factory infershape_utils) + pten pten_utils kernel_factory infershape_utils op_utils) ELSE() cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils - pten pten_utils kernel_factory infershape_utils) + pten pten_utils kernel_factory infershape_utils op_utils) ENDIF() cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) @@ -404,7 +404,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info) -cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) +cc_library(infershape_utils SRCS infershape_utils.cc DEPS pten_utils attribute shape_inference op_utils) # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index 27f55e237f..427b981e7c 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -275,10 +275,8 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { - PADDLE_ENFORCE_EQ( - info->infer_shape_, nullptr, - platform::errors::AlreadyExists( - "Duplicate InferShapeFN of %s has been registered", op_type)); + // Note: if fill InferShapeFN by this Filler, the infershape here + // will overwrite the op->InferShape func registered in kOperator Filler info->infer_shape_ = [](InferShapeContext* ctx) { T inference; inference(ctx); diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 9a91a5208e..08b945159a 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -15,11 +15,14 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/pten/core/compat/arg_map_context.h" +#include "paddle/pten/core/compat/op_utils.h" #include "paddle/pten/core/compat_utils.h" #include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/infermeta_utils.h" #include "paddle/pten/core/meta_tensor.h" namespace paddle { @@ -186,5 +189,40 @@ class CompatMetaTensor : public pten::MetaTensor { bool is_runtime_; }; +pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, + const std::string& op_type) { + // 1. get kernel args + InitDefaultKernelSignatureMap(); + auto arg_map_fn = pten::OpUtilsMap::Instance().GetArgumentMappingFn(op_type); + PADDLE_ENFORCE_NOT_NULL( + arg_map_fn, platform::errors::NotFound( + "The ArgumentMappingFn of %s op is not found.", op_type)); + InferShapeArgumentMappingContext arg_map_context(*ctx); + auto signature = arg_map_fn(arg_map_context); + VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature; + + // 2. build infermeta context + pten::InferMetaContext infer_meta_context(ctx->IsRuntime()); + + auto& input_names = std::get<0>(signature.args); + auto& output_names = std::get<2>(signature.args); + // TODO(chenweihang): support attrs in next pr + // auto& attr_names = std::get<1>(signature.args); + + // TODO(chenweihang): support multiple inputs and outputs + pten::InferMetaContext infer_mete_context; + for (auto& in_name : input_names) { + infer_meta_context.EmplaceBackInput(std::make_shared( + ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime())); + } + for (auto& out_name : output_names) { + infer_meta_context.EmplaceBackOutput(std::make_shared( + ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime())); + } + // TODO(chenweihang): support attrs later + + return infer_meta_context; +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/infershape_utils.h b/paddle/fluid/framework/infershape_utils.h index f943989523..fbfb44e27c 100644 --- a/paddle/fluid/framework/infershape_utils.h +++ b/paddle/fluid/framework/infershape_utils.h @@ -26,7 +26,6 @@ class InferMetaContext; namespace paddle { namespace framework { -// TODO(chenweihang): impl this function in next PR pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, const std::string& op_type); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 45c55340cb..ae61b7388d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -31,6 +31,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler.h" #include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/ops/compat/signatures.h" namespace pten { class DenseTensor; @@ -1086,6 +1087,13 @@ bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); } +void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { + PADDLE_THROW(platform::errors::PermissionDenied( + "The default InferShape function of OperatorWithKernel is not allowed to " + "be called, please override corresponding InferShape function in the " + "specific operator.")); +} + void OperatorWithKernel::RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const { @@ -1784,8 +1792,10 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( const ExecutionContext& ctx) const { - return KernelSignatureMap::Instance().Get( - pten::TransToPtenKernelName(Type())); + InitDefaultKernelSignatureMap(); + ExecutionArgumentMappingContext arg_mapping_ctx(ctx); + return pten::OpUtilsMap::Instance().GetArgumentMappingFn(Type())( + arg_mapping_ctx); } Scope* OperatorWithKernel::PreparePtenData( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 7971d6154f..c280eeaa0f 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -41,6 +41,7 @@ limitations under the License. */ #include "paddle/utils/flat_hash_map.h" #include "paddle/pten/core/compat/arg_map_context.h" +#include "paddle/pten/core/compat/op_utils.h" #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_factory.h" @@ -468,8 +469,7 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { } bool IsDenseTensorInput(const std::string& name) const override { - return ctx_.InputVar(name)->IsType() || - ctx_.InputVar(name)->IsType(); + return ctx_.InputVar(name)->IsType(); } bool IsSelectedRowsInput(const std::string& name) const override { @@ -550,7 +550,7 @@ class OperatorWithKernel : public OperatorBase { bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const; - virtual void InferShape(InferShapeContext* ctx) const = 0; + virtual void InferShape(InferShapeContext* ctx) const; void RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const override; diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index 2fd5b87b7f..dc20aaffec 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/pten_utils.h" +#include "paddle/pten/core/compat/op_utils.h" #include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/kernel_factory.h" @@ -89,48 +90,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( return pten::KernelKey(backend, layout, dtype); } -KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr; -std::once_flag KernelSignatureMap::init_flag_; - -KernelSignatureMap& KernelSignatureMap::Instance() { - std::call_once(init_flag_, [] { - kernel_signature_map_ = new KernelSignatureMap(); - for (const auto& pair : OpInfoMap::Instance().map()) { - const auto& op_type = pair.first; - const auto* op_proto = pair.second.proto_; - if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) && - op_proto) { - KernelArgsNameMakerByOpProto maker(op_proto); - VLOG(10) << "Register kernel signature for " << op_type; - auto success = kernel_signature_map_->map_ - .emplace(pten::TransToPtenKernelName(op_type), - std::move(maker.GetKernelSignature())) - .second; - PADDLE_ENFORCE_EQ( - success, true, - platform::errors::PermissionDenied( - "Kernel signature of the operator %s has been registered.", - op_type)); - } - } - }); - return *kernel_signature_map_; -} - -bool KernelSignatureMap::Has(const std::string& op_type) const { - return map_.find(op_type) != map_.end(); -} - -const KernelSignature& KernelSignatureMap::Get( - const std::string& op_type) const { - auto it = map_.find(op_type); - PADDLE_ENFORCE_NE( - it, map_.end(), - platform::errors::NotFound( - "Operator `%s`'s kernel signature is not registered.", op_type)); - return it->second; -} - const paddle::SmallVector& KernelArgsNameMakerByOpProto::GetInputArgsNames() { for (int i = 0; i < op_proto_->inputs_size(); ++i) { @@ -196,6 +155,24 @@ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { GetOutputArgsNames()); } +std::once_flag kernel_sig_map_init_flag; + +void InitDefaultKernelSignatureMap() { + std::call_once(kernel_sig_map_init_flag, [] { + for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) { + const auto& op_type = pair.first; + const auto* op_proto = pair.second.proto_; + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) && + op_proto) { + paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto); + VLOG(10) << "Register kernel signature for " << op_type; + pten::DefaultKernelSignatureMap::Instance().Insert( + op_type, std::move(maker.GetKernelSignature())); + } + } + }); +} + void SetAllocationForOutputTenosr(pten::DenseTensor* tensor, const platform::Place& place) { if (!tensor->IsInitialized() || !(tensor->place() == place)) { diff --git a/paddle/fluid/framework/pten_utils.h b/paddle/fluid/framework/pten_utils.h index 4985e53ee6..9b1019f658 100644 --- a/paddle/fluid/framework/pten_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -44,26 +44,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( /* Kernel Args parse */ -// TODO(chenweihang): we can generate this map by proto info in compile time -class KernelSignatureMap { - public: - static KernelSignatureMap& Instance(); - - bool Has(const std::string& op_type) const; - - const KernelSignature& Get(const std::string& op_type) const; - - private: - KernelSignatureMap() = default; - DISABLE_COPY_AND_ASSIGN(KernelSignatureMap); - - private: - static KernelSignatureMap* kernel_signature_map_; - static std::once_flag init_flag_; - - paddle::flat_hash_map map_; -}; - class KernelArgsNameMaker { public: virtual ~KernelArgsNameMaker() {} @@ -72,6 +52,8 @@ class KernelArgsNameMaker { virtual const paddle::SmallVector& GetAttrsArgsNames() = 0; }; +void InitDefaultKernelSignatureMap(); + void SetAllocationForOutputTenosr(pten::DenseTensor* tensor, const platform::Place& place); diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 86f4e1b3ac..a195452791 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/operators/scale_op.h" #include #include "paddle/fluid/platform/float16.h" -#include "paddle/pten/ops/compat/scale_args_fn.h" namespace paddle { namespace framework { @@ -71,12 +70,6 @@ class ScaleOp : public framework::OperatorWithKernel { #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } - - framework::KernelSignature GetExpectedPtenKernelArgs( - const framework::ExecutionContext &ctx) const override { - framework::ExecutionArgumentMappingContext arg_mapping_ctx(ctx); - return pten::ScaleOpArgumentMapping(arg_mapping_ctx); - } }; class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/sign_op.cc b/paddle/fluid/operators/sign_op.cc index 6207c33f9d..f361240780 100644 --- a/paddle/fluid/operators/sign_op.cc +++ b/paddle/fluid/operators/sign_op.cc @@ -14,7 +14,10 @@ limitations under the License. */ #include "paddle/fluid/operators/sign_op.h" #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/infermeta/unary.h" namespace paddle { namespace operators { @@ -22,14 +25,6 @@ namespace operators { class SignOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "sign"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sign"); - - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; template @@ -64,9 +59,12 @@ class SignGradMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(sign, SignInferShapeFunctor, + PT_INFER_META(pten::UnchangedInferMetaNew)); REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker, ops::SignGradMaker, - ops::SignGradMaker); + ops::SignGradMaker, + SignInferShapeFunctor); REGISTER_OP_CPU_KERNEL( sign, ops::SignKernel, ops::SignKernel); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 1df77c78a4..8c1d3d0230 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -2,7 +2,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator - cost_model cuda_graph_with_memory_pool fleet_executor global_utils) + cost_model cuda_graph_with_memory_pool fleet_executor global_utils pten_utils) if (WITH_PSCORE) set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 57d7e98cef..ad018944e4 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -50,6 +50,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/prune.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/save_load_util.h" #include "paddle/fluid/framework/scope_pool.h" diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt index 671ed28313..78e86c12cb 100644 --- a/paddle/pten/CMakeLists.txt +++ b/paddle/pten/CMakeLists.txt @@ -21,7 +21,7 @@ add_subdirectory(ops) add_subdirectory(tests) # make an unity target for compile deps -set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils) +set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos) get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS) # keep this message for debug, remove it later if needless message(STATUS "All standard pten kernels: ${pten_kernels}") diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index 185ff2858d..e89d2cd3b3 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -21,13 +21,14 @@ cc_library(tensor_meta SRCS tensor_meta.cc DEPS pten_enforce mixed_vector) cc_library(lod_utils SRCS lod_utils.cc DEPS pten_enforce mixed_vector) cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base) cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) + cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) +cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim) cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) -cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim) # Will remove once we implemented MKLDNN_Tensor if(WITH_MKLDNN) add_dependencies(dense_tensor mkldnn) diff --git a/paddle/pten/core/compat/CMakeLists.txt b/paddle/pten/core/compat/CMakeLists.txt index 0af35c20b3..0c081edb81 100644 --- a/paddle/pten/core/compat/CMakeLists.txt +++ b/paddle/pten/core/compat/CMakeLists.txt @@ -1 +1,2 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce) +cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce convert_utils) diff --git a/paddle/pten/core/compat/op_utils.cc b/paddle/pten/core/compat/op_utils.cc new file mode 100644 index 0000000000..12c2d74737 --- /dev/null +++ b/paddle/pten/core/compat/op_utils.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +DefaultKernelSignatureMap& DefaultKernelSignatureMap::Instance() { + static DefaultKernelSignatureMap g_default_kernel_sig_map; + return g_default_kernel_sig_map; +} + +OpUtilsMap& OpUtilsMap::Instance() { + static OpUtilsMap g_op_utils_map; + return g_op_utils_map; +} + +} // namespace pten diff --git a/paddle/pten/core/compat/op_utils.h b/paddle/pten/core/compat/op_utils.h new file mode 100644 index 0000000000..505ef13891 --- /dev/null +++ b/paddle/pten/core/compat/op_utils.h @@ -0,0 +1,166 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/pten/core/compat/arg_map_context.h" +#include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/core/kernel_def.h" +#include "paddle/pten/core/macros.h" +#include "paddle/utils/flat_hash_map.h" + +#include "paddle/fluid/platform/enforce.h" + +namespace pten { + +class DefaultKernelSignatureMap { + public: + static DefaultKernelSignatureMap& Instance(); + + bool Has(const std::string& op_type) const { return map_.count(op_type) > 0; } + + const KernelSignature& Get(const std::string& op_type) const { + auto it = map_.find(op_type); + PADDLE_ENFORCE_NE( + it, + map_.end(), + paddle::platform::errors::NotFound( + "Operator `%s`'s kernel signature is not registered.", op_type)); + return it->second; + } + + void Insert(std::string op_type, KernelSignature signature) { + PADDLE_ENFORCE_NE( + Has(op_type), + true, + paddle::platform::errors::AlreadyExists( + "Operator (%s)'s Kernel Siginature has been registered.", op_type)); + map_.insert({std::move(op_type), std::move(signature)}); + } + + private: + DefaultKernelSignatureMap() = default; + + paddle::flat_hash_map map_; + + DISABLE_COPY_AND_ASSIGN(DefaultKernelSignatureMap); +}; + +class OpUtilsMap { + public: + static OpUtilsMap& Instance(); + + bool Contains(const std::string& op_type) const { + return name_map_.count(op_type) || arg_mapping_fn_map_.count(op_type); + } + + void InsertApiName(std::string op_type, std::string api_name) { + PADDLE_ENFORCE_EQ( + name_map_.count(op_type), + 0UL, + paddle::platform::errors::AlreadyExists( + "Operator (%s)'s api name has been registered.", op_type)); + name_map_.insert({std::move(op_type), std::move(api_name)}); + } + + void InsertArgumentMappingFn(std::string op_type, ArgumentMappingFn fn) { + PADDLE_ENFORCE_EQ( + arg_mapping_fn_map_.count(op_type), + 0UL, + paddle::platform::errors::AlreadyExists( + "Operator (%s)'s argu,emt mapping function has been registered.", + op_type)); + arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)}); + } + + std::string GetApiName(const std::string& op_type) const { + auto it = name_map_.find(op_type); + if (it == name_map_.end()) { + return "deprecated"; + } else { + return it->second; + } + } + + ArgumentMappingFn GetArgumentMappingFn(const std::string& op_type) const { + auto it = arg_mapping_fn_map_.find(op_type); + if (it == arg_mapping_fn_map_.end()) { + auto func = + [op_type](const ArgumentMappingContext& ctx) -> KernelSignature { + return DefaultKernelSignatureMap::Instance().Get(op_type); + }; + return func; + } else { + return it->second; + } + } + + private: + OpUtilsMap() = default; + + paddle::flat_hash_map name_map_; + paddle::flat_hash_map arg_mapping_fn_map_; + + DISABLE_COPY_AND_ASSIGN(OpUtilsMap); +}; + +struct ApiNameRegistrar { + ApiNameRegistrar(const char* op_type, const char* api_name) { + OpUtilsMap::Instance().InsertApiName(op_type, api_name); + } +}; + +struct ArgumentMappingFnRegistrar { + ArgumentMappingFnRegistrar(const char* op_type, + ArgumentMappingFn arg_mapping_fn) { + OpUtilsMap::Instance().InsertArgumentMappingFn(op_type, + std::move(arg_mapping_fn)); + } +}; + +#define PT_REGISTER_API_NAME(op_type, api_name) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_api_name_ns_check_##op_type, \ + "PT_REGISTER_API_NAME must be called in global namespace."); \ + static const ::pten::ApiNameRegistrar __registrar_api_name_for_##op_type( \ + #op_type, #api_name); \ + int TouchApiNameSymbol_##op_type() { return 0; } + +#define PT_DECLARE_API_NAME(op_type) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_declare_ai_name_ns_check_##op_type, \ + "PT_DECLARE_API_NAME must be called in global namespace."); \ + extern int TouchApiNameSymbol_##op_type(); \ + UNUSED static int __declare_api_name_symbol_for_##op_type = \ + TouchApiNameSymbol_##op_type() + +#define PT_REGISTER_ARG_MAPPING_FN(op_type, arg_mapping_fn) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_arg_map_fn_ns_check_##op_type, \ + "PT_REGISTER_ARG_MAPPING_FN must be called in global namespace."); \ + static const ::pten::ArgumentMappingFnRegistrar \ + __registrar_arg_map_fn_for_##op_type(#op_type, arg_mapping_fn); \ + int TouchArgumentMappingFnSymbol_##op_type() { return 0; } + +#define PT_DECLARE_ARG_MAPPING_FN(op_type) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_declare_arg_map_fn_ns_check_##op_type, \ + "PT_DECLARE_ARG_MAPPING_FN must be called in global namespace."); \ + extern int TouchArgumentMappingFnSymbol_##op_type(); \ + UNUSED static int __declare_arg_map_fn_symbol_for_##op_type = \ + TouchArgumentMappingFnSymbol_##op_type() + +} // namespace pten diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h index c6812dee92..bfc9d29e63 100644 --- a/paddle/pten/core/infermeta_utils.h +++ b/paddle/pten/core/infermeta_utils.h @@ -151,7 +151,7 @@ struct InferMetaFnImpl { struct InferMetaFnCallHelper { template static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { - const MetaConfig& arg = ctx->GetMetaConfig(); + MetaConfig arg = ctx->GetMetaConfig(); InferMetaFnCallHelper::template Call( ctx, pargs..., arg); } diff --git a/paddle/pten/core/meta_tensor.cc b/paddle/pten/core/meta_tensor.cc index f52d771b73..a8229b568a 100644 --- a/paddle/pten/core/meta_tensor.cc +++ b/paddle/pten/core/meta_tensor.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/pten/core/meta_tensor.h b/paddle/pten/core/meta_tensor.h index 442ff4137d..1435e1c391 100644 --- a/paddle/pten/core/meta_tensor.h +++ b/paddle/pten/core/meta_tensor.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/pten/infermeta/CMakeLists.txt b/paddle/pten/infermeta/CMakeLists.txt index 8e50d9d2c9..2216d38708 100644 --- a/paddle/pten/infermeta/CMakeLists.txt +++ b/paddle/pten/infermeta/CMakeLists.txt @@ -1,2 +1,2 @@ -cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils) +cc_library(infermeta SRCS nullary.cc unary.cc binary.cc multiary.cc DEPS convert_utils infermeta_utils) cc_library(backward_infermeta SRCS backward.cc DEPS convert_utils) diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index 95b419856b..fec50d528d 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -12,12 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -// See Note [ Why still include the fluid headers? ] #include "paddle/pten/infermeta/unary.h" + #include +#include "paddle/pten/core/infermeta_utils.h" + namespace pten { +void UnchangedInferMetaNew(MetaConfig config, + const MetaTensor& x, + MetaTensor* out) { + out->set_dims(x.dims()); + out->share_lod(x); +} + DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta) { return x_meta; } diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 388a9fca34..670c70de84 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -16,23 +16,27 @@ limitations under the License. */ // See Note [ Why still include the fluid headers? ] #include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/core/infermeta_utils.h" +#include "paddle/pten/core/meta_tensor.h" #include "paddle/pten/core/tensor_meta.h" namespace pten { +class MetaConfig; + // Common InferMeta Functions for unary operators, The format like: // -// 1. DenseTensorMeta [OpName]InferMeta(const DenseTensorMeta& x_meta, ...) -// {} -// 2. std::pair [OpName]InferMeta(const -// DenseTensorMeta& -// x_meta, ...) {} -// 3. std::tuple -// [OpName]InferMeta(const -// DenseTensorMeta& x_meta, ...) -// NOTE: The name "InferMeta" may be not appropriate. "InferMeta" may be good. -// Because functions in this file -// not only can infer shape, but alse need infer lod or other useful data. +// void [OpName]InferMeta(const MetaTensor& x, ..., MetaTensor* out) {} +// +// NOTE: The name "InferShape" may be not appropriate. "InferMeta" may be good. +// Because functions in this file not only can infer shape, but also need +// infer lod or other useful data. + +// TODO(chenweihang): update all InferMeta function format in next pr, +// now add UnchangedInferMetaNew for test new format +void UnchangedInferMetaNew(MetaConfig config, + const MetaTensor& x, + MetaTensor* out); DenseTensorMeta UnchangedInferMeta(const DenseTensorMeta& x_meta); diff --git a/paddle/pten/ops/CMakeLists.txt b/paddle/pten/ops/CMakeLists.txt index e69de29bb2..910b62766e 100644 --- a/paddle/pten/ops/CMakeLists.txt +++ b/paddle/pten/ops/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(compat) diff --git a/paddle/pten/ops/compat/CMakeLists.txt b/paddle/pten/ops/compat/CMakeLists.txt new file mode 100644 index 0000000000..dd214087e1 --- /dev/null +++ b/paddle/pten/ops/compat/CMakeLists.txt @@ -0,0 +1,11 @@ +set(op_utils_header ${PADDLE_BINARY_DIR}/paddle/pten/ops/compat/signatures.h.tmp CACHE INTERNAL "op_args_fns.cc file") +set(op_utils_header_final ${PADDLE_BINARY_DIR}/paddle/pten/ops/compat/signatures.h) +file(WRITE ${op_utils_header} "// Generated by the paddle/pten/ops/compat/CMakeLists.txt. DO NOT EDIT!\n\n") +file(APPEND ${op_utils_header} "#include \"paddle/pten/core/compat/op_utils.h\"\n\n") + +# Automatically generate the registration code of all arg map functions +# and compile the corresponding target to avoid frequent code conflicts +# when writing to same file +register_op_utils(op_compat_infos DEPS op_utils) + +copy_if_different(${op_utils_header} ${op_utils_header_final}) diff --git a/paddle/pten/ops/compat/scale_args_fn.h b/paddle/pten/ops/compat/scale_sig.cc similarity index 72% rename from paddle/pten/ops/compat/scale_args_fn.h rename to paddle/pten/ops/compat/scale_sig.cc index 91f0db389d..5ce159a5d8 100644 --- a/paddle/pten/ops/compat/scale_args_fn.h +++ b/paddle/pten/ops/compat/scale_sig.cc @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - -#include "paddle/pten/core/compat/arg_map_context.h" +#include "paddle/pten/core/compat/op_utils.h" namespace pten { @@ -22,15 +20,18 @@ KernelSignature ScaleOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("X")) { std::string scale_attr; if (ctx.HasInput("ScaleTensor")) { - scale_attr = "ScaleTensor"; + return KernelSignature( + "scale", {"X"}, {"ScaleTensor", "bias", "bias_after_scale"}, {"Out"}); } else { - scale_attr = "scale"; + return KernelSignature( + "scale", {"X"}, {"scale", "bias", "bias_after_scale"}, {"Out"}); } - return KernelSignature( - "scale", {"X"}, {scale_attr, "bias", "bias_after_scale"}, {"Out"}); } // TODO(chenweihang): support other cases after selected rows added return KernelSignature("scale.unregistered", {}, {}, {}); } } // namespace pten + +// op_type, api_name, arg_mapping_fn +PT_REGISTER_ARG_MAPPING_FN(scale, pten::ScaleOpArgumentMapping); diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 43e1480e2c..27a0173ef6 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -3,6 +3,7 @@ cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) +cc_test(test_op_utils SRCS test_op_utils.cc DEPS op_compat_infos) cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context) cc_test(test_ddim SRCS test_ddim.cc DEPS ddim) diff --git a/paddle/pten/tests/core/test_op_utils.cc b/paddle/pten/tests/core/test_op_utils.cc new file mode 100644 index 0000000000..6c4a418685 --- /dev/null +++ b/paddle/pten/tests/core/test_op_utils.cc @@ -0,0 +1,32 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "gtest/gtest.h" +#include "paddle/pten/core/compat/op_utils.h" +#include "paddle/pten/ops/compat/signatures.h" + +namespace pten { +namespace tests { + +TEST(OpUtilsMap, ArgMappingFnExists) { + std::cout << "enter ArgMappingFnExists"; + auto scale_arg_mapping_fn = + pten::OpUtilsMap::Instance().GetArgumentMappingFn("scale"); + EXPECT_NE(scale_arg_mapping_fn, nullptr); +} + +} // namespace tests +} // namespace pten -- GitLab