diff --git a/cmake/pten.cmake b/cmake/pten.cmake index acc30aa22996a7ccb8fbe0b04ecd5b8365887bc1..2a040c73b981fa07e1d77f10c8306d7711095de6 100644 --- a/cmake/pten.cmake +++ b/cmake/pten.cmake @@ -282,9 +282,9 @@ endfunction() function(append_op_util_declare TARGET) file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content) - string(REGEX MATCH "(PT_REGISTER_API_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}") + string(REGEX MATCH "(PT_REGISTER_BASE_KERNEL_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}") string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}") - string(REPLACE "PT_REGISTER_API_NAME" "PT_REGISTER_API_NAME" util_declare "${util_declare}") + string(REPLACE "PT_REGISTER_BASE_KERNEL_NAME" "PT_DECLARE_BASE_KERNEL_NAME" util_declare "${util_declare}") string(APPEND util_declare ");") file(APPEND ${op_utils_header} "${util_declare}") endfunction() diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index 387e3a0aed71408c7b7703976484383cea261d4f..1a27f971fa082f894ac422fcfb8762ab7fb46725 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -185,9 +185,8 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { } KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { - return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()), - GetInputArgsNames(), GetAttrsArgsNames(), - GetOutputArgsNames()); + return KernelSignature(op_proto_->type(), GetInputArgsNames(), + GetAttrsArgsNames(), GetOutputArgsNames()); } std::once_flag kernel_sig_map_init_flag; diff --git a/paddle/pten/core/compat/CMakeLists.txt b/paddle/pten/core/compat/CMakeLists.txt index c6377f2e812b335e6af7a33208687a2ea3a8eb75..6d1529d94fd4051f9c11d9047b75f58f527ca089 100644 --- a/paddle/pten/core/compat/CMakeLists.txt +++ b/paddle/pten/core/compat/CMakeLists.txt @@ -1,13 +1,13 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce) +cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce) if(WITH_GPU) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) + cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils gpu_info) elseif(WITH_ROCM) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) + cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils gpu_info) elseif(WITH_XPU) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place xpu_info) + cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils xpu_info) elseif(WITH_ASCEND_CL) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place npu_info) + cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils npu_info) else() - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place) + cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils) endif() -cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce convert_utils) diff --git a/paddle/pten/core/compat/convert_utils.cc b/paddle/pten/core/compat/convert_utils.cc index c819626870dfe5a1862c7d126fb331f7bfc7d367..355a67601dd96df693faf117b6385471070bceac 100644 --- a/paddle/pten/core/compat/convert_utils.cc +++ b/paddle/pten/core/compat/convert_utils.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/pten/core/compat/convert_utils.h" - -#include "paddle/pten/core/compat/kernel_alias_name.h" +#include "paddle/pten/core/compat/op_utils.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device/gpu/gpu_info.h" @@ -235,21 +234,18 @@ std::string DataType2String(DataType dtype) { } } -const std::string& TransToPtenKernelName(const std::string& fluid_op_name) { - if (kernel_alias_name_map.find(fluid_op_name) != - kernel_alias_name_map.end()) { - return kernel_alias_name_map.at(fluid_op_name); - } - return fluid_op_name; +std::string TransToPtenKernelName(const std::string& fluid_op_name) { + return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name); } const std::string& TransToFluidOpName(const std::string& pten_kernel_name) { - auto it = std::find_if(kernel_alias_name_map.begin(), - kernel_alias_name_map.end(), + auto& base_kernel_name_map = OpUtilsMap::Instance().base_kernel_name_map(); + auto it = std::find_if(base_kernel_name_map.begin(), + base_kernel_name_map.end(), [&pten_kernel_name](const auto& pair) { return pair.second == pten_kernel_name; }); - if (it != kernel_alias_name_map.end()) { + if (it != base_kernel_name_map.end()) { return it->first; } return pten_kernel_name; diff --git a/paddle/pten/core/compat/convert_utils.h b/paddle/pten/core/compat/convert_utils.h index 31e38eee0f7c2e326eefbf4941cafdaa3ee2e9b5..1d241c5ad4040a989101f286fd59ce49581ce3e5 100644 --- a/paddle/pten/core/compat/convert_utils.h +++ b/paddle/pten/core/compat/convert_utils.h @@ -27,7 +27,7 @@ limitations under the License. */ namespace pten { -const std::string& TransToPtenKernelName(const std::string& fluid_op_name); +std::string TransToPtenKernelName(const std::string& fluid_op_name); const std::string& TransToFluidOpName(const std::string& pten_kernel_name); Backend TransToPtenBackend(const pten::Place& place); diff --git a/paddle/pten/core/compat/kernel_alias_name.h b/paddle/pten/core/compat/kernel_alias_name.h deleted file mode 100644 index cfe3f7579741135fa64fc9957cc5b65e6a98dc1c..0000000000000000000000000000000000000000 --- a/paddle/pten/core/compat/kernel_alias_name.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -// TODO(yuanrisheng): this file may need to be removed -#pragma once - -namespace pten { - -// the key is kernel_name in fluid, the value is the kernel_name in pten -// the key is sorted by key's alphabet -const std::unordered_map kernel_alias_name_map = { - {"elementwise_add", "add_raw"}, - {"elementwise_add_grad", "add_grad"}, - {"elementwise_div", "divide_raw"}, - {"elementwise_mul", "muliply_raw"}, - {"elementwise_sub", "subtract_raw"}, - {"elementwise_sub_grad", "subtract_grad"}, - {"fill_any_like", "full_like"}, - {"fill_constant", "full"}, - {"flatten_contiguous_range", "flatten"}, - {"flatten_contiguous_range_grad", "flatten_grad"}, - {"matmul_v2", "matmul"}, - {"matmul_v2_grad", "matmul_grad"}, - {"matmul_v2_grad_grad", "matmul_double_grad"}, - {"matmul_v2_triple_grad", "matmul_triple_grad"}, - {"reduce_mean", "mean_raw"}, - {"reduce_sum", "sum_raw"}, - {"reshape2", "reshape"}, - {"reshape2_grad", "reshape_grad"}, - {"reshape2_grad_grad", "reshape_double_grad"}, - // fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated - {"flatten", "deprecated"}, - {"flatten_grad", "deprecated"}, - {"matmul", "deprecated"}, - {"matmul_grad", "deprecated"}, - {"matmul_grad_grad", "deprecated"}, - {"mean", "deprecated"}, - {"reshape", "deprecated"}, - {"reshape_grad", "deprecated"}, - {"sum", "deprecated"}}; - -} // namespace pten diff --git a/paddle/pten/core/compat/op_utils.h b/paddle/pten/core/compat/op_utils.h index d35eeffd7ef830dc2f786c4f37581899332a32e5..93090616366f007427b6b1d5d20608545a13f13f 100644 --- a/paddle/pten/core/compat/op_utils.h +++ b/paddle/pten/core/compat/op_utils.h @@ -14,7 +14,8 @@ limitations under the License. */ #pragma once -#include +#include +#include #include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/enforce.h" @@ -25,6 +26,22 @@ limitations under the License. */ namespace pten { +/** + * Some fluid ops are no longer used under the corresponding official API + * system of 2.0. These names need to correspond to the official API names + * after 2.0, and can no longer be occupied by the previously abandoned ops. + * They are marked here uniformly. + */ +const std::unordered_set deprecated_op_names({"flatten", + "flatten_grad", + "matmul", + "matmul_grad", + "matmul_grad_grad", + "mean", + "reshape", + "reshape_grad", + "sum"}); + class DefaultKernelSignatureMap { public: static DefaultKernelSignatureMap& Instance(); @@ -63,16 +80,18 @@ class OpUtilsMap { static OpUtilsMap& Instance(); bool Contains(const std::string& op_type) const { - return name_map_.count(op_type) || arg_mapping_fn_map_.count(op_type); + return base_kernel_name_map_.count(op_type) || + arg_mapping_fn_map_.count(op_type); } - void InsertApiName(std::string op_type, std::string api_name) { + void InsertBaseKernelName(std::string op_type, std::string base_kernel_name) { PADDLE_ENFORCE_EQ( - name_map_.count(op_type), + base_kernel_name_map_.count(op_type), 0UL, pten::errors::AlreadyExists( "Operator (%s)'s api name has been registered.", op_type)); - name_map_.insert({std::move(op_type), std::move(api_name)}); + base_kernel_name_map_.insert( + {std::move(op_type), std::move(base_kernel_name)}); } void InsertArgumentMappingFn(std::string op_type, ArgumentMappingFn fn) { @@ -85,10 +104,13 @@ class OpUtilsMap { arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)}); } - std::string GetApiName(const std::string& op_type) const { - auto it = name_map_.find(op_type); - if (it == name_map_.end()) { + std::string GetBaseKernelName(const std::string& op_type) const { + if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) { return "deprecated"; + } + auto it = base_kernel_name_map_.find(op_type); + if (it == base_kernel_name_map_.end()) { + return op_type; } else { return it->second; } @@ -107,18 +129,23 @@ class OpUtilsMap { } } + const paddle::flat_hash_map& base_kernel_name_map() + const { + return base_kernel_name_map_; + } + private: OpUtilsMap() = default; - paddle::flat_hash_map name_map_; + paddle::flat_hash_map base_kernel_name_map_; paddle::flat_hash_map arg_mapping_fn_map_; DISABLE_COPY_AND_ASSIGN(OpUtilsMap); }; -struct ApiNameRegistrar { - ApiNameRegistrar(const char* op_type, const char* api_name) { - OpUtilsMap::Instance().InsertApiName(op_type, api_name); +struct BaseKernelNameRegistrar { + BaseKernelNameRegistrar(const char* op_type, const char* base_kernel_name) { + OpUtilsMap::Instance().InsertBaseKernelName(op_type, base_kernel_name); } }; @@ -130,21 +157,21 @@ struct ArgumentMappingFnRegistrar { } }; -#define PT_REGISTER_API_NAME(op_type, api_name) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_register_api_name_ns_check_##op_type, \ - "PT_REGISTER_API_NAME must be called in global namespace."); \ - static const ::pten::ApiNameRegistrar __registrar_api_name_for_##op_type( \ - #op_type, #api_name); \ - int TouchApiNameSymbol_##op_type() { return 0; } - -#define PT_DECLARE_API_NAME(op_type) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_declare_ai_name_ns_check_##op_type, \ - "PT_DECLARE_API_NAME must be called in global namespace."); \ - extern int TouchApiNameSymbol_##op_type(); \ - UNUSED static int __declare_api_name_symbol_for_##op_type = \ - TouchApiNameSymbol_##op_type() +#define PT_REGISTER_BASE_KERNEL_NAME(op_type, base_kernel_name) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_base_kernel_name_ns_check_##op_type, \ + "PT_REGISTER_BASE_KERNEL_NAME must be called in global namespace."); \ + static const ::pten::BaseKernelNameRegistrar \ + __registrar_base_kernel_name_for_##op_type(#op_type, #base_kernel_name); \ + int TouchBaseKernelNameSymbol_##op_type() { return 0; } + +#define PT_DECLARE_BASE_KERNEL_NAME(op_type) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_declare_ai_name_ns_check_##op_type, \ + "PT_DECLARE_BASE_KERNEL_NAME must be called in global namespace."); \ + extern int TouchBaseKernelNameSymbol_##op_type(); \ + UNUSED static int __declare_base_kernel_name_symbol_for_##op_type = \ + TouchBaseKernelNameSymbol_##op_type() #define PT_REGISTER_ARG_MAPPING_FN(op_type, arg_mapping_fn) \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ diff --git a/paddle/pten/ops/compat/elementwise_sig.cc b/paddle/pten/ops/compat/elementwise_sig.cc index 77e7625532b18b0856e1c645271f3793cc292db7..4c14a5d139e3cb1c80fb6a71c03a7c72bd37f92d 100644 --- a/paddle/pten/ops/compat/elementwise_sig.cc +++ b/paddle/pten/ops/compat/elementwise_sig.cc @@ -66,6 +66,13 @@ KernelSignature ElementwiseDivOpArgumentMapping( } // namespace pten +PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract_raw); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, muliply_raw); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide_raw); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); + PT_REGISTER_ARG_MAPPING_FN(elementwise_add, pten::ElementwiseAddOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(elementwise_sub, diff --git a/paddle/pten/ops/compat/fill_any_like_sig.cc b/paddle/pten/ops/compat/fill_any_like_sig.cc index 39e301d633863d153de62fa1bc59d5839fabf847..81065d0c8aebd54a9b09b55ad68b900ae075485a 100644 --- a/paddle/pten/ops/compat/fill_any_like_sig.cc +++ b/paddle/pten/ops/compat/fill_any_like_sig.cc @@ -23,4 +23,6 @@ KernelSignature FillAnyLikeOpArgumentMapping( } // namespace pten +PT_REGISTER_BASE_KERNEL_NAME(fill_any_like, full_like); + PT_REGISTER_ARG_MAPPING_FN(fill_any_like, pten::FillAnyLikeOpArgumentMapping); diff --git a/paddle/pten/ops/compat/fill_constant_sig.cc b/paddle/pten/ops/compat/fill_constant_sig.cc index 6acf01c7c6f05ca536320797771233879cde782b..73dee270f7072ef4860da1a404c260aaa35a787b 100644 --- a/paddle/pten/ops/compat/fill_constant_sig.cc +++ b/paddle/pten/ops/compat/fill_constant_sig.cc @@ -68,4 +68,6 @@ KernelSignature FillConstantOpArgumentMapping( } // namespace pten +PT_REGISTER_BASE_KERNEL_NAME(fill_constant, full); + PT_REGISTER_ARG_MAPPING_FN(fill_constant, pten::FillConstantOpArgumentMapping); diff --git a/paddle/pten/ops/compat/flatten_sig.cc b/paddle/pten/ops/compat/flatten_sig.cc index f1c774401648e8f6b7daab8bd9ecb0451a6f35cd..1ef2977bf88d796c8e70cf159bb9cf1e994c986e 100644 --- a/paddle/pten/ops/compat/flatten_sig.cc +++ b/paddle/pten/ops/compat/flatten_sig.cc @@ -30,5 +30,8 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { } // namespace pten +PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten); +PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad); + PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range, pten::FlattenOpArgumentMapping); diff --git a/paddle/pten/ops/compat/matmul_sig.cc b/paddle/pten/ops/compat/matmul_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..67ef91b429e36cbac2bb789aa1853bd37302f190 --- /dev/null +++ b/paddle/pten/ops/compat/matmul_sig.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten {} // namespace pten + +PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul); +PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad); +PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad); +PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad); diff --git a/paddle/pten/ops/compat/reduce_sig.cc b/paddle/pten/ops/compat/reduce_sig.cc index 7f9171fd5811e070f162fdc72f574c25f7ed5263..a8a2b517d3e9d37e7078302a736e7438d1d0d4c3 100644 --- a/paddle/pten/ops/compat/reduce_sig.cc +++ b/paddle/pten/ops/compat/reduce_sig.cc @@ -45,5 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { } // namespace pten +PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum_raw); +PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean_raw); + PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping); diff --git a/paddle/pten/ops/compat/reshape_sig.cc b/paddle/pten/ops/compat/reshape_sig.cc index 22d39ef41109e7026e9403a90e6173bd0816a887..031b6875867a5fb7e066a18b46c90f4f8a50adc7 100644 --- a/paddle/pten/ops/compat/reshape_sig.cc +++ b/paddle/pten/ops/compat/reshape_sig.cc @@ -28,4 +28,8 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { } // namespace pten +PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape); +PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad); +PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad); + PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping);