未验证 提交 594bd723 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Standardise some C++ API (Part4) (#47702)

* standard api

* fix sparse bugs

* fix xpu bugs, test=kunlun

* remove hard code for custom unittest

* open ci, test=kunlun

* deal with conflict
上级 28c56d77
...@@ -100,7 +100,7 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -100,7 +100,7 @@ class CompareOp : public framework::OperatorWithKernel {
char _##op_type##Comment::equation[]{_equation}; \ char _##op_type##Comment::equation[]{_equation}; \
DECLARE_INFER_SHAPE_FUNCTOR(op_type, \ DECLARE_INFER_SHAPE_FUNCTOR(op_type, \
op_type##_InferShapeFunctor, \ op_type##_InferShapeFunctor, \
PD_INFER_META(phi::CompareInferMeta)); \ PD_INFER_META(phi::CompareRawInferMeta)); \
REGISTER_OPERATOR( \ REGISTER_OPERATOR( \
op_type, \ op_type, \
::paddle::operators::CompareOp<_##op_type##Comment>, \ ::paddle::operators::CompareOp<_##op_type##Comment>, \
......
...@@ -400,7 +400,6 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T> ...@@ -400,7 +400,6 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
grad_op->SetInput("{{attr_name | to_pascal_case}}Tensor", this->Input("{{attr_name | to_pascal_case}}Tensor")); grad_op->SetInput("{{attr_name | to_pascal_case}}Tensor", this->Input("{{attr_name | to_pascal_case}}Tensor"));
{% endif %} {% endif %}
{% else %}{# maybe something wrong: backward op has more attrs than the forward one#} {% else %}{# maybe something wrong: backward op has more attrs than the forward one#}
grad_op->AddAttr<{{attr["typename"] | to_op_attr_type}}>({{attr_name}}, "({{attr["typename"] | to_op_attr_type}}), exceptional attr {{attr_name}}");
grad_op->SetAttr("{{attr_name}}", {{process_default_value(attr)}}); grad_op->SetAttr("{{attr_name}}", {{process_default_value(attr)}});
{% endif %} {% endif %}
{% endfor %} {% endfor %}
......
...@@ -841,7 +841,7 @@ static PyObject* tensor__gt__method(TensorObject* self, ...@@ -841,7 +841,7 @@ static PyObject* tensor__gt__method(TensorObject* self,
VLOG(6) << "Calling greater_than_ad_func in tensor__gt__method"; VLOG(6) << "Calling greater_than_ad_func in tensor__gt__method";
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
ret = greater_than_ad_func(self_tensor, other_tensor, -1); ret = greater_than_ad_func(self_tensor, other_tensor);
} }
return ToPyObject(ret); return ToPyObject(ret);
...@@ -927,7 +927,7 @@ static PyObject* tensor__ge__method(TensorObject* self, ...@@ -927,7 +927,7 @@ static PyObject* tensor__ge__method(TensorObject* self,
VLOG(6) << "Calling greater_equal_ad_func in tensor__ge__method"; VLOG(6) << "Calling greater_equal_ad_func in tensor__ge__method";
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
ret = greater_equal_ad_func(self_tensor, other_tensor, -1); ret = greater_equal_ad_func(self_tensor, other_tensor);
} }
return ToPyObject(ret); return ToPyObject(ret);
...@@ -1204,7 +1204,7 @@ static PyObject* tensor__lt__method(TensorObject* self, ...@@ -1204,7 +1204,7 @@ static PyObject* tensor__lt__method(TensorObject* self,
VLOG(6) << "Calling less_than_ad_func in tensor__lt__method"; VLOG(6) << "Calling less_than_ad_func in tensor__lt__method";
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
ret = less_than_ad_func(self_tensor, other_tensor, -1); ret = less_than_ad_func(self_tensor, other_tensor);
} }
return ToPyObject(ret); return ToPyObject(ret);
...@@ -1290,7 +1290,7 @@ static PyObject* tensor__le__method(TensorObject* self, ...@@ -1290,7 +1290,7 @@ static PyObject* tensor__le__method(TensorObject* self,
VLOG(6) << "Calling less_equal_ad_func in tensor__le__method"; VLOG(6) << "Calling less_equal_ad_func in tensor__le__method";
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
ret = less_equal_ad_func(self_tensor, other_tensor, -1); ret = less_equal_ad_func(self_tensor, other_tensor);
} }
return ToPyObject(ret); return ToPyObject(ret);
...@@ -1636,7 +1636,7 @@ static PyObject* tensor__ne__method(TensorObject* self, ...@@ -1636,7 +1636,7 @@ static PyObject* tensor__ne__method(TensorObject* self,
VLOG(6) << "Calling not_equal_ad_func in tensor__ne__method"; VLOG(6) << "Calling not_equal_ad_func in tensor__ne__method";
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
ret = not_equal_ad_func(self_tensor, other_tensor, -1); ret = not_equal_ad_func(self_tensor, other_tensor);
} }
return ToPyObject(ret); return ToPyObject(ret);
...@@ -1722,7 +1722,7 @@ static PyObject* tensor__eq__method(TensorObject* self, ...@@ -1722,7 +1722,7 @@ static PyObject* tensor__eq__method(TensorObject* self,
VLOG(6) << "Calling equal_ad_func in tensor__eq__method"; VLOG(6) << "Calling equal_ad_func in tensor__eq__method";
{ {
eager_gil_scoped_release guard; eager_gil_scoped_release guard;
ret = equal_ad_func(self_tensor, other_tensor, -1); ret = equal_ad_func(self_tensor, other_tensor);
} }
return ToPyObject(ret); return ToPyObject(ret);
......
...@@ -67,7 +67,7 @@ ...@@ -67,7 +67,7 @@
func : addmm_grad func : addmm_grad
- backward_op : affine_grid_grad - backward_op : affine_grid_grad
forward : affine_grid (Tensor input, IntArray outputShape, bool use_cudnn=true, bool align_corners=true) -> Tensor(output) forward : affine_grid (Tensor input, IntArray outputShape, bool align_corners=true, bool use_cudnn=true) -> Tensor(output)
args : (Tensor output_grad, IntArray outputShape, bool use_cudnn=true, bool align_corners=true) args : (Tensor output_grad, IntArray outputShape, bool use_cudnn=true, bool align_corners=true)
output : Tensor(input_grad) output : Tensor(input_grad)
infer_meta : infer_meta :
...@@ -577,8 +577,8 @@ ...@@ -577,8 +577,8 @@
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : fmax_grad - backward_op : fmax_grad
forward : fmax(Tensor x, Tensor y, int axis) -> Tensor(out) forward : fmax(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis) args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1)
output : Tensor(x_grad), Tensor(y_grad) output : Tensor(x_grad), Tensor(y_grad)
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : GeneralBinaryGradInferMeta
...@@ -587,8 +587,8 @@ ...@@ -587,8 +587,8 @@
func : fmax_grad func : fmax_grad
- backward_op : fmin_grad - backward_op : fmin_grad
forward : fmin(Tensor x, Tensor y, int axis) -> Tensor(out) forward : fmin(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis) args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1)
output : Tensor(x_grad), Tensor(y_grad) output : Tensor(x_grad), Tensor(y_grad)
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : GeneralBinaryGradInferMeta
...@@ -684,8 +684,8 @@ ...@@ -684,8 +684,8 @@
func : gumbel_softmax_grad func : gumbel_softmax_grad
- backward_op : hardswish_grad - backward_op : hardswish_grad
forward : hardswish (Tensor x, float threshold = 6.0, float scale = 6.0, float offset = 3.0) -> Tensor(out) forward : hardswish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float threshold, float scale, float offset) args : (Tensor x, Tensor out_grad, float threshold = 6.0, float scale = 6.0, float offset = 3.0)
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
...@@ -1418,8 +1418,8 @@ ...@@ -1418,8 +1418,8 @@
invoke : real_grad_impl(out_grad, x_grad) invoke : real_grad_impl(out_grad, x_grad)
- backward_op : relu6_grad - backward_op : relu6_grad
forward : relu6 (Tensor x, float threshold) -> Tensor(out) forward : relu6 (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float threshold) args : (Tensor out, Tensor out_grad, float threshold = 6)
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
...@@ -1810,7 +1810,7 @@ ...@@ -1810,7 +1810,7 @@
optional: u_grad, vh_grad, s_grad optional: u_grad, vh_grad, s_grad
- backward_op : swish_grad - backward_op : swish_grad
forward : swish (Tensor x, float beta=1.0) -> Tensor(out) forward : swish (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float bete=1.0) args : (Tensor x, Tensor out_grad, float bete=1.0)
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
......
...@@ -97,7 +97,7 @@ ...@@ -97,7 +97,7 @@
backward : addmm_grad backward : addmm_grad
- op : affine_grid - op : affine_grid
args : (Tensor input, IntArray outputShape, bool use_cudnn=true, bool align_corners=true) args : (Tensor input, IntArray outputShape, bool align_corners=true, bool use_cudnn=true)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : AffineGridInferMeta func : AffineGridInferMeta
...@@ -649,7 +649,7 @@ ...@@ -649,7 +649,7 @@
backend : place > x backend : place > x
- op : equal - op : equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
...@@ -751,7 +751,7 @@ ...@@ -751,7 +751,7 @@
func : floor_divide func : floor_divide
- op : fmax - op : fmax
args : (Tensor x, Tensor y, int axis) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
param: [x, y] param: [x, y]
...@@ -761,7 +761,7 @@ ...@@ -761,7 +761,7 @@
backward : fmax_grad backward : fmax_grad
- op : fmin - op : fmin
args : (Tensor x, Tensor y, int axis) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
param: [x, y] param: [x, y]
...@@ -898,7 +898,7 @@ ...@@ -898,7 +898,7 @@
func : generate_proposals_v2 func : generate_proposals_v2
- op : greater_equal - op : greater_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
...@@ -906,7 +906,7 @@ ...@@ -906,7 +906,7 @@
func : greater_equal func : greater_equal
- op : greater_than - op : greater_than
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
...@@ -945,7 +945,7 @@ ...@@ -945,7 +945,7 @@
backward : gumbel_softmax_grad backward : gumbel_softmax_grad
- op : hardswish - op : hardswish
args : (Tensor x, float threshold = 6.0, float scale = 6.0, float offset = 3.0) args : (Tensor x)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
...@@ -1180,7 +1180,7 @@ ...@@ -1180,7 +1180,7 @@
backward : lerp_grad backward : lerp_grad
- op : less_equal - op : less_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
...@@ -1188,7 +1188,7 @@ ...@@ -1188,7 +1188,7 @@
func : less_equal func : less_equal
- op : less_than - op : less_than
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
...@@ -1623,7 +1623,7 @@ ...@@ -1623,7 +1623,7 @@
backward : norm_grad backward : norm_grad
- op : not_equal - op : not_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
...@@ -1820,7 +1820,7 @@ ...@@ -1820,7 +1820,7 @@
backward : real_grad backward : real_grad
- op : relu6 - op : relu6
args : (Tensor x, float threshold) args : (Tensor x)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
...@@ -2192,9 +2192,8 @@ ...@@ -2192,9 +2192,8 @@
func : svd func : svd
backward : svd_grad backward : svd_grad
# The python API paddle.nn.functional.swish has no `bete` argument, it may be removed later
- op : swish - op : swish
args : (Tensor x, float beta=1.0) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
......
...@@ -251,8 +251,8 @@ ...@@ -251,8 +251,8 @@
pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr} pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_op : relu6_grad - backward_op : relu6_grad
forward : relu6(Tensor x, float threshold) -> Tensor(out) forward : relu6(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float threshold) args : (Tensor out, Tensor out_grad, float threshold = 6)
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
......
...@@ -213,7 +213,7 @@ ...@@ -213,7 +213,7 @@
backward : relu_grad backward : relu_grad
- op : relu6 - op : relu6
args : (Tensor x, float threshold) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
......
...@@ -328,10 +328,10 @@ void CholeskySolveInferMeta(const MetaTensor& x, ...@@ -328,10 +328,10 @@ void CholeskySolveInferMeta(const MetaTensor& x,
out->share_lod(x); out->share_lod(x);
} }
void CompareInferMeta(const MetaTensor& x, void CompareRawInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
int axis, int axis,
MetaTensor* out) { MetaTensor* out) {
auto dim_x = x.dims(); auto dim_x = x.dims();
auto dim_y = y.dims(); auto dim_y = y.dims();
...@@ -358,6 +358,12 @@ void CompareInferMeta(const MetaTensor& x, ...@@ -358,6 +358,12 @@ void CompareInferMeta(const MetaTensor& x,
out->set_dtype(DataType::BOOL); out->set_dtype(DataType::BOOL);
} }
void CompareInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* out) {
CompareRawInferMeta(x, y, -1, out);
}
void CompareAllInferMeta(const MetaTensor& x, void CompareAllInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
MetaTensor* out) { MetaTensor* out) {
......
...@@ -69,9 +69,13 @@ void CompareAllInferMeta(const MetaTensor& x, ...@@ -69,9 +69,13 @@ void CompareAllInferMeta(const MetaTensor& x,
void CompareInferMeta(const MetaTensor& x, void CompareInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
int axis,
MetaTensor* out); MetaTensor* out);
void CompareRawInferMeta(const MetaTensor& x,
const MetaTensor& y,
int axis,
MetaTensor* out);
void ComplexInferMeta(const MetaTensor& x, void ComplexInferMeta(const MetaTensor& x,
const MetaTensor& y, const MetaTensor& y,
MetaTensor* out); MetaTensor* out);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
HardSwishRawKernel<T, Context>(dev_ctx, x, 6, 6, 3, out);
}
template <typename T, typename Context>
void Relu6Kernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
Relu6RawKernel<T, Context>(dev_ctx, x, 6, out);
}
template <typename T, typename Context>
void SwishKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
SwishRawKernel<T, Context>(dev_ctx, x, 1.0, out);
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
hard_swish, CPU, ALL_LAYOUT, phi::HardSwishKernel, float, double) {}
PD_REGISTER_KERNEL(relu6, CPU, ALL_LAYOUT, phi::Relu6Kernel, float, double) {}
PD_REGISTER_KERNEL(swish, CPU, ALL_LAYOUT, phi::SwishKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(hard_swish,
GPU,
ALL_LAYOUT,
phi::HardSwishKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(relu6,
GPU,
ALL_LAYOUT,
phi::Relu6Kernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(swish,
GPU,
ALL_LAYOUT,
phi::SwishKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
#if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(hard_swish, XPU, ALL_LAYOUT, phi::HardSwishKernel, float) {}
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {}
PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {}
#endif
#ifdef PADDLE_WITH_MKLDNN
PD_REGISTER_KERNEL(hard_swish,
OneDNN,
ONEDNN,
phi::HardSwishKernel,
float,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
relu6, OneDNN, ONEDNN, phi::Relu6Kernel, float, phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
swish, OneDNN, ONEDNN, phi::SwishKernel, float, phi::dtype::bfloat16) {}
#endif
...@@ -75,13 +75,13 @@ DECLARE_ACTIVATION_KERNEL(Negative) ...@@ -75,13 +75,13 @@ DECLARE_ACTIVATION_KERNEL(Negative)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6Raw, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SwishRaw, beta)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Celu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Celu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Logit, eps) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Logit, eps)
...@@ -90,14 +90,29 @@ DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b) ...@@ -90,14 +90,29 @@ DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset)
template <typename T, typename Context>
void HardSwishRawKernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
float scale,
float offset,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx, void HardSwishKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float threshold,
float scale,
float offset,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void Relu6Kernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
template <typename T, typename Context>
void SwishKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void PowKernel(const Context& dev_ctx, void PowKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -18,20 +18,25 @@ limitations under the License. */ ...@@ -18,20 +18,25 @@ limitations under the License. */
namespace phi { namespace phi {
#define DECALRE_COMPARE_KERNEL(compare_kernel) \ #define DECALRE_COMPARE_KERNEL(name) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void compare_kernel(const Context& ctx, \ void name##RawKernel(const Context& ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
int axis, \ int axis, \
DenseTensor* out); DenseTensor* out); \
template <typename T, typename Context> \
DECALRE_COMPARE_KERNEL(LessThanKernel) void name##Kernel(const Context& ctx, \
DECALRE_COMPARE_KERNEL(LessEqualKernel) const DenseTensor& x, \
DECALRE_COMPARE_KERNEL(GreaterThanKernel) const DenseTensor& y, \
DECALRE_COMPARE_KERNEL(GreaterEqualKernel) DenseTensor* out);
DECALRE_COMPARE_KERNEL(EqualKernel)
DECALRE_COMPARE_KERNEL(NotEqualKernel) DECALRE_COMPARE_KERNEL(LessThan)
DECALRE_COMPARE_KERNEL(LessEqual)
DECALRE_COMPARE_KERNEL(GreaterThan)
DECALRE_COMPARE_KERNEL(GreaterEqual)
DECALRE_COMPARE_KERNEL(Equal)
DECALRE_COMPARE_KERNEL(NotEqual)
#undef DECALRE_COMPARE_KERNEL #undef DECALRE_COMPARE_KERNEL
#define DECALRE_COMPARE_ALL_KERNEL(compare_all_kernel) \ #define DECALRE_COMPARE_ALL_KERNEL(compare_all_kernel) \
......
...@@ -96,12 +96,12 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) ...@@ -96,12 +96,12 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor, ThresholdedReluFunctor,
threshold) threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6Functor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu6Raw, Relu6Functor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishFunctor, beta) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SwishRaw, SwishFunctor, beta)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardTanh, HardTanhFunctor, t_min, t_max) DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardTanh, HardTanhFunctor, t_min, t_max)
...@@ -113,12 +113,12 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, ...@@ -113,12 +113,12 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
offset) offset)
template <typename T, typename Context> template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx, void HardSwishRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float threshold, float threshold,
float scale, float scale,
float offset, float offset,
DenseTensor* out) { DenseTensor* out) {
funcs::HardSwishFunctor<T> functor; funcs::HardSwishFunctor<T> functor;
auto attrs = functor.GetAttrs(); auto attrs = functor.GetAttrs();
*(attrs[0].second) = threshold; *(attrs[0].second) = threshold;
...@@ -149,7 +149,7 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) ...@@ -149,7 +149,7 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
...@@ -182,8 +182,8 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) ...@@ -182,8 +182,8 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
......
...@@ -71,73 +71,6 @@ inline void CompareAllKernelImpl(const Context& ctx, ...@@ -71,73 +71,6 @@ inline void CompareAllKernelImpl(const Context& ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(less_than,
CPU,
ALL_LAYOUT,
phi::LessThanKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(less_equal,
CPU,
ALL_LAYOUT,
phi::LessEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_than,
CPU,
ALL_LAYOUT,
phi::GreaterThanKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(greater_equal,
CPU,
ALL_LAYOUT,
phi::GreaterEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal,
CPU,
ALL_LAYOUT,
phi::EqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(not_equal,
CPU,
ALL_LAYOUT,
phi::NotEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(equal_all, PD_REGISTER_KERNEL(equal_all,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -147,3 +80,33 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -147,3 +80,33 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, int64_t,
float, float,
double) {} double) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) {} \
PD_REGISTER_KERNEL(name##_raw, \
CPU, \
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) {}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
PD_REGISTER_COMPARE_KERNEL(equal, Equal)
PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual)
...@@ -122,11 +122,23 @@ using complex128 = ::phi::dtype::complex<double>; ...@@ -122,11 +122,23 @@ using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16; // using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(fmax_raw,
fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {} CPU,
ALL_LAYOUT,
phi::FMaxRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(fmin_raw,
fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {} CPU,
ALL_LAYOUT,
phi::FMinRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(maximum_raw, PD_REGISTER_KERNEL(maximum_raw,
CPU, CPU,
......
...@@ -110,10 +110,32 @@ void SubtractKernel(const Context& dev_ctx, ...@@ -110,10 +110,32 @@ void SubtractKernel(const Context& dev_ctx,
SubtractRawKernel<T>(dev_ctx, x, y, axis, out); SubtractRawKernel<T>(dev_ctx, x, y, axis, out);
} }
template <typename T, typename Context>
void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
FMaxRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
template <typename T, typename Context>
void FMinKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
FMinRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
} // namespace phi } // namespace phi
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(
fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(maximum, PD_REGISTER_KERNEL(maximum,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -210,6 +232,26 @@ PD_REGISTER_KERNEL(divide, ...@@ -210,6 +232,26 @@ PD_REGISTER_KERNEL(divide,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(fmax,
KPS,
ALL_LAYOUT,
phi::FMaxKernel,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(fmin,
KPS,
ALL_LAYOUT,
phi::FMinKernel,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(maximum, PD_REGISTER_KERNEL(maximum,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -19,18 +19,30 @@ ...@@ -19,18 +19,30 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void FMaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void FMaxKernel(const Context& dev_ctx, void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void FMinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void FMinKernel(const Context& dev_ctx, void FMinKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -112,13 +112,13 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) ...@@ -112,13 +112,13 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor, CudaThresholdedReluFunctor,
threshold) threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, CudaRelu6Functor, threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Relu6Raw, CudaRelu6Functor, threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
CudaHardShrinkFunctor, CudaHardShrinkFunctor,
threshold) threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, CudaSwishFunctor, beta) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SwishRaw, CudaSwishFunctor, beta)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CudaCELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CudaCELUFunctor, alpha)
...@@ -138,12 +138,12 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, ...@@ -138,12 +138,12 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Selu, CudaSeluFunctor, scale, alpha) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Selu, CudaSeluFunctor, scale, alpha)
template <typename T, typename Context> template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx, void HardSwishRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float threshold, float threshold,
float scale, float scale,
float offset, float offset,
DenseTensor* out) { DenseTensor* out) {
funcs::CudaHardSwishFunctor<T> functor; funcs::CudaHardSwishFunctor<T> functor;
auto attrs = functor.GetAttrs(); auto attrs = functor.GetAttrs();
*(attrs[0].second) = threshold; *(attrs[0].second) = threshold;
...@@ -198,7 +198,7 @@ PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel) ...@@ -198,7 +198,7 @@ PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_tanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel) PD_REGISTER_ACTIVATION_KERNEL(stanh, StanhKernel)
...@@ -254,8 +254,8 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) ...@@ -254,8 +254,8 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
......
...@@ -36,33 +36,38 @@ inline void CompareAllKernelImpl(const Context& ctx, ...@@ -36,33 +36,38 @@ inline void CompareAllKernelImpl(const Context& ctx,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out); DenseTensor* out);
#define DEFINE_COMPARE_KERNEL(compare_kernel, functor, inverse_functor) \ #define DEFINE_COMPARE_KERNEL(name, functor, inverse_functor) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void compare_kernel(const Context& ctx, \ void name##RawKernel(const Context& ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
int axis, \ int axis, \
DenseTensor* out) { \ DenseTensor* out) { \
CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \ CompareKernelImpl<T, Context, functor<T>, inverse_functor<T>>( \
ctx, x, y, axis, out); \ ctx, x, y, axis, out); \
} \
template <typename T, typename Context> \
void name##Kernel(const Context& ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
name##RawKernel<T, Context>(ctx, x, y, -1, out); \
} }
DEFINE_COMPARE_KERNEL(LessThanKernel, DEFINE_COMPARE_KERNEL(LessThan,
funcs::LessThanFunctor, funcs::LessThanFunctor,
funcs::GreaterThanFunctor) funcs::GreaterThanFunctor)
DEFINE_COMPARE_KERNEL(LessEqualKernel, DEFINE_COMPARE_KERNEL(LessEqual,
funcs::LessEqualFunctor, funcs::LessEqualFunctor,
funcs::GreaterEqualFunctor) funcs::GreaterEqualFunctor)
DEFINE_COMPARE_KERNEL(GreaterThanKernel, DEFINE_COMPARE_KERNEL(GreaterThan,
funcs::GreaterThanFunctor, funcs::GreaterThanFunctor,
funcs::LessThanFunctor) funcs::LessThanFunctor)
DEFINE_COMPARE_KERNEL(GreaterEqualKernel, DEFINE_COMPARE_KERNEL(GreaterEqual,
funcs::GreaterEqualFunctor, funcs::GreaterEqualFunctor,
funcs::LessEqualFunctor) funcs::LessEqualFunctor)
DEFINE_COMPARE_KERNEL(EqualKernel, funcs::EqualFunctor, funcs::EqualFunctor) DEFINE_COMPARE_KERNEL(Equal, funcs::EqualFunctor, funcs::EqualFunctor)
DEFINE_COMPARE_KERNEL(NotEqualKernel, DEFINE_COMPARE_KERNEL(NotEqual, funcs::NotEqualFunctor, funcs::NotEqualFunctor)
funcs::NotEqualFunctor,
funcs::NotEqualFunctor)
#undef DEFINE_COMPARE_KERNEL #undef DEFINE_COMPARE_KERNEL
#define DEFINE_COMPARE_ALL_KERNEL(compare_all_kernel, functor) \ #define DEFINE_COMPARE_ALL_KERNEL(compare_all_kernel, functor) \
......
...@@ -67,22 +67,22 @@ namespace phi { ...@@ -67,22 +67,22 @@ namespace phi {
} }
template <typename T, typename Context> template <typename T, typename Context>
void FMaxKernel(const Context& dev_ctx, void FMaxRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMaxFunctor<T>(), out); dev_ctx, x, y, axis, funcs::FMaxFunctor<T>(), out);
} }
template <typename T, typename Context> template <typename T, typename Context>
void FMinKernel(const Context& dev_ctx, void FMinRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMinFunctor<T>(), out); dev_ctx, x, y, axis, funcs::FMinFunctor<T>(), out);
......
...@@ -103,79 +103,20 @@ PD_REGISTER_KERNEL( ...@@ -103,79 +103,20 @@ PD_REGISTER_KERNEL(
greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {} greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {}
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {} PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {}
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {} PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {}
PD_REGISTER_KERNEL(
less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) {}
PD_REGISTER_KERNEL(
less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) {}
PD_REGISTER_KERNEL(
greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) {}
PD_REGISTER_KERNEL(
greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) {}
PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) {}
PD_REGISTER_KERNEL(
not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) {}
#else #else
PD_REGISTER_KERNEL(less_than,
KPS,
ALL_LAYOUT,
phi::LessThanKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(less_equal,
KPS,
ALL_LAYOUT,
phi::LessEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(greater_than,
KPS,
ALL_LAYOUT,
phi::GreaterThanKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(greater_equal,
KPS,
ALL_LAYOUT,
phi::GreaterEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(equal,
KPS,
ALL_LAYOUT,
phi::EqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(not_equal,
KPS,
ALL_LAYOUT,
phi::NotEqualKernel,
bool,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(equal_all, PD_REGISTER_KERNEL(equal_all,
KPS, KPS,
...@@ -186,4 +127,38 @@ PD_REGISTER_KERNEL(equal_all, ...@@ -186,4 +127,38 @@ PD_REGISTER_KERNEL(equal_all,
int64_t, int64_t,
float, float,
double) {} double) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) {} \
PD_REGISTER_KERNEL(name##_raw, \
KPS, \
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) {}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
PD_REGISTER_COMPARE_KERNEL(equal, Equal)
PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual)
#endif #endif
...@@ -93,20 +93,20 @@ using bfloat16 = phi::dtype::bfloat16; ...@@ -93,20 +93,20 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(fmax, PD_REGISTER_KERNEL(fmax_raw,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::FMaxKernel, phi::FMaxRawKernel,
float, float,
double, double,
int, int,
float16, float16,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(fmin, PD_REGISTER_KERNEL(fmin_raw,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::FMinKernel, phi::FMinRawKernel,
float, float,
double, double,
int, int,
......
...@@ -154,15 +154,15 @@ DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor) ...@@ -154,15 +154,15 @@ DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold)
DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(SwishRaw, SwishOneDNNFunctor, beta)
template <typename T, typename Context> template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx, void HardSwishRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float threshold, float threshold,
float scale, float scale,
float offset, float offset,
DenseTensor* out) { DenseTensor* out) {
HardSwishOneDNNFunctor<T> functor; HardSwishOneDNNFunctor<T> functor;
functor(dev_ctx, x, threshold, 0, out); functor(dev_ctx, x, threshold, 0, out);
} }
...@@ -182,10 +182,10 @@ void GeluKernel(const Context& dev_ctx, ...@@ -182,10 +182,10 @@ void GeluKernel(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void Relu6Kernel(const Context& dev_ctx, void Relu6RawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float threshold, float threshold,
DenseTensor* out) { DenseTensor* out) {
Relu6OneDNNFunctor<T> functor; Relu6OneDNNFunctor<T> functor;
functor(dev_ctx, x, 0, threshold, out); functor(dev_ctx, x, 0, threshold, out);
} }
...@@ -202,12 +202,12 @@ PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel) ...@@ -202,12 +202,12 @@ PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel)
PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel) PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel) PD_REGISTER_ACTIVATION_KERNEL(relu, ReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
...@@ -95,6 +95,7 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow) ...@@ -95,6 +95,7 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6_raw, Relu6Raw)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu)
PD_REGISTER_KERNEL(divide_scalar_coo, PD_REGISTER_KERNEL(divide_scalar_coo,
......
...@@ -99,6 +99,7 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs) ...@@ -99,6 +99,7 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6_raw, Relu6Raw)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu)
......
...@@ -89,9 +89,23 @@ DEFINE_SPARSE_UNARY_KERNEL(Relu) ...@@ -89,9 +89,23 @@ DEFINE_SPARSE_UNARY_KERNEL(Relu)
DEFINE_SPARSE_UNARY_KERNEL(Abs) DEFINE_SPARSE_UNARY_KERNEL(Abs)
DEFINE_SPARSE_UNARY_KERNEL(Expm1) DEFINE_SPARSE_UNARY_KERNEL(Expm1)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6, threshold) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6Raw, threshold)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
template <typename T, typename Context>
void Relu6CooKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCooTensor* out) {
Relu6RawCooKernel<T, Context>(dev_ctx, x, 6, out);
}
template <typename T, typename Context>
void Relu6CsrKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCsrTensor* out) {
Relu6RawCsrKernel<T, Context>(dev_ctx, x, 6, out);
}
template <typename T, typename Context> template <typename T, typename Context>
void ScaleCooKernel(const Context& dev_ctx, void ScaleCooKernel(const Context& dev_ctx,
const SparseCooTensor& x, const SparseCooTensor& x,
......
...@@ -356,10 +356,10 @@ struct XPUMishFunctor : public funcs::BaseActivationFunctor<T> { ...@@ -356,10 +356,10 @@ struct XPUMishFunctor : public funcs::BaseActivationFunctor<T> {
}; };
template <typename T, typename Context> template <typename T, typename Context>
void SwishKernel(const Context& dev_ctx, void SwishRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float beta, float beta,
DenseTensor* out) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
int r = xpu::swish(dev_ctx.x_context(), int r = xpu::swish(dev_ctx.x_context(),
...@@ -415,7 +415,9 @@ DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, XPUMishFunctor, threshold) ...@@ -415,7 +415,9 @@ DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, XPUMishFunctor, threshold)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu,
XPULeakyReluFunctor, XPULeakyReluFunctor,
alpha) alpha)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6, XPURelu6Functor, threshold) DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6Raw,
XPURelu6Functor,
threshold)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus,
XPUSoftplusFunctor, XPUSoftplusFunctor,
...@@ -423,12 +425,12 @@ DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, ...@@ -423,12 +425,12 @@ DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus,
threshold) threshold)
template <typename T, typename Context> template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx, void HardSwishRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
float threshold, float threshold,
float scale, float scale,
float offset, float offset,
DenseTensor* out) { DenseTensor* out) {
XPUHardSwishFunctor<T> functor; XPUHardSwishFunctor<T> functor;
auto attrs = functor.GetAttrs(); auto attrs = functor.GetAttrs();
*(attrs[0].second) = threshold; *(attrs[0].second) = threshold;
...@@ -452,13 +454,13 @@ PD_REGISTER_KERNEL( ...@@ -452,13 +454,13 @@ PD_REGISTER_KERNEL(
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_swish_raw, HardSwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(square, SquareKernel) PD_REGISTER_ACTIVATION_KERNEL(square, SquareKernel)
...@@ -52,48 +52,59 @@ void XPUCompareKernelImpl(const Context& dev_ctx, ...@@ -52,48 +52,59 @@ void XPUCompareKernelImpl(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "compare op"); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "compare op");
} }
#define DEFINE_XPU_COMPARE_KERNEL(compare_kernel, functor) \ #define DEFINE_XPU_COMPARE_KERNEL(name, functor) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void compare_kernel(const Context& dev_ctx, \ void name##RawKernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
int axis, \ int axis, \
DenseTensor* out) { \ DenseTensor* out) { \
using XPUType = typename XPUTypeTrait<T>::Type; \ using XPUType = typename XPUTypeTrait<T>::Type; \
XPUCompareKernelImpl<T, XPUType, Context>(dev_ctx, x, y, out, functor); \ XPUCompareKernelImpl<T, XPUType, Context>(dev_ctx, x, y, out, functor); \
} \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
name##RawKernel<T, Context>(dev_ctx, x, y, -1, out); \
} }
DEFINE_XPU_COMPARE_KERNEL(EqualKernel, xpu::broadcast_equal<XPUType>) DEFINE_XPU_COMPARE_KERNEL(Equal, xpu::broadcast_equal<XPUType>)
DEFINE_XPU_COMPARE_KERNEL(NotEqualKernel, xpu::broadcast_not_equal<XPUType>) DEFINE_XPU_COMPARE_KERNEL(NotEqual, xpu::broadcast_not_equal<XPUType>)
DEFINE_XPU_COMPARE_KERNEL(LessThanKernel, xpu::broadcast_less_than<XPUType>) DEFINE_XPU_COMPARE_KERNEL(LessThan, xpu::broadcast_less_than<XPUType>)
DEFINE_XPU_COMPARE_KERNEL(LessEqualKernel, xpu::broadcast_less_equal<XPUType>) DEFINE_XPU_COMPARE_KERNEL(LessEqual, xpu::broadcast_less_equal<XPUType>)
DEFINE_XPU_COMPARE_KERNEL(GreaterThanKernel, DEFINE_XPU_COMPARE_KERNEL(GreaterThan, xpu::broadcast_greater_than<XPUType>)
xpu::broadcast_greater_than<XPUType>) DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>)
DEFINE_XPU_COMPARE_KERNEL(GreaterEqualKernel,
xpu::broadcast_greater_equal<XPUType>)
#undef DEFINE_XPU_COMPARE_KERNEL #undef DEFINE_XPU_COMPARE_KERNEL
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
equal, XPU, ALL_LAYOUT, phi::EqualKernel, float, int, int64_t) {} less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) {}
PD_REGISTER_KERNEL(
not_equal, XPU, ALL_LAYOUT, phi::NotEqualKernel, float, int, int64_t) {} PD_REGISTER_KERNEL(less_than_raw,
PD_REGISTER_KERNEL(
less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, float, int, int64_t) {}
PD_REGISTER_KERNEL(
less_equal, XPU, ALL_LAYOUT, phi::LessEqualKernel, float, int, int64_t) {}
PD_REGISTER_KERNEL(greater_than,
XPU,
ALL_LAYOUT,
phi::GreaterThanKernel,
float,
int,
int64_t) {}
PD_REGISTER_KERNEL(greater_equal,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::GreaterEqualKernel, phi::LessThanRawKernel,
float,
int, int,
int64_t) {} int64_t,
float) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) {} \
PD_REGISTER_KERNEL(name##_raw, \
XPU, \
ALL_LAYOUT, \
phi::func##RawKernel, \
int, \
int64_t, \
float) {}
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
PD_REGISTER_COMPARE_KERNEL(greater_equal, GreaterEqual)
PD_REGISTER_COMPARE_KERNEL(equal, Equal)
PD_REGISTER_COMPARE_KERNEL(not_equal, NotEqual)
...@@ -53,6 +53,19 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(STanh, ...@@ -53,6 +53,19 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(STanh,
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold"); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold"); // NOLINT
KernelSignature HardSwishOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"hard_swish_raw", {"X"}, {"threshold", "scale", "offset"}, {"Out"});
}
KernelSignature SwishOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("swish_raw", {"X"}, {"beta"}, {"Out"});
}
KernelSignature Relu6OpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("relu6_raw", {"X"}, {"threshold"}, {"Out"});
}
KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) { if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"}); return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"});
...@@ -108,10 +121,12 @@ PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping); ...@@ -108,10 +121,12 @@ PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(brelu_grad, phi::HardTanhGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(brelu_grad, phi::HardTanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(relu6_grad, phi::Relu6GradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(relu6_grad, phi::Relu6GradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(relu6, phi::Relu6OpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad, PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad,
phi::HardSwishGradOpArgumentMapping); phi::HardSwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish, phi::HardSwishOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish, phi::SwishOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_double_grad, PD_REGISTER_ARG_MAPPING_FN(pow_double_grad,
phi::PowDoubleGradOpArgumentMapping); phi::PowDoubleGradOpArgumentMapping);
......
...@@ -17,27 +17,27 @@ ...@@ -17,27 +17,27 @@
namespace phi { namespace phi {
KernelSignature LessThanArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LessThanArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("less_than", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("less_than_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature LessEqualArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature LessEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("less_equal", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("less_equal_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature GreaterThanArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GreaterThanArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("greater_than", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("greater_than_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature GreaterEqualArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GreaterEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("greater_equal", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("greater_equal_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature EqualArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("equal", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("equal_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature NotEqualArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature NotEqualArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("not_equal", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("not_equal_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
} // namespace phi } // namespace phi
......
...@@ -181,12 +181,12 @@ KernelSignature ElementwiseMulGradOpArgumentMapping( ...@@ -181,12 +181,12 @@ KernelSignature ElementwiseMulGradOpArgumentMapping(
KernelSignature ElementwiseFMaxOpArgumentMapping( KernelSignature ElementwiseFMaxOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("fmax", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("fmax_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature ElementwiseFMinOpArgumentMapping( KernelSignature ElementwiseFMinOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("fmin", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("fmin_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature ElementwiseFMaxGradOpArgumentMapping( KernelSignature ElementwiseFMaxGradOpArgumentMapping(
......
...@@ -2075,7 +2075,7 @@ def greater_than(x, y, cond=None, name=None): ...@@ -2075,7 +2075,7 @@ def greater_than(x, y, cond=None, name=None):
attrs = dict() attrs = dict()
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.greater_than(x, y, -1) return _C_ops.greater_than(x, y)
else: else:
helper.append_op( helper.append_op(
type='greater_than', type='greater_than',
...@@ -2173,8 +2173,7 @@ def equal(x, y, cond=None, name=None): ...@@ -2173,8 +2173,7 @@ def equal(x, y, cond=None, name=None):
out2 = fluid.layers.equal(x=label_cond,y=limit, cond=out_cond) #out2=[False, True] out_cond=[False, True] out2 = fluid.layers.equal(x=label_cond,y=limit, cond=out_cond) #out2=[False, True] out_cond=[False, True]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
default_axis = -1 return _C_ops.equal(x, y)
return _C_ops.equal(x, y, default_axis)
check_variable_and_dtype( check_variable_and_dtype(
x, "x", ["float32", "float64", "int32", "int64"], "equal" x, "x", ["float32", "float64", "int32", "int64"], "equal"
......
if(WITH_CUSTOM_DEVICE AND NOT WITH_GPU) if(WITH_CUSTOM_DEVICE AND NOT WITH_GPU)
set(PLUGIN_URL https://github.com/PaddlePaddle/PaddleCustomDevice.git) set(PLUGIN_URL https://github.com/PaddlePaddle/PaddleCustomDevice.git)
set(PLUGIN_TAG 0698428ddba21e6baecb690579f37c48896f7d56) set(PLUGIN_TAG develop)
file( file(
GLOB TEST_OPS GLOB TEST_OPS
......
...@@ -402,7 +402,7 @@ def hardswish(x, name=None): ...@@ -402,7 +402,7 @@ def hardswish(x, name=None):
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.hard_swish(x) return _legacy_C_ops.hard_swish(x)
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.hardswish(x, 6, 6, 3) return _C_ops.hardswish(x)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'hardswish' x, 'x', ['float16', 'float32', 'float64'], 'hardswish'
...@@ -893,7 +893,7 @@ def relu6(x, name=None): ...@@ -893,7 +893,7 @@ def relu6(x, name=None):
""" """
threshold = 6.0 threshold = 6.0
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.relu6(x, threshold) return _C_ops.relu6(x)
if in_dynamic_mode(): if in_dynamic_mode():
return _legacy_C_ops.relu6(x, 'threshold', threshold) return _legacy_C_ops.relu6(x, 'threshold', threshold)
...@@ -1388,7 +1388,7 @@ def swish(x, name=None): ...@@ -1388,7 +1388,7 @@ def swish(x, name=None):
# [-0.23840584, 0. , 0.73105854]) # [-0.23840584, 0. , 0.73105854])
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.swish(x, 1.0) return _C_ops.swish(x)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.swish(x, 'beta', 1.0) return _legacy_C_ops.swish(x, 'beta', 1.0)
......
...@@ -92,7 +92,7 @@ def affine_grid(theta, out_shape, align_corners=True, name=None): ...@@ -92,7 +92,7 @@ def affine_grid(theta, out_shape, align_corners=True, name=None):
if isinstance(out_shape, Variable) if isinstance(out_shape, Variable)
else out_shape else out_shape
) )
return _C_ops.affine_grid(theta, _out_shape, use_cudnn, align_corners) return _C_ops.affine_grid(theta, _out_shape, align_corners, use_cudnn)
elif in_dynamic_mode(): elif in_dynamic_mode():
_out_shape = ( _out_shape = (
out_shape.numpy().tolist() out_shape.numpy().tolist()
......
...@@ -140,7 +140,7 @@ def relu6(x, name=None): ...@@ -140,7 +140,7 @@ def relu6(x, name=None):
sparse_x = dense_x.to_sparse_coo(1) sparse_x = dense_x.to_sparse_coo(1)
out = paddle.sparse.nn.functional.relu6(sparse_x) out = paddle.sparse.nn.functional.relu6(sparse_x)
""" """
return _C_ops.sparse_relu6(x, 6.0) return _C_ops.sparse_relu6(x)
@dygraph_only @dygraph_only
......
...@@ -445,8 +445,7 @@ def equal(x, y, name=None): ...@@ -445,8 +445,7 @@ def equal(x, y, name=None):
y = full(shape=[1], dtype=x.dtype, fill_value=y) y = full(shape=[1], dtype=x.dtype, fill_value=y)
if in_dygraph_mode(): if in_dygraph_mode():
default_axis = -1 return _C_ops.equal(x, y)
return _C_ops.equal(x, y, default_axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.equal(x, y) return _legacy_C_ops.equal(x, y)
...@@ -502,8 +501,7 @@ def greater_equal(x, y, name=None): ...@@ -502,8 +501,7 @@ def greater_equal(x, y, name=None):
print(result1) # result1 = [True False True] print(result1) # result1 = [True False True]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
default_axis = -1 return _C_ops.greater_equal(x, y)
return _C_ops.greater_equal(x, y, default_axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.greater_equal(x, y) return _legacy_C_ops.greater_equal(x, y)
...@@ -559,7 +557,7 @@ def greater_than(x, y, name=None): ...@@ -559,7 +557,7 @@ def greater_than(x, y, name=None):
print(result1) # result1 = [False False True] print(result1) # result1 = [False False True]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.greater_than(x, y, -1) return _C_ops.greater_than(x, y)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.greater_than(x, y) return _legacy_C_ops.greater_than(x, y)
...@@ -616,8 +614,7 @@ def less_equal(x, y, name=None): ...@@ -616,8 +614,7 @@ def less_equal(x, y, name=None):
print(result1) # result1 = [True True False] print(result1) # result1 = [True True False]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
axis = -1 return _C_ops.less_equal(x, y)
return _C_ops.less_equal(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.less_equal(x, y) return _legacy_C_ops.less_equal(x, y)
...@@ -674,8 +671,7 @@ def less_than(x, y, name=None): ...@@ -674,8 +671,7 @@ def less_than(x, y, name=None):
print(result1) # result1 = [False True False] print(result1) # result1 = [False True False]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
default_axis = -1 return _C_ops.less_than(x, y)
return _C_ops.less_than(x, y, default_axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.less_than(x, y) return _legacy_C_ops.less_than(x, y)
...@@ -732,8 +728,7 @@ def not_equal(x, y, name=None): ...@@ -732,8 +728,7 @@ def not_equal(x, y, name=None):
print(result1) # result1 = [False True True] print(result1) # result1 = [False True True]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
axis = -1 return _C_ops.not_equal(x, y)
return _C_ops.not_equal(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.not_equal(x, y) return _legacy_C_ops.not_equal(x, y)
......
...@@ -1168,7 +1168,7 @@ def fmax(x, y, name=None): ...@@ -1168,7 +1168,7 @@ def fmax(x, y, name=None):
axis = -1 axis = -1
act = None act = None
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.fmax(x, y, axis) return _C_ops.fmax(x, y)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _elementwise_op_in_dygraph( return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type x, y, axis=axis, act=act, op_name=op_type
...@@ -1236,7 +1236,7 @@ def fmin(x, y, name=None): ...@@ -1236,7 +1236,7 @@ def fmin(x, y, name=None):
axis = -1 axis = -1
act = None act = None
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.fmin(x, y, axis) return _C_ops.fmin(x, y)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _elementwise_op_in_dygraph( return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type x, y, axis=axis, act=act, op_name=op_type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册