未验证 提交 5dc20c27 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Remove kernel alias name (#39321)

* remove kernel alias name

* fix depreacted error

* fix deprecated failed

* fix mean error

* resolve conflict

* fix windows failed
上级 34cce62f
......@@ -282,9 +282,9 @@ endfunction()
function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content)
string(REGEX MATCH "(PT_REGISTER_API_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REGEX MATCH "(PT_REGISTER_BASE_KERNEL_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}")
string(REPLACE "PT_REGISTER_API_NAME" "PT_REGISTER_API_NAME" util_declare "${util_declare}")
string(REPLACE "PT_REGISTER_BASE_KERNEL_NAME" "PT_DECLARE_BASE_KERNEL_NAME" util_declare "${util_declare}")
string(APPEND util_declare ");")
file(APPEND ${op_utils_header} "${util_declare}")
endfunction()
......
......@@ -185,9 +185,8 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
}
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()),
GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
return KernelSignature(op_proto_->type(), GetInputArgsNames(),
GetAttrsArgsNames(), GetOutputArgsNames());
}
std::once_flag kernel_sig_map_init_flag;
......
cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce)
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce)
if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils gpu_info)
elseif(WITH_ROCM)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils gpu_info)
elseif(WITH_XPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place xpu_info)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils xpu_info)
elseif(WITH_ASCEND_CL)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place npu_info)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils npu_info)
else()
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils)
endif()
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce convert_utils)
......@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/compat/kernel_alias_name.h"
#include "paddle/pten/core/compat/op_utils.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
......@@ -235,21 +234,18 @@ std::string DataType2String(DataType dtype) {
}
}
const std::string& TransToPtenKernelName(const std::string& fluid_op_name) {
if (kernel_alias_name_map.find(fluid_op_name) !=
kernel_alias_name_map.end()) {
return kernel_alias_name_map.at(fluid_op_name);
}
return fluid_op_name;
std::string TransToPtenKernelName(const std::string& fluid_op_name) {
return OpUtilsMap::Instance().GetBaseKernelName(fluid_op_name);
}
const std::string& TransToFluidOpName(const std::string& pten_kernel_name) {
auto it = std::find_if(kernel_alias_name_map.begin(),
kernel_alias_name_map.end(),
auto& base_kernel_name_map = OpUtilsMap::Instance().base_kernel_name_map();
auto it = std::find_if(base_kernel_name_map.begin(),
base_kernel_name_map.end(),
[&pten_kernel_name](const auto& pair) {
return pair.second == pten_kernel_name;
});
if (it != kernel_alias_name_map.end()) {
if (it != base_kernel_name_map.end()) {
return it->first;
}
return pten_kernel_name;
......
......@@ -27,7 +27,7 @@ limitations under the License. */
namespace pten {
const std::string& TransToPtenKernelName(const std::string& fluid_op_name);
std::string TransToPtenKernelName(const std::string& fluid_op_name);
const std::string& TransToFluidOpName(const std::string& pten_kernel_name);
Backend TransToPtenBackend(const pten::Place& place);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// TODO(yuanrisheng): this file may need to be removed
#pragma once
namespace pten {
// the key is kernel_name in fluid, the value is the kernel_name in pten
// the key is sorted by key's alphabet
const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"elementwise_add", "add_raw"},
{"elementwise_add_grad", "add_grad"},
{"elementwise_div", "divide_raw"},
{"elementwise_mul", "muliply_raw"},
{"elementwise_sub", "subtract_raw"},
{"elementwise_sub_grad", "subtract_grad"},
{"fill_any_like", "full_like"},
{"fill_constant", "full"},
{"flatten_contiguous_range", "flatten"},
{"flatten_contiguous_range_grad", "flatten_grad"},
{"matmul_v2", "matmul"},
{"matmul_v2_grad", "matmul_grad"},
{"matmul_v2_grad_grad", "matmul_double_grad"},
{"matmul_v2_triple_grad", "matmul_triple_grad"},
{"reduce_mean", "mean_raw"},
{"reduce_sum", "sum_raw"},
{"reshape2", "reshape"},
{"reshape2_grad", "reshape_grad"},
{"reshape2_grad_grad", "reshape_double_grad"},
// fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated
{"flatten", "deprecated"},
{"flatten_grad", "deprecated"},
{"matmul", "deprecated"},
{"matmul_grad", "deprecated"},
{"matmul_grad_grad", "deprecated"},
{"mean", "deprecated"},
{"reshape", "deprecated"},
{"reshape_grad", "deprecated"},
{"sum", "deprecated"}};
} // namespace pten
......@@ -14,7 +14,8 @@ limitations under the License. */
#pragma once
#include <mutex>
#include <string>
#include <unordered_set>
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/enforce.h"
......@@ -25,6 +26,22 @@ limitations under the License. */
namespace pten {
/**
* Some fluid ops are no longer used under the corresponding official API
* system of 2.0. These names need to correspond to the official API names
* after 2.0, and can no longer be occupied by the previously abandoned ops.
* They are marked here uniformly.
*/
const std::unordered_set<std::string> deprecated_op_names({"flatten",
"flatten_grad",
"matmul",
"matmul_grad",
"matmul_grad_grad",
"mean",
"reshape",
"reshape_grad",
"sum"});
class DefaultKernelSignatureMap {
public:
static DefaultKernelSignatureMap& Instance();
......@@ -63,16 +80,18 @@ class OpUtilsMap {
static OpUtilsMap& Instance();
bool Contains(const std::string& op_type) const {
return name_map_.count(op_type) || arg_mapping_fn_map_.count(op_type);
return base_kernel_name_map_.count(op_type) ||
arg_mapping_fn_map_.count(op_type);
}
void InsertApiName(std::string op_type, std::string api_name) {
void InsertBaseKernelName(std::string op_type, std::string base_kernel_name) {
PADDLE_ENFORCE_EQ(
name_map_.count(op_type),
base_kernel_name_map_.count(op_type),
0UL,
pten::errors::AlreadyExists(
"Operator (%s)'s api name has been registered.", op_type));
name_map_.insert({std::move(op_type), std::move(api_name)});
base_kernel_name_map_.insert(
{std::move(op_type), std::move(base_kernel_name)});
}
void InsertArgumentMappingFn(std::string op_type, ArgumentMappingFn fn) {
......@@ -85,10 +104,13 @@ class OpUtilsMap {
arg_mapping_fn_map_.insert({std::move(op_type), std::move(fn)});
}
std::string GetApiName(const std::string& op_type) const {
auto it = name_map_.find(op_type);
if (it == name_map_.end()) {
std::string GetBaseKernelName(const std::string& op_type) const {
if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) {
return "deprecated";
}
auto it = base_kernel_name_map_.find(op_type);
if (it == base_kernel_name_map_.end()) {
return op_type;
} else {
return it->second;
}
......@@ -107,18 +129,23 @@ class OpUtilsMap {
}
}
const paddle::flat_hash_map<std::string, std::string>& base_kernel_name_map()
const {
return base_kernel_name_map_;
}
private:
OpUtilsMap() = default;
paddle::flat_hash_map<std::string, std::string> name_map_;
paddle::flat_hash_map<std::string, std::string> base_kernel_name_map_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> arg_mapping_fn_map_;
DISABLE_COPY_AND_ASSIGN(OpUtilsMap);
};
struct ApiNameRegistrar {
ApiNameRegistrar(const char* op_type, const char* api_name) {
OpUtilsMap::Instance().InsertApiName(op_type, api_name);
struct BaseKernelNameRegistrar {
BaseKernelNameRegistrar(const char* op_type, const char* base_kernel_name) {
OpUtilsMap::Instance().InsertBaseKernelName(op_type, base_kernel_name);
}
};
......@@ -130,21 +157,21 @@ struct ArgumentMappingFnRegistrar {
}
};
#define PT_REGISTER_API_NAME(op_type, api_name) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_api_name_ns_check_##op_type, \
"PT_REGISTER_API_NAME must be called in global namespace."); \
static const ::pten::ApiNameRegistrar __registrar_api_name_for_##op_type( \
#op_type, #api_name); \
int TouchApiNameSymbol_##op_type() { return 0; }
#define PT_DECLARE_API_NAME(op_type) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_declare_ai_name_ns_check_##op_type, \
"PT_DECLARE_API_NAME must be called in global namespace."); \
extern int TouchApiNameSymbol_##op_type(); \
UNUSED static int __declare_api_name_symbol_for_##op_type = \
TouchApiNameSymbol_##op_type()
#define PT_REGISTER_BASE_KERNEL_NAME(op_type, base_kernel_name) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_base_kernel_name_ns_check_##op_type, \
"PT_REGISTER_BASE_KERNEL_NAME must be called in global namespace."); \
static const ::pten::BaseKernelNameRegistrar \
__registrar_base_kernel_name_for_##op_type(#op_type, #base_kernel_name); \
int TouchBaseKernelNameSymbol_##op_type() { return 0; }
#define PT_DECLARE_BASE_KERNEL_NAME(op_type) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_declare_ai_name_ns_check_##op_type, \
"PT_DECLARE_BASE_KERNEL_NAME must be called in global namespace."); \
extern int TouchBaseKernelNameSymbol_##op_type(); \
UNUSED static int __declare_base_kernel_name_symbol_for_##op_type = \
TouchBaseKernelNameSymbol_##op_type()
#define PT_REGISTER_ARG_MAPPING_FN(op_type, arg_mapping_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
......
......@@ -66,6 +66,13 @@ KernelSignature ElementwiseDivOpArgumentMapping(
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, muliply_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
PT_REGISTER_ARG_MAPPING_FN(elementwise_add,
pten::ElementwiseAddOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(elementwise_sub,
......
......@@ -23,4 +23,6 @@ KernelSignature FillAnyLikeOpArgumentMapping(
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(fill_any_like, full_like);
PT_REGISTER_ARG_MAPPING_FN(fill_any_like, pten::FillAnyLikeOpArgumentMapping);
......@@ -68,4 +68,6 @@ KernelSignature FillConstantOpArgumentMapping(
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(fill_constant, full);
PT_REGISTER_ARG_MAPPING_FN(fill_constant, pten::FillConstantOpArgumentMapping);
......@@ -30,5 +30,8 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range, flatten);
PT_REGISTER_BASE_KERNEL_NAME(flatten_contiguous_range_grad, flatten_grad);
PT_REGISTER_ARG_MAPPING_FN(flatten_contiguous_range,
pten::FlattenOpArgumentMapping);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2, matmul);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad, matmul_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_grad_grad, matmul_double_grad);
PT_REGISTER_BASE_KERNEL_NAME(matmul_v2_triple_grad, matmul_triple_grad);
......@@ -45,5 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum_raw);
PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean_raw);
PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping);
......@@ -28,4 +28,8 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(reshape2, reshape);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad, reshape_grad);
PT_REGISTER_BASE_KERNEL_NAME(reshape2_grad_grad, reshape_double_grad);
PT_REGISTER_ARG_MAPPING_FN(reshape2, pten::ReshapeOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册