未验证 提交 4892d592 编写于 作者: B BrilliantYuKaimin 提交者: GitHub

【PaddlePaddle Hackathon 2】18、为 Paddle 新增 paddle.heaviside 和 paddle.Tensor.heaviside API (#41872)

* Create elementwise_heaviside_op.cc

* add ElementwiseHeavisideFunctor

* Create test_elementwise_heaviside_op.py

* 增加heaviside的python接口

* add heaviside in white list

* 增加heaviside的签名

* 增加heaviside的核函数

* 增加heaviside梯度的核函数

* 增加heaviside梯度的注册

* 调整代码格式

* Update elementwise_sig.cc

* add heaviside in __all__

* Update heaviside docs

* Update math.py

* Update math.py

* Update math.py
上级 81644145
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwiseHeavisideOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "Heaviside"; }
std::string GetEquation() const override { return "Out = Heaviside(X, Y)"; }
void AddInputX() override {
AddInput("X",
"(Tensor), The input tensor of Heaviside step function. "
"Its dtype can be int32, int64, float32 and float64");
}
void AddInputY() override {
AddInput("Y",
"(Tensor), The tensor determining a Heaviside step function, "
"which is the value when X = 0. Its dtype should be same as X.");
}
std::string GetOpFuntionality() const override {
return "Computes the Heaviside step function determined by Y "
"for each element in X.";
}
};
template <typename T>
class ElementwiseHeavisideGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_heaviside_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
elementwise_heaviside, ops::ElementwiseOp, ops::ElementwiseHeavisideOpMaker,
ops::ElementwiseHeavisideGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseHeavisideGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_heaviside_grad, ops::ElementwiseOpGrad);
......@@ -88,6 +88,16 @@ PD_REGISTER_KERNEL(minimum_grad,
int,
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(elementwise_heaviside_grad,
CPU,
ALL_LAYOUT,
phi::ElementwiseHeavisideGradKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow_grad,
CPU,
ALL_LAYOUT,
......
......@@ -95,6 +95,18 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
dev_ctx, x, y, axis, funcs::ElementwisePowFunctor<T>(), out);
}
template <typename T, typename Context>
void ElementwiseHeavisideRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::ElementwiseHeavisideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwiseHeavisideFunctor<T>(), out);
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
......@@ -149,3 +161,11 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_heaviside_raw,
CPU,
ALL_LAYOUT,
phi::ElementwiseHeavisideRawKernel,
float,
double,
int,
int64_t) {}
......@@ -55,6 +55,15 @@ void MinimumGradKernel(const Context& dev_ctx,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void ElementwiseHeavisideGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void ElementwisePowGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -64,6 +64,15 @@ void ElementwisePowKernel(const Context& dev_ctx,
ElementwisePowRawKernel<T>(dev_ctx, x, y, axis, out);
}
template <typename T, typename Context>
void ElementwiseHeavisideKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
int axis = -1;
ElementwiseHeavisideRawKernel<T>(dev_ctx, x, y, axis, out);
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
......@@ -91,6 +100,14 @@ PD_REGISTER_KERNEL(
modulo, CPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(
floor_divide, CPU, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {}
PD_REGISTER_KERNEL(elementwise_heaviside,
CPU,
ALL_LAYOUT,
phi::ElementwiseHeavisideKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow,
CPU,
ALL_LAYOUT,
......@@ -126,6 +143,14 @@ PD_REGISTER_KERNEL(
modulo, GPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(
floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {}
PD_REGISTER_KERNEL(elementwise_heaviside,
GPU,
ALL_LAYOUT,
phi::ElementwiseHeavisideKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow,
KPS,
ALL_LAYOUT,
......
......@@ -98,6 +98,19 @@ void ElementwisePowKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void ElementwiseHeavisideRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void ElementwiseHeavisideKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Maximum(const Context& dev_ctx,
const DenseTensor& x,
......@@ -142,6 +155,17 @@ DenseTensor FloorDivide(const Context& dev_ctx,
return dense_out;
}
template <typename T, typename Context>
DenseTensor ElementwiseHeaviside(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y) {
DenseTensor dense_out;
MetaTensor meta_out(&dense_out);
ElementwiseInferMeta(x, y, &meta_out);
ElementwiseHeavisideKernel<T, Context>(dev_ctx, x, y, &dense_out);
return dense_out;
}
template <typename T, typename Context>
DenseTensor ElementwisePow(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -543,6 +543,13 @@ struct InverseModuloFunctor<
}
};
template <typename T>
struct ElementwiseHeavisideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return a == static_cast<T>(0) ? b : static_cast<T>(a > 0);
}
};
template <typename T>
struct FloorDivideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
......
......@@ -128,6 +128,16 @@ PD_REGISTER_KERNEL(minimum_grad,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(elementwise_heaviside_grad,
GPU,
ALL_LAYOUT,
phi::ElementwiseHeavisideGradKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow_grad,
GPU,
ALL_LAYOUT,
......
......@@ -683,6 +683,43 @@ struct MinGradDy {
}
};
template <typename T>
struct HeavisideGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(0);
}
};
template <typename T>
struct HeavisideGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(x == static_cast<T>(0));
}
};
template <typename T, typename Context>
void ElementwiseHeavisideGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
int axis,
DenseTensor* dx,
DenseTensor* dy) {
funcs::ElementwiseGradPreProcess(dout, dx);
phi::funcs::
ElemwiseGradCompute<Context, T, HeavisideGradDx<T>, HeavisideGradDy<T>>(
dev_ctx,
x,
y,
dout,
dout,
axis,
dx,
dy,
HeavisideGradDx<T>(),
HeavisideGradDy<T>());
}
template <typename T>
struct PowGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
......
......@@ -54,6 +54,8 @@ void FloorDivideKernel(const Context& dev_ctx,
int axis = -1;
FloorDivideRawKernel<T>(dev_ctx, x, y, axis, out);
}
// Create the definition of Heaviside
DEFINE_CUDA_ELEMENTWISE_OP(ElementwiseHeaviside)
// Create the definition of Pow
DEFINE_CUDA_ELEMENTWISE_OP(ElementwisePow)
template <typename T, typename Context>
......@@ -130,6 +132,14 @@ PD_REGISTER_KERNEL(floor_divide_raw,
phi::FloorDivideRawKernel,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_heaviside_raw,
KPS,
ALL_LAYOUT,
phi::ElementwiseHeavisideRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow_raw,
KPS,
ALL_LAYOUT,
......
......@@ -95,6 +95,16 @@ KernelSignature ElementwiseFloorDivOpArgumentMapping(
return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseHeavisideOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (axis == -1) {
return KernelSignature("elementwise_heaviside", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature(
"elementwise_heaviside_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwisePowOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
......@@ -208,6 +218,15 @@ KernelSignature ElementwiseMinGradOpArgumentMapping(
return KernelSignature(
"minimum_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseHeavisideGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_heaviside_grad",
{"X", "Y", "Out@GRAD"},
{"axis"},
{"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwisePowGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_pow_grad",
......@@ -258,6 +277,8 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mod,
phi::ElementwiseModOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv,
phi::ElementwiseFloorDivOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside,
phi::ElementwiseHeavisideOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow,
phi::ElementwisePowOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad,
......@@ -292,5 +313,7 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad,
phi::ElementwiseMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad,
phi::ElementwiseMinGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_heaviside_grad,
phi::ElementwiseHeavisideGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad,
phi::ElementwisePowGradOpArgumentMapping);
......@@ -269,6 +269,7 @@ from .tensor.math import fmax # noqa: F401
from .tensor.math import fmin # noqa: F401
from .tensor.math import inner # noqa: F401
from .tensor.math import outer # noqa: F401
from .tensor.math import heaviside # noqa: F401
from .tensor.math import frac # noqa: F401
from .tensor.random import bernoulli # noqa: F401
......@@ -635,4 +636,5 @@ __all__ = [ # noqa
'renorm',
'take_along_axis',
'put_along_axis',
'heaviside',
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
class TestElementwiseOp(OpTest):
def setUp(self):
self.op_type = "elementwise_heaviside"
x = np.random.random((13, 17)).astype("float64")
y = np.random.random((13, 17)).astype("float64")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestHeavisideBroadcast(unittest.TestCase):
def setUp(self):
self.input_1 = np.random.rand(2, 100, 13, 17).astype("float32")
self.input_2 = np.random.rand(100, 13, 17).astype("float32")
self.input_3 = np.random.rand(100, 13, 1).astype("float32")
self.input_4 = np.random.rand(13, 17).astype("float32")
self.input_5 = np.random.rand(1).astype("float32")
self.np_expected1 = np.heaviside(self.input_1, self.input_2)
self.np_expected2 = np.heaviside(self.input_2, self.input_3)
self.np_expected3 = np.heaviside(self.input_2, self.input_4)
self.np_expected4 = np.heaviside(self.input_4, self.input_5)
def test_broadcast(self):
paddle.disable_static()
self.tensor_1 = paddle.to_tensor(self.input_1)
self.tensor_2 = paddle.to_tensor(self.input_2)
self.tensor_3 = paddle.to_tensor(self.input_3)
self.tensor_4 = paddle.to_tensor(self.input_4)
self.tensor_5 = paddle.to_tensor(self.input_5)
res = paddle.heaviside(self.tensor_1, self.tensor_2)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected1))
res = paddle.heaviside(self.tensor_2, self.tensor_3)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected2))
res = paddle.heaviside(self.tensor_2, self.tensor_4)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected3))
res = paddle.heaviside(self.tensor_4, self.tensor_5)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected4))
class TestHeavisideAPI_float64(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random((13, 17)).astype("float64")
self.y_np = np.random.random((13, 17)).astype("float64")
self.out_np = np.heaviside(self.x_np, self.y_np)
self.dtype = "float64"
def test_static(self):
for use_cuda in ([False, True]
if paddle.device.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
prog = paddle.static.Program()
with paddle.static.program_guard(prog):
x = paddle.static.data(
name=f"x_{self.dtype}", shape=[13, 17], dtype=self.dtype)
y = paddle.static.data(
name=f"y_{self.dtype}", shape=[13, 17], dtype=self.dtype)
out = paddle.heaviside(x, y)
exe = paddle.static.Executor(place=place)
res = exe.run(prog,
feed={
f"x_{self.dtype}": self.x_np,
f"y_{self.dtype}": self.y_np
},
fetch_list=out,
use_prune=True)
self.assertTrue(np.allclose(res, self.out_np))
def test_dygraph(self):
for use_cuda in ([False, True]
if paddle.device.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
result = paddle.heaviside(
paddle.to_tensor(self.x_np), paddle.to_tensor(self.y_np))
self.assertTrue(np.allclose(result.numpy(), self.out_np))
class TestHeavisideAPI_float32(TestHeavisideAPI_float64):
def setUp(self):
self.x_np = np.random.random((13, 17)).astype("float32")
self.y_np = np.random.random((13, 17)).astype("float32")
self.out_np = np.heaviside(self.x_np, self.y_np)
self.dtype = "float32"
class TestHeavisideAPI_int64(TestHeavisideAPI_float64):
def setUp(self):
self.x_np = np.random.random((13, 17)).astype("int64")
self.y_np = np.random.random((13, 17)).astype("int64")
self.out_np = np.heaviside(self.x_np, self.y_np)
self.dtype = "int64"
class TestHeavisideAPI_int32(TestHeavisideAPI_float64):
def setUp(self):
self.x_np = np.random.random((13, 17)).astype("int32")
self.y_np = np.random.random((13, 17)).astype("int32")
self.out_np = np.heaviside(self.x_np, self.y_np)
self.dtype = "int32"
class TestHeavisideError(unittest.TestCase):
def test_input(self):
paddle.disable_static()
def test_input_x():
paddle.heaviside(1, paddle.randn([100]))
self.assertRaises(ValueError, test_input_x)
def test_input_y():
paddle.heaviside(paddle.randn([100]), 1)
self.assertRaises(ValueError, test_input_y)
def test_input_xy():
paddle.heaviside(
paddle.randn([100], 'float32'), paddle.randn([100], 'float64'))
self.assertRaises(ValueError, test_input_xy)
if __name__ == '__main__':
unittest.main()
......@@ -37,6 +37,7 @@ NEED_TO_FIX_OP_LIST = [
'dot',
'elementwise_add',
'elementwise_div',
'elementwise_heaviside',
'elementwise_max',
'elementwise_min',
'elementwise_mul',
......
......@@ -229,6 +229,7 @@ from .math import fmax # noqa: F401
from .math import fmin # noqa: F401
from .math import inner # noqa: F401
from .math import outer # noqa: F401
from .math import heaviside # noqa: F401
from .math import frac # noqa: F401
from .random import multinomial # noqa: F401
......@@ -495,6 +496,7 @@ tensor_method_func = [ #noqa
'put_along_axis',
'put_along_axis_',
'exponential_',
'heaviside',
]
#this list used in math_op_patch.py for magic_method bind
......
......@@ -4381,6 +4381,54 @@ def angle(x, name=None):
helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out
def heaviside(x, y, name=None):
"""
Computes the Heaviside step function determined by corresponding element in y for each element in x. The equation is
.. math::
heaviside(x, y)=
\left\{
\\begin{array}{lcl}
0,& &\\text{if} \ x < 0, \\\\
y,& &\\text{if} \ x = 0, \\\\
1,& &\\text{if} \ x > 0.
\end{array}
\\right.
Notes:
``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
Args:
x (Tensor): The input tensor of Heaviside step function, it's data type should be float32, float64, int32 or int64.
y (Tensor): The tensor that determines a Heaviside step function, it's data type should be float32, float64, int32 or int64.
name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. If x and y have different shapes and are broadcastable, the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
.. code-block:: python
:name: heaviside-example
import paddle
x = paddle.to_tensor([-0.5, 0, 0.5])
y = paddle.to_tensor([0.1])
paddle.heaviside(x, y)
# [0. , 0.10000000, 1. ]
x = paddle.to_tensor([[-0.5, 0, 0.5], [-0.5, 0.5, 0]])
y = paddle.to_tensor([0.1, 0.2, 0.3])
paddle.heaviside(x, y)
# [[0. , 0.20000000, 1. ],
# [0. , 1. , 0.30000001]]
"""
op_type = 'elementwise_heaviside'
axis = -1
act = None
if _non_static_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals()))
def frac(x, name=None):
"""
This API is used to return the fractional portion of each element in input.
......
......@@ -170,6 +170,7 @@ STATIC_MODE_TESTING_LIST = [
'test_elementwise_div_op',
'test_elementwise_floordiv_op',
'test_elementwise_gradient_op',
'test_elementwise_heaviside_op',
'test_elementwise_max_op',
'test_elementwise_min_op',
'test_elementwise_mod_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册