未验证 提交 5dc20c27 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Remove kernel alias name (#39321)

* remove kernel alias name

* fix depreacted error

* fix deprecated failed

* fix mean error

* resolve conflict

* fix windows failed
上级 34cce62f
...@@ -282,9 +282,9 @@ endfunction() ...@@ -282,9 +282,9 @@ endfunction()
function(append_op_util_declare TARGET) function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content) file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content)
string(REGEX MATCH "(PT_REGISTER_API_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}") string(REGEX MATCH "(PT_REGISTER_BASE_KERNEL_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}") string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}")
string(REPLACE "PT_REGISTER_API_NAME" "PT_REGISTER_API_NAME" util_declare "${util_declare}") string(REPLACE "PT_REGISTER_BASE_KERNEL_NAME" "PT_DECLARE_BASE_KERNEL_NAME" util_declare "${util_declare}")
string(APPEND util_declare ");") string(APPEND util_declare ");")
file(APPEND ${op_utils_header} "${util_declare}") file(APPEND ${op_utils_header} "${util_declare}")
endfunction() endfunction()
......
...@@ -185,9 +185,8 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { ...@@ -185,9 +185,8 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
} }
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()), return KernelSignature(op_proto_->type(), GetInputArgsNames(),
GetInputArgsNames(), GetAttrsArgsNames(), GetAttrsArgsNames(), GetOutputArgsNames());
GetOutputArgsNames());
} }
std::once_flag kernel_sig_map_init_flag; std::once_flag kernel_sig_map_init_flag;
......
cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce) cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce)
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce)
if(WITH_GPU) if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils gpu_info)
elseif(WITH_ROCM) elseif(WITH_ROCM)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils gpu_info)
elseif(WITH_XPU) elseif(WITH_XPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place xpu_info) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils xpu_info)
elseif(WITH_ASCEND_CL) elseif(WITH_ASCEND_CL)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place npu_info) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils npu_info)
else() else()
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils)
endif() endif()
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce convert_utils)
...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/pten/core/compat/convert_utils.h" #include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/compat/kernel_alias_name.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
...@@ -235,21 +234,18 @@ std::string DataType2String(DataType dtype) { ...@@ -235,21 +234,18 @@ std::string DataType2String(DataType dtype) {
} }
} }
const std::string& TransToPtenKernelName(const std::string& fluid_op_name) { std::string TransToPtenKernelName(const std::string& fluid_op_name) {
if (kernel_alias_name_map.find(fluid_op_name) != return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name);
kernel_alias_name_map.end()) {
return kernel_alias_name_map.at(fluid_op_name);
}
return fluid_op_name;
} }
const std::string& TransToFluidOpName(const std::string& pten_kernel_name) { const std::string& TransToFluidOpName(const std::string& pten_kernel_name) {
auto it = std::find_if(kernel_alias_name_map.begin(), auto& base_kernel_name_map = OpUtilsMap::Instance().base_kernel_name_map();
kernel_alias_name_map.end(), auto it = std::find_if(base_kernel_name_map.begin(),
base_kernel_name_map.end(),
[&pten_kernel_name](const auto& pair) { [&pten_kernel_name](const auto& pair) {
return pair.second == pten_kernel_name; return pair.second == pten_kernel_name;
}); });
if (it != kernel_alias_name_map.end()) { if (it != base_kernel_name_map.end()) {
return it->first; return it->first;
} }
return pten_kernel_name; return pten_kernel_name;
......
...@@ -27,7 +27,7 @@ limitations under the License. */ ...@@ -27,7 +27,7 @@ limitations under the License. */
namespace pten { namespace pten {
const std::string& TransToPtenKernelName(const std::string& fluid_op_name); std::string TransToPtenKernelName(const std::string& fluid_op_name);
const std::string& TransToFluidOpName(const std::string& pten_kernel_name); const std::string& TransToFluidOpName(const std::string& pten_kernel_name);
Backend TransToPtenBackend(const pten::Place& place); Backend TransToPtenBackend(const pten::Place& place);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// TODO(yuanrisheng): this file may need to be removed
#pragma once
namespace pten {
// the key is kernel_name in fluid, the value is the kernel_name in pten
// the key is sorted by key's alphabet
const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"elementwise_add", "add_raw"},
{"elementwise_add_grad", "add_grad"},
{"elementwise_div", "divide_raw"},
{"elementwise_mul", "muliply_raw"},
{"elementwise_sub", "subtract_raw"},
{"elementwise_sub_grad", "subtract_grad"},
{"fill_any_like", "full_like"},
{"fill_constant", "full"},
{"flatten_contiguous_range", "flatten"},
{"flatten_contiguous_range_grad", "flatten_grad"},
{"matmul_v2", "matmul"},
{"matmul_v2_grad", "matmul_grad"},
{"matmul_v2_grad_grad", "matmul_double_grad"},
{"matmul_v2_triple_grad", "matmul_triple_grad"},
{"reduce_mean", "mean_raw"},
{"reduce_sum", "sum_raw"},
{"reshape2", "reshape"},
{"reshape2_grad", "reshape_grad"},
{"reshape2_grad_grad", "reshape_double_grad"},
// fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated
{"flatten", "deprecated"},
{"flatten_grad", "deprecated"},
{"matmul", "deprecated"},
{"matmul_grad", "deprecated"},
{"matmul_grad_grad", "deprecated"},
{"mean", "deprecated"},
{"reshape", "deprecated"},
{"reshape_grad", "deprecated"},
{"sum", "deprecated"}};
} // namespace pten
...@@ -14,7 +14,8 @@ limitations under the License. */ ...@@ -14,7 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <mutex> #include <string>
#include <unordered_set>
#include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/enforce.h" #include "paddle/pten/core/enforce.h"
...@@ -25,6 +26,22 @@ limitations under the License. */ ...@@ -25,6 +26,22 @@ limitations under the License. */
namespace pten { namespace pten {
/**
* Some fluid ops are no longer used under the corresponding official API
* system of 2.0. These names need to correspond to the official API names
* after 2.0, and can no longer be occupied by the previously abandoned ops.
* They are marked here uniformly.
*/
const std::unordered_set<std::string> deprecated_op_names({"flatten",
"flatten_grad",
"matmul",
"matmul_grad",
"matmul_grad_grad",
"mean",
"reshape",
"reshape_grad",
"sum"});
class DefaultKernelSignatureMap { class DefaultKernelSignatureMap {
public: public:
static DefaultKernelSignatureMap& Instance(); static DefaultKernelSignatureMap& Instance();
...@@ -63,16 +80,18 @@ class OpUtilsMap { ...@@ -63,16 +80,18 @@ class OpUtilsMap {
static OpUtilsMap& Instance(); static OpUtilsMap& Instance();
bool Contains(const std::string& op_type) const { bool Contains(const std::string& op_type) const {
return name_map_.count(op_type) || arg_mapping_fn_map_.count(op_type); return base_kernel_name_map_.count(op_type) ||
arg_mapping_fn_map_.count(op_type);
} }
void InsertApiName(std::string op_type, std::string api_name) { void InsertBaseKernelName(std::string op_type, std::string base_kernel_name) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
name_map_.count(op_type), base_kernel_name_map_.count(op_type),
0UL, 0UL,
pten::errors::AlreadyExists( pten::errors::AlreadyExists(
"Operator (%s)'s api name has been registered.", op_type)); "Operator (%s)'s api name has been registered.", op_type));
name_map_.insert({std::move(op_type), std::move(api_name)}); base_kernel_name_map_.insert(
{std::move(op_type), std::move(base_kernel_name)});
} }
void InsertArgumentMappingFn(std::string op_type, ArgumentMappingFn fn) { void InsertArgumentMappingFn(std::string op_type, ArgumentMappingFn fn) {
...@@ -85,10 +104,13 @@ class OpUtilsMap { ...@@ -85,10 +104,13 @@ class OpUtilsMap {
arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)}); arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)});
} }
std::string GetApiName(const std::string& op_type) const { std::string GetBaseKernelName(const std::string& op_type) const {
auto it = name_map_.find(op_type); if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) {
if (it == name_map_.end()) {
return "deprecated"; return "deprecated";
}
auto it = base_kernel_name_map_.find(op_type);
if (it == base_kernel_name_map_.end()) {
return op_type;
} else { } else {
return it->second; return it->second;
} }
...@@ -107,18 +129,23 @@ class OpUtilsMap { ...@@ -107,18 +129,23 @@ class OpUtilsMap {
} }
} }
const paddle::flat_hash_map<std::string, std::string>& base_kernel_name_map()
const {
return base_kernel_name_map_;
}
private: private:
OpUtilsMap() = default; OpUtilsMap() = default;
paddle::flat_hash_map<std::string, std::string> name_map_; paddle::flat_hash_map<std::string, std::string> base_kernel_name_map_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> arg_mapping_fn_map_; paddle::flat_hash_map<std::string, ArgumentMappingFn> arg_mapping_fn_map_;
DISABLE_COPY_AND_ASSIGN(OpUtilsMap); DISABLE_COPY_AND_ASSIGN(OpUtilsMap);
}; };
struct ApiNameRegistrar { struct BaseKernelNameRegistrar {
ApiNameRegistrar(const char* op_type, const char* api_name) { BaseKernelNameRegistrar(const char* op_type, const char* base_kernel_name) {
OpUtilsMap::Instance().InsertApiName(op_type, api_name); OpUtilsMap::Instance().InsertBaseKernelName(op_type, base_kernel_name);
} }
}; };
...@@ -130,21 +157,21 @@ struct ArgumentMappingFnRegistrar { ...@@ -130,21 +157,21 @@ struct ArgumentMappingFnRegistrar {
} }
}; };
#define PT_REGISTER_API_NAME(op_type, api_name) \ #define PT_REGISTER_BASE_KERNEL_NAME(op_type, base_kernel_name) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_api_name_ns_check_##op_type, \ pt_register_base_kernel_name_ns_check_##op_type, \
"PT_REGISTER_API_NAME must be called in global namespace."); \ "PT_REGISTER_BASE_KERNEL_NAME must be called in global namespace."); \
static const ::pten::ApiNameRegistrar __registrar_api_name_for_##op_type( \ static const ::pten::BaseKernelNameRegistrar \
#op_type, #api_name); \ __registrar_base_kernel_name_for_##op_type(#op_type, #base_kernel_name); \
int TouchApiNameSymbol_##op_type() { return 0; } int TouchBaseKernelNameSymbol_##op_type() { return 0; }
#define PT_DECLARE_API_NAME(op_type) \ #define PT_DECLARE_BASE_KERNEL_NAME(op_type) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_declare_ai_name_ns_check_##op_type, \ pt_declare_ai_name_ns_check_##op_type, \
"PT_DECLARE_API_NAME must be called in global namespace."); \ "PT_DECLARE_BASE_KERNEL_NAME must be called in global namespace."); \
extern int TouchApiNameSymbol_##op_type(); \ extern int TouchBaseKernelNameSymbol_##op_type(); \
UNUSED static int __declare_api_name_symbol_for_##op_type = \ UNUSED static int __declare_base_kernel_name_symbol_for_##op_type = \
TouchApiNameSymbol_##op_type() TouchBaseKernelNameSymbol_##op_type()
#define PT_REGISTER_ARG_MAPPING_FN(op_type, arg_mapping_fn) \ #define PT_REGISTER_ARG_MAPPING_FN(op_type, arg_mapping_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
......
...@@ -66,6 +66,13 @@ KernelSignature ElementwiseDivOpArgumentMapping( ...@@ -66,6 +66,13 @@ KernelSignature ElementwiseDivOpArgumentMapping(
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, muliply_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
PT_REGISTER_ARG_MAPPING_FN(elementwise_add, PT_REGISTER_ARG_MAPPING_FN(elementwise_add,
pten::ElementwiseAddOpArgumentMapping); pten::ElementwiseAddOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_sub, PT_REGISTER_ARG_MAPPING_FN(elementwise_sub,
......
...@@ -23,4 +23,6 @@ KernelSignature FillAnyLikeOpArgumentMapping( ...@@ -23,4 +23,6 @@ KernelSignature FillAnyLikeOpArgumentMapping(
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(fill_any_like, full_like);
PT_REGISTER_ARG_MAPPING_FN(fill_any_like, pten::FillAnyLikeOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(fill_any_like, pten::FillAnyLikeOpArgumentMapping);
...@@ -68,4 +68,6 @@ KernelSignature FillConstantOpArgumentMapping( ...@@ -68,4 +68,6 @@ KernelSignature FillConstantOpArgumentMapping(
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(fill_constant, full);
PT_REGISTER_ARG_MAPPING_FN(fill_constant, pten::FillConstantOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(fill_constant, pten::FillConstantOpArgumentMapping);
...@@ -30,5 +30,8 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -30,5 +30,8 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten);
PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad);
PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range, PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range,
pten::FlattenOpArgumentMapping); pten::FlattenOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad);
...@@ -45,5 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -45,5 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum_raw);
PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean_raw);
PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping);
...@@ -28,4 +28,8 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -28,4 +28,8 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad);
PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册