未验证 提交 d3f52efd 编写于 作者: Z zyfncg 提交者: GitHub

Fix bug of TransToFluidOpName (#48355)

* add fluid_op_name_map

* rename some kernel name

* add comments for op-kernel map

* refine map name of op to kernel
上级 86d92092
......@@ -48,7 +48,7 @@ int main(int argc, char **argv) {
for (const auto &op_kernel_pair : kernel_factory.kernels()) {
std::string op_name = op_kernel_pair.first;
const paddle::flat_hash_map<std::string, std::string> &kernel_name_map =
phi::OpUtilsMap::Instance().base_kernel_name_map();
phi::OpUtilsMap::Instance().fluid_op_to_phi_kernel();
for (auto &it : kernel_name_map) {
if (it.second == op_name) {
op_name = it.first;
......
......@@ -633,7 +633,7 @@
func : UnchangedInferMeta
param : [x]
kernel :
func : hard_swish_grad
func : hardswish_grad
inplace : (out_grad -> x_grad)
- backward_op : hardtanh_grad
......@@ -644,7 +644,7 @@
func : UnchangedInferMeta
param : [x]
kernel :
func : hard_tanh_grad
func : hardtanh_grad
inplace : (out_grad -> x_grad)
- backward_op : hsigmoid_loss_grad
......
......@@ -171,7 +171,7 @@
infer_meta :
func : ArgMinMaxInferMeta
kernel :
func : arg_max
func : argmax
- op : argmin
args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype)
......@@ -179,7 +179,7 @@
infer_meta :
func : ArgMinMaxInferMeta
kernel :
func : arg_min
func : argmin
- op : assign
args : (Tensor x)
......@@ -914,7 +914,7 @@
func : UnchangedInferMeta
param : [x]
kernel :
func : hard_swish
func : hardswish
backward : hardswish_grad
- op : hardtanh
......@@ -924,7 +924,7 @@
func : UnchangedInferMeta
param : [x]
kernel :
func : hard_tanh
func : hardtanh
backward : hardtanh_grad
- op : hsigmoid_loss
......
......@@ -110,14 +110,11 @@ const std::string& TransToPhiKernelName(const std::string& fluid_op_name) {
}
const std::string& TransToFluidOpName(const std::string& phi_kernel_name) {
auto& base_kernel_name_map = OpUtilsMap::Instance().base_kernel_name_map();
auto it = std::find_if(base_kernel_name_map.begin(),
base_kernel_name_map.end(),
[&phi_kernel_name](const auto& pair) {
return pair.second == phi_kernel_name;
});
if (it != base_kernel_name_map.end()) {
return it->first;
const auto& phi_kernel_to_fluid_op =
OpUtilsMap::Instance().phi_kernel_to_fluid_op();
auto it = phi_kernel_to_fluid_op.find(phi_kernel_name);
if (it != phi_kernel_to_fluid_op.end()) {
return it->second;
}
return phi_kernel_name;
}
......
......@@ -131,18 +131,23 @@ class OpUtilsMap {
static OpUtilsMap& Instance();
bool Contains(const std::string& op_type) const {
return base_kernel_name_map_.count(op_type) ||
return fluid_op_to_phi_kernel_.count(op_type) ||
arg_mapping_fn_map_.count(op_type);
}
void InsertBaseKernelName(std::string op_type, std::string base_kernel_name) {
void InsertBaseKernelName(const std::string& op_type,
const std::string& base_kernel_name) {
fluid_op_to_phi_kernel_.insert({op_type, base_kernel_name});
}
void InsertFluidOplName(std::string op_type, std::string base_kernel_name) {
PADDLE_ENFORCE_EQ(
base_kernel_name_map_.count(op_type),
phi_kernel_to_fluid_op_.count(base_kernel_name),
0UL,
phi::errors::AlreadyExists(
"Operator (%s)'s api name has been registered.", op_type));
base_kernel_name_map_.insert(
{std::move(op_type), std::move(base_kernel_name)});
"Operator (%s)'s kernel name (%s) has been registered.",
op_type,
base_kernel_name));
phi_kernel_to_fluid_op_.insert({base_kernel_name, op_type});
}
bool HasArgumentMappingFn(const std::string& op_type) const {
......@@ -163,8 +168,8 @@ class OpUtilsMap {
if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) {
return deprecated_kernel_name;
}
auto it = base_kernel_name_map_.find(op_type);
if (it == base_kernel_name_map_.end()) {
auto it = fluid_op_to_phi_kernel_.find(op_type);
if (it == fluid_op_to_phi_kernel_.end()) {
return op_type;
} else {
return it->second;
......@@ -181,15 +186,23 @@ class OpUtilsMap {
}
}
const paddle::flat_hash_map<std::string, std::string>& base_kernel_name_map()
const {
return base_kernel_name_map_;
const paddle::flat_hash_map<std::string, std::string>&
fluid_op_to_phi_kernel() const {
return fluid_op_to_phi_kernel_;
}
const paddle::flat_hash_map<std::string, std::string>&
phi_kernel_to_fluid_op() const {
return phi_kernel_to_fluid_op_;
}
private:
OpUtilsMap() = default;
paddle::flat_hash_map<std::string, std::string> base_kernel_name_map_;
paddle::flat_hash_map<std::string, std::string> fluid_op_to_phi_kernel_;
paddle::flat_hash_map<std::string, std::string> phi_kernel_to_fluid_op_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> arg_mapping_fn_map_;
DISABLE_COPY_AND_ASSIGN(OpUtilsMap);
......@@ -198,6 +211,7 @@ class OpUtilsMap {
struct BaseKernelNameRegistrar {
BaseKernelNameRegistrar(const char* op_type, const char* base_kernel_name) {
OpUtilsMap::Instance().InsertBaseKernelName(op_type, base_kernel_name);
OpUtilsMap::Instance().InsertFluidOplName(op_type, base_kernel_name);
}
};
......
......@@ -45,12 +45,12 @@ using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
hard_swish, CPU, ALL_LAYOUT, phi::HardSwishKernel, float, double) {}
hardswish, CPU, ALL_LAYOUT, phi::HardSwishKernel, float, double) {}
PD_REGISTER_KERNEL(relu6, CPU, ALL_LAYOUT, phi::Relu6Kernel, float, double) {}
PD_REGISTER_KERNEL(swish, CPU, ALL_LAYOUT, phi::SwishKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(hard_swish,
PD_REGISTER_KERNEL(hardswish,
GPU,
ALL_LAYOUT,
phi::HardSwishKernel,
......@@ -80,13 +80,13 @@ PD_REGISTER_KERNEL(swish,
#endif
#if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(hard_swish, XPU, ALL_LAYOUT, phi::HardSwishKernel, float) {}
PD_REGISTER_KERNEL(hardswish, XPU, ALL_LAYOUT, phi::HardSwishKernel, float) {}
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {}
PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {}
#endif
#ifdef PADDLE_WITH_MKLDNN
PD_REGISTER_KERNEL(hard_swish,
PD_REGISTER_KERNEL(hardswish,
OneDNN,
ONEDNN,
phi::HardSwishKernel,
......
......@@ -263,7 +263,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_tanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
ThresholdedReluGradKernel)
......@@ -388,7 +388,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
......
......@@ -146,7 +146,7 @@ PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
......@@ -183,7 +183,7 @@ PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
......
......@@ -180,7 +180,7 @@ void ArgMaxKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(arg_min,
PD_REGISTER_KERNEL(argmin,
CPU,
ALL_LAYOUT,
phi::ArgMinKernel,
......@@ -191,7 +191,7 @@ PD_REGISTER_KERNEL(arg_min,
int16_t,
uint8_t) {}
PD_REGISTER_KERNEL(arg_max,
PD_REGISTER_KERNEL(argmax,
CPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
......
......@@ -347,7 +347,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_double_grad, TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_triple_grad, TanhTripleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_tanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel)
......@@ -474,7 +474,7 @@ PD_REGISTER_KERNEL(log_double_grad,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
......
......@@ -196,7 +196,7 @@ PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
......@@ -254,7 +254,7 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
......
......@@ -248,7 +248,7 @@ void ArgMaxKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(arg_min,
PD_REGISTER_KERNEL(argmin,
GPU,
ALL_LAYOUT,
phi::ArgMinKernel,
......@@ -261,7 +261,7 @@ PD_REGISTER_KERNEL(arg_min,
int16_t,
uint8_t) {}
PD_REGISTER_KERNEL(arg_max,
PD_REGISTER_KERNEL(argmax,
GPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
......
......@@ -273,7 +273,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
......
......@@ -202,7 +202,7 @@ PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel)
PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel)
......
......@@ -617,7 +617,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_sigmoid_grad, HardSigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
......
......@@ -486,7 +486,7 @@ PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
......
......@@ -65,4 +65,4 @@ void ArgMaxKernel(const Context& dev_ctx,
XPUAPIErrorMsg[r]));
}
} // namespace phi
PD_REGISTER_KERNEL(arg_max, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {}
PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {}
......@@ -39,10 +39,10 @@ namespace phi {
#define comma ,
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hard_tanh", "t_min" comma "t_max");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hardtanh", "t_min" comma "t_max");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish,
"hard_swish",
"hardswish",
"threshold" comma "scale" comma
"offset"); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta"); // NOLINT
......@@ -55,7 +55,7 @@ DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold"); // NOLINT
KernelSignature HardSwishOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"hard_swish_raw", {"X"}, {"threshold", "scale", "offset"}, {"Out"});
"hardswish_raw", {"X"}, {"threshold", "scale", "offset"}, {"Out"});
}
KernelSignature SwishOpArgumentMapping(const ArgumentMappingContext& ctx) {
......@@ -113,8 +113,10 @@ KernelSignature PowTripleGradOpArgumentMapping(
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh);
PD_REGISTER_BASE_KERNEL_NAME(brelu_grad, hard_tanh_grad);
PD_REGISTER_BASE_KERNEL_NAME(brelu, hardtanh);
PD_REGISTER_BASE_KERNEL_NAME(brelu_grad, hardtanh_grad);
PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish);
PD_REGISTER_BASE_KERNEL_NAME(hard_swish_grad, hardswish_grad);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
PD_REGISTER_BASE_KERNEL_NAME(arg_max, argmax);
PD_REGISTER_BASE_KERNEL_NAME(arg_min, argmin);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册