未验证 提交 e337d280 编写于 作者: Z zyfncg 提交者: GitHub

Fix the name map of operator from Phi to fluid (#48496)

* rename some kernel name

* fix compile problem
上级 35902ec6
......@@ -111,19 +111,31 @@ function(kernel_declare TARGET_LIST)
endfunction()
function(append_op_util_declare TARGET)
file(READ ${TARGET} target_content)
string(REGEX MATCH "(PD_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*"
util_registrar "${target_content}")
if(NOT ${util_registrar} EQUAL "")
string(REPLACE "PD_REGISTER_ARG_MAPPING_FN" "PD_DECLARE_ARG_MAPPING_FN"
util_declare "${util_registrar}")
string(APPEND util_declare ");\n")
file(APPEND ${op_utils_header} "${util_declare}")
endif()
endfunction()
function(append_op_kernel_map_declare TARGET)
file(READ ${TARGET} target_content)
string(
REGEX
MATCH
"(PD_REGISTER_BASE_KERNEL_NAME|PD_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*"
util_registrar
"(PD_REGISTER_BASE_KERNEL_NAME)\\([ \t\r\n]*[a-z0-9_]*,[ \\\t\r\n]*[a-z0-9_]*"
kernel_mapping_registrar
"${target_content}")
string(REPLACE "PD_REGISTER_ARG_MAPPING_FN" "PD_DECLARE_ARG_MAPPING_FN"
util_declare "${util_registrar}")
string(REPLACE "PD_REGISTER_BASE_KERNEL_NAME" "PD_DECLARE_BASE_KERNEL_NAME"
util_declare "${util_declare}")
string(APPEND util_declare ");\n")
file(APPEND ${op_utils_header} "${util_declare}")
if(NOT ${kernel_mapping_registrar} EQUAL "")
string(REPLACE "PD_REGISTER_BASE_KERNEL_NAME" "PD_DECLARE_BASE_KERNEL_NAME"
kernel_mapping_declare "${kernel_mapping_registrar}")
string(APPEND kernel_mapping_declare ");\n")
file(APPEND ${op_utils_header} "${kernel_mapping_declare}")
endif()
endfunction()
function(register_op_utils TARGET_NAME)
......@@ -137,6 +149,7 @@ function(register_op_utils TARGET_NAME)
file(GLOB SIGNATURES "${PADDLE_SOURCE_DIR}/paddle/phi/ops/compat/*_sig.cc")
foreach(target ${SIGNATURES})
append_op_util_declare(${target})
append_op_kernel_map_declare(${target})
list(APPEND utils_srcs ${target})
endforeach()
......
......@@ -62,7 +62,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SizeOpNoNeedBufferVarInferer, "Input");
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(size,
SizeInferShapeFunctor,
PD_INFER_META(phi::SizeInferMeta));
PD_INFER_META(phi::NumelInferMeta));
REGISTER_OPERATOR(
size,
ops::SizeOp,
......
......@@ -1465,9 +1465,9 @@
args : (Tensor x)
output : Tensor(size)
infer_meta :
func : SizeInferMeta
func : NumelInferMeta
kernel :
func : size
func : numel
data_transform:
skip_transform : x
......
......@@ -223,21 +223,22 @@ struct ArgumentMappingFnRegistrar {
}
};
#define PD_REGISTER_BASE_KERNEL_NAME(op_type, base_kernel_name) \
#define PD_REGISTER_BASE_KERNEL_NAME(op_type, base_kernel_name) \
PD_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PD_REGISTER_base_kernel_name_ns_check_##base_kernel_name, \
"PD_REGISTER_BASE_KERNEL_NAME must be called in global namespace."); \
static const ::phi::BaseKernelNameRegistrar \
__registrar_base_kernel_name_for_##base_kernel_name(#op_type, \
#base_kernel_name); \
int TouchBaseKernelNameSymbol_##base_kernel_name() { return 0; }
#define PD_DECLARE_BASE_KERNEL_NAME(op_type, base_kernel_name) \
PD_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PD_REGISTER_base_kernel_name_ns_check_##op_type, \
"PD_REGISTER_BASE_KERNEL_NAME must be called in global namespace."); \
static const ::phi::BaseKernelNameRegistrar \
__registrar_base_kernel_name_for_##op_type(#op_type, #base_kernel_name); \
int TouchBaseKernelNameSymbol_##op_type() { return 0; }
#define PD_DECLARE_BASE_KERNEL_NAME(op_type) \
PD_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PD_DECLARE_ai_name_ns_check_##op_type, \
"PD_DECLARE_BASE_KERNEL_NAME must be called in global namespace."); \
extern int TouchBaseKernelNameSymbol_##op_type(); \
UNUSED static int __declare_base_kernel_name_symbol_for_##op_type = \
TouchBaseKernelNameSymbol_##op_type()
PD_DECLARE_ai_name_ns_check_##base_kernel_name, \
"PD_DECLARE_BASE_KERNEL_NAME must be called in global namespace."); \
extern int TouchBaseKernelNameSymbol_##base_kernel_name(); \
UNUSED static int __declare_base_kernel_name_symbol_for_##base_kernel_name = \
TouchBaseKernelNameSymbol_##base_kernel_name()
#define PD_REGISTER_ARG_MAPPING_FN(op_type, arg_mapping_fn) \
PD_STATIC_ASSERT_GLOBAL_NAMESPACE( \
......
......@@ -3243,7 +3243,7 @@ void ShardIndexInferMeta(const MetaTensor& in,
out->set_dtype(in.dtype());
}
void SizeInferMeta(const MetaTensor& input, MetaTensor* out) {
void NumelInferMeta(const MetaTensor& input, MetaTensor* out) {
out->set_dtype(DataType::INT64);
if (input.dims().size() == 0) {
out->set_dims(phi::make_ddim({}));
......
......@@ -478,7 +478,7 @@ void ShardIndexInferMeta(const MetaTensor& in,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SizeInferMeta(const MetaTensor& input, MetaTensor* out);
void NumelInferMeta(const MetaTensor& input, MetaTensor* out);
void SliceRawInferMeta(const MetaTensor& input,
const std::vector<int64_t>& axes,
......
......@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/size_kernel.h"
#include "paddle/phi/kernels/numel_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/size_kernel_impl.h"
#include "paddle/phi/kernels/impl/numel_kernel_impl.h"
PD_REGISTER_KERNEL(size,
PD_REGISTER_KERNEL(numel,
CPU,
ALL_LAYOUT,
phi::SizeKernel,
phi::NumelKernel,
uint8_t,
int16_t,
int,
......
......@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/size_kernel.h"
#include "paddle/phi/kernels/numel_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/size_kernel_impl.h"
#include "paddle/phi/kernels/impl/numel_kernel_impl.h"
PD_REGISTER_KERNEL(size,
PD_REGISTER_KERNEL(numel,
GPU,
ALL_LAYOUT,
phi::SizeKernel,
phi::NumelKernel,
int16_t,
int,
int64_t,
......
......@@ -19,9 +19,9 @@
namespace phi {
template <typename T, typename Context>
void SizeKernel(const Context& ctx,
const DenseTensor& input,
DenseTensor* out) {
void NumelKernel(const Context& ctx,
const DenseTensor& input,
DenseTensor* out) {
auto place = ctx.GetPlace();
auto out_data = ctx.template Alloc<int64_t>(out);
......
......@@ -19,6 +19,8 @@
namespace phi {
template <typename T, typename Context>
void SizeKernel(const Context& ctx, const DenseTensor& input, DenseTensor* out);
void NumelKernel(const Context& ctx,
const DenseTensor& input,
DenseTensor* out);
} // namespace phi
......@@ -31,5 +31,7 @@ KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(einsum, einsum_raw);
PD_REGISTER_ARG_MAPPING_FN(einsum, phi::EinsumOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(einsum_grad, phi::EinsumGradOpArgumentMapping);
......@@ -58,6 +58,11 @@ KernelSignature EmbeddingGradOpArgumentMapping(
PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding);
PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad);
PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_sparse_grad);
PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad,
sparse_weight_embedding_grad);
PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad,
sparse_weight_embedding_sparse_grad);
PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
PD_REGISTER_BASE_KERNEL_NAME(size, numel);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册