未验证 提交 d3f52efd 编写于 作者: Z zyfncg 提交者: GitHub

Fix bug of TransToFluidOpName (#48355)

* add fluid_op_name_map

* rename some kernel name

* add comments for op-kernel map

* refine map name of op to kernel
上级 86d92092
...@@ -48,7 +48,7 @@ int main(int argc, char **argv) { ...@@ -48,7 +48,7 @@ int main(int argc, char **argv) {
for (const auto &op_kernel_pair : kernel_factory.kernels()) { for (const auto &op_kernel_pair : kernel_factory.kernels()) {
std::string op_name = op_kernel_pair.first; std::string op_name = op_kernel_pair.first;
const paddle::flat_hash_map<std::string, std::string> &kernel_name_map = const paddle::flat_hash_map<std::string, std::string> &kernel_name_map =
phi::OpUtilsMap::Instance().base_kernel_name_map(); phi::OpUtilsMap::Instance().fluid_op_to_phi_kernel();
for (auto &it : kernel_name_map) { for (auto &it : kernel_name_map) {
if (it.second == op_name) { if (it.second == op_name) {
op_name = it.first; op_name = it.first;
......
...@@ -633,7 +633,7 @@ ...@@ -633,7 +633,7 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : hard_swish_grad func : hardswish_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : hardtanh_grad - backward_op : hardtanh_grad
...@@ -644,7 +644,7 @@ ...@@ -644,7 +644,7 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : hard_tanh_grad func : hardtanh_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : hsigmoid_loss_grad - backward_op : hsigmoid_loss_grad
......
...@@ -171,7 +171,7 @@ ...@@ -171,7 +171,7 @@
infer_meta : infer_meta :
func : ArgMinMaxInferMeta func : ArgMinMaxInferMeta
kernel : kernel :
func : arg_max func : argmax
- op : argmin - op : argmin
args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype) args : (Tensor x, Scalar axis, bool keepdims, bool flatten, int dtype)
...@@ -179,7 +179,7 @@ ...@@ -179,7 +179,7 @@
infer_meta : infer_meta :
func : ArgMinMaxInferMeta func : ArgMinMaxInferMeta
kernel : kernel :
func : arg_min func : argmin
- op : assign - op : assign
args : (Tensor x) args : (Tensor x)
...@@ -914,7 +914,7 @@ ...@@ -914,7 +914,7 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : hard_swish func : hardswish
backward : hardswish_grad backward : hardswish_grad
- op : hardtanh - op : hardtanh
...@@ -924,7 +924,7 @@ ...@@ -924,7 +924,7 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [x] param : [x]
kernel : kernel :
func : hard_tanh func : hardtanh
backward : hardtanh_grad backward : hardtanh_grad
- op : hsigmoid_loss - op : hsigmoid_loss
......
...@@ -110,14 +110,11 @@ const std::string& TransToPhiKernelName(const std::string& fluid_op_name) { ...@@ -110,14 +110,11 @@ const std::string& TransToPhiKernelName(const std::string& fluid_op_name) {
} }
const std::string& TransToFluidOpName(const std::string& phi_kernel_name) { const std::string& TransToFluidOpName(const std::string& phi_kernel_name) {
auto& base_kernel_name_map = OpUtilsMap::Instance().base_kernel_name_map(); const auto& phi_kernel_to_fluid_op =
auto it = std::find_if(base_kernel_name_map.begin(), OpUtilsMap::Instance().phi_kernel_to_fluid_op();
base_kernel_name_map.end(), auto it = phi_kernel_to_fluid_op.find(phi_kernel_name);
[&phi_kernel_name](const auto& pair) { if (it != phi_kernel_to_fluid_op.end()) {
return pair.second == phi_kernel_name; return it->second;
});
if (it != base_kernel_name_map.end()) {
return it->first;
} }
return phi_kernel_name; return phi_kernel_name;
} }
......
...@@ -131,18 +131,23 @@ class OpUtilsMap { ...@@ -131,18 +131,23 @@ class OpUtilsMap {
static OpUtilsMap& Instance(); static OpUtilsMap& Instance();
bool Contains(const std::string& op_type) const { bool Contains(const std::string& op_type) const {
return base_kernel_name_map_.count(op_type) || return fluid_op_to_phi_kernel_.count(op_type) ||
arg_mapping_fn_map_.count(op_type); arg_mapping_fn_map_.count(op_type);
} }
void InsertBaseKernelName(std::string op_type, std::string base_kernel_name) { void InsertBaseKernelName(const std::string& op_type,
const std::string& base_kernel_name) {
fluid_op_to_phi_kernel_.insert({op_type, base_kernel_name});
}
void InsertFluidOplName(std::string op_type, std::string base_kernel_name) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
base_kernel_name_map_.count(op_type), phi_kernel_to_fluid_op_.count(base_kernel_name),
0UL, 0UL,
phi::errors::AlreadyExists( phi::errors::AlreadyExists(
"Operator (%s)'s api name has been registered.", op_type)); "Operator (%s)'s kernel name (%s) has been registered.",
base_kernel_name_map_.insert( op_type,
{std::move(op_type), std::move(base_kernel_name)}); base_kernel_name));
phi_kernel_to_fluid_op_.insert({base_kernel_name, op_type});
} }
bool HasArgumentMappingFn(const std::string& op_type) const { bool HasArgumentMappingFn(const std::string& op_type) const {
...@@ -163,8 +168,8 @@ class OpUtilsMap { ...@@ -163,8 +168,8 @@ class OpUtilsMap {
if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) { if (deprecated_op_names.find(op_type) != deprecated_op_names.end()) {
return deprecated_kernel_name; return deprecated_kernel_name;
} }
auto it = base_kernel_name_map_.find(op_type); auto it = fluid_op_to_phi_kernel_.find(op_type);
if (it == base_kernel_name_map_.end()) { if (it == fluid_op_to_phi_kernel_.end()) {
return op_type; return op_type;
} else { } else {
return it->second; return it->second;
...@@ -181,15 +186,23 @@ class OpUtilsMap { ...@@ -181,15 +186,23 @@ class OpUtilsMap {
} }
} }
const paddle::flat_hash_map<std::string, std::string>& base_kernel_name_map() const paddle::flat_hash_map<std::string, std::string>&
const { fluid_op_to_phi_kernel() const {
return base_kernel_name_map_; return fluid_op_to_phi_kernel_;
}
const paddle::flat_hash_map<std::string, std::string>&
phi_kernel_to_fluid_op() const {
return phi_kernel_to_fluid_op_;
} }
private: private:
OpUtilsMap() = default; OpUtilsMap() = default;
paddle::flat_hash_map<std::string, std::string> base_kernel_name_map_; paddle::flat_hash_map<std::string, std::string> fluid_op_to_phi_kernel_;
paddle::flat_hash_map<std::string, std::string> phi_kernel_to_fluid_op_;
paddle::flat_hash_map<std::string, ArgumentMappingFn> arg_mapping_fn_map_; paddle::flat_hash_map<std::string, ArgumentMappingFn> arg_mapping_fn_map_;
DISABLE_COPY_AND_ASSIGN(OpUtilsMap); DISABLE_COPY_AND_ASSIGN(OpUtilsMap);
...@@ -198,6 +211,7 @@ class OpUtilsMap { ...@@ -198,6 +211,7 @@ class OpUtilsMap {
struct BaseKernelNameRegistrar { struct BaseKernelNameRegistrar {
BaseKernelNameRegistrar(const char* op_type, const char* base_kernel_name) { BaseKernelNameRegistrar(const char* op_type, const char* base_kernel_name) {
OpUtilsMap::Instance().InsertBaseKernelName(op_type, base_kernel_name); OpUtilsMap::Instance().InsertBaseKernelName(op_type, base_kernel_name);
OpUtilsMap::Instance().InsertFluidOplName(op_type, base_kernel_name);
} }
}; };
......
...@@ -45,12 +45,12 @@ using complex64 = ::phi::dtype::complex<float>; ...@@ -45,12 +45,12 @@ using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
hard_swish, CPU, ALL_LAYOUT, phi::HardSwishKernel, float, double) {} hardswish, CPU, ALL_LAYOUT, phi::HardSwishKernel, float, double) {}
PD_REGISTER_KERNEL(relu6, CPU, ALL_LAYOUT, phi::Relu6Kernel, float, double) {} PD_REGISTER_KERNEL(relu6, CPU, ALL_LAYOUT, phi::Relu6Kernel, float, double) {}
PD_REGISTER_KERNEL(swish, CPU, ALL_LAYOUT, phi::SwishKernel, float, double) {} PD_REGISTER_KERNEL(swish, CPU, ALL_LAYOUT, phi::SwishKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(hard_swish, PD_REGISTER_KERNEL(hardswish,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::HardSwishKernel, phi::HardSwishKernel,
...@@ -80,13 +80,13 @@ PD_REGISTER_KERNEL(swish, ...@@ -80,13 +80,13 @@ PD_REGISTER_KERNEL(swish,
#endif #endif
#if defined PADDLE_WITH_XPU #if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(hard_swish, XPU, ALL_LAYOUT, phi::HardSwishKernel, float) {} PD_REGISTER_KERNEL(hardswish, XPU, ALL_LAYOUT, phi::HardSwishKernel, float) {}
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {} PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {}
PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {} PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {}
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
PD_REGISTER_KERNEL(hard_swish, PD_REGISTER_KERNEL(hardswish,
OneDNN, OneDNN,
ONEDNN, ONEDNN,
phi::HardSwishKernel, phi::HardSwishKernel,
......
...@@ -263,7 +263,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel) ...@@ -263,7 +263,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_tanh_grad, HardTanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
ThresholdedReluGradKernel) ThresholdedReluGradKernel)
...@@ -388,7 +388,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) ...@@ -388,7 +388,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
......
...@@ -146,7 +146,7 @@ PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel) ...@@ -146,7 +146,7 @@ PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel) PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel) PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel) PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
...@@ -183,7 +183,7 @@ PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) ...@@ -183,7 +183,7 @@ PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
......
...@@ -180,7 +180,7 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -180,7 +180,7 @@ void ArgMaxKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(arg_min, PD_REGISTER_KERNEL(argmin,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ArgMinKernel, phi::ArgMinKernel,
...@@ -191,7 +191,7 @@ PD_REGISTER_KERNEL(arg_min, ...@@ -191,7 +191,7 @@ PD_REGISTER_KERNEL(arg_min,
int16_t, int16_t,
uint8_t) {} uint8_t) {}
PD_REGISTER_KERNEL(arg_max, PD_REGISTER_KERNEL(argmax,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ArgMaxKernel, phi::ArgMaxKernel,
......
...@@ -347,7 +347,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel) ...@@ -347,7 +347,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_double_grad, TanhDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_double_grad, TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_triple_grad, TanhTripleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_triple_grad, TanhTripleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_tanh_grad, HardTanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel) LeakyReluDoubleGradKernel)
...@@ -474,7 +474,7 @@ PD_REGISTER_KERNEL(log_double_grad, ...@@ -474,7 +474,7 @@ PD_REGISTER_KERNEL(log_double_grad,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
......
...@@ -196,7 +196,7 @@ PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel) ...@@ -196,7 +196,7 @@ PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel) PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel) PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel) PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
...@@ -254,7 +254,7 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) ...@@ -254,7 +254,7 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
......
...@@ -248,7 +248,7 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -248,7 +248,7 @@ void ArgMaxKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(arg_min, PD_REGISTER_KERNEL(argmin,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ArgMinKernel, phi::ArgMinKernel,
...@@ -261,7 +261,7 @@ PD_REGISTER_KERNEL(arg_min, ...@@ -261,7 +261,7 @@ PD_REGISTER_KERNEL(arg_min,
int16_t, int16_t,
uint8_t) {} uint8_t) {}
PD_REGISTER_KERNEL(arg_max, PD_REGISTER_KERNEL(argmax,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ArgMaxKernel, phi::ArgMaxKernel,
......
...@@ -273,7 +273,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel) ...@@ -273,7 +273,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
......
...@@ -202,7 +202,7 @@ PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel) ...@@ -202,7 +202,7 @@ PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel)
PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel) PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel) PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel)
......
...@@ -617,7 +617,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) ...@@ -617,7 +617,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_sigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_sigmoid_grad, HardSigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
......
...@@ -486,7 +486,7 @@ PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad ...@@ -486,7 +486,7 @@ PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
......
...@@ -65,4 +65,4 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -65,4 +65,4 @@ void ArgMaxKernel(const Context& dev_ctx,
XPUAPIErrorMsg[r])); XPUAPIErrorMsg[r]));
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(arg_max, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {} PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {}
...@@ -39,10 +39,10 @@ namespace phi { ...@@ -39,10 +39,10 @@ namespace phi {
#define comma , #define comma ,
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hard_tanh", "t_min" comma "t_max"); DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hardtanh", "t_min" comma "t_max");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold"); DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish, DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish,
"hard_swish", "hardswish",
"threshold" comma "scale" comma "threshold" comma "scale" comma
"offset"); // NOLINT "offset"); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta"); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta"); // NOLINT
...@@ -55,7 +55,7 @@ DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold"); // NOLINT ...@@ -55,7 +55,7 @@ DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold"); // NOLINT
KernelSignature HardSwishOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature HardSwishOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"hard_swish_raw", {"X"}, {"threshold", "scale", "offset"}, {"Out"}); "hardswish_raw", {"X"}, {"threshold", "scale", "offset"}, {"Out"});
} }
KernelSignature SwishOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SwishOpArgumentMapping(const ArgumentMappingContext& ctx) {
...@@ -113,8 +113,10 @@ KernelSignature PowTripleGradOpArgumentMapping( ...@@ -113,8 +113,10 @@ KernelSignature PowTripleGradOpArgumentMapping(
} }
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh); PD_REGISTER_BASE_KERNEL_NAME(brelu, hardtanh);
PD_REGISTER_BASE_KERNEL_NAME(brelu_grad, hard_tanh_grad); PD_REGISTER_BASE_KERNEL_NAME(brelu_grad, hardtanh_grad);
PD_REGISTER_BASE_KERNEL_NAME(hard_swish, hardswish);
PD_REGISTER_BASE_KERNEL_NAME(hard_swish_grad, hardswish_grad);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
PD_REGISTER_BASE_KERNEL_NAME(arg_max, argmax);
PD_REGISTER_BASE_KERNEL_NAME(arg_min, argmin);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册