未验证 提交 d9de3ef6 编写于 作者: W wangzhen38 提交者: GitHub

[prim] add elementwise_pow backward (#51230)

* [cinn] add elementwise_pow backward

* [cinn] update unnitest

* [cinn] update by comments

* [cinn] for ci

* [cinn] for ci

* [cinn] for ci

* [cinn] for ci

* [cinn] for ci
上级 3328a3d5
...@@ -12,6 +12,9 @@ limitations under the License. */ ...@@ -12,6 +12,9 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -41,6 +44,36 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -41,6 +44,36 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
} }
}; };
class ElementwisePowCompositeGradOpMaker
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor y = this->GetSingleForwardInput("Y");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite pow but we got: ", axis));
VLOG(6) << "Runing pow_grad composite func";
prim::elementwise_pow_grad<prim::DescTensor>(
x, y, out_grad, axis, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
};
class ElementwisePowOpMaker : public ElementwiseOpMaker { class ElementwisePowOpMaker : public ElementwiseOpMaker {
protected: protected:
std::string GetName() const override { return "Pow"; } std::string GetName() const override { return "Pow"; }
...@@ -65,7 +98,9 @@ REGISTER_OPERATOR(elementwise_pow, ...@@ -65,7 +98,9 @@ REGISTER_OPERATOR(elementwise_pow,
ops::ElementwiseOpInferVarType, ops::ElementwiseOpInferVarType,
ops::ElementwisePowOpGradMaker<paddle::framework::OpDesc>, ops::ElementwisePowOpGradMaker<paddle::framework::OpDesc>,
ops::ElementwisePowOpGradMaker<paddle::imperative::OpBase>); ops::ElementwisePowOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_pow_grad, ops::ElementwiseOpGrad); REGISTER_OPERATOR(elementwise_pow_grad,
ops::ElementwiseOpGrad,
ops::ElementwisePowCompositeGradOpMaker);
REGISTER_OP_VERSION(elementwise_pow) REGISTER_OP_VERSION(elementwise_pow)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -278,6 +278,56 @@ void divide_grad(const Tensor& x, ...@@ -278,6 +278,56 @@ void divide_grad(const Tensor& x,
} // indicate we will compute dx } // indicate we will compute dx
} }
template <typename T>
void elementwise_pow_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* dx,
Tensor* dy) {
if (dy) {
// dy = lnx * x^y
auto lnx = log<T>(x);
auto x_pow_y = elementwise_pow<T>(x, y);
auto dy_res = lnx * x_pow_y;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
set_output<T>(dy_res, dy);
} else {
auto dy_reduce_res =
dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
}
} else {
set_output<T>(dy_res, dy);
}
} // indicate we will compute dy
if (dx) {
// dx = y * x^(y-1)
auto tmp_z = y - 1.0;
auto x_pow_z = elementwise_pow<T>(x, tmp_z);
auto dx_res = y * x_pow_z;
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
set_output<T>(dx_res, dx);
} else {
auto dx_reduce_res =
dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
}
} else {
set_output<T>(dx_res, dx);
}
} // indicate we will compute dx
}
template <typename T> template <typename T>
void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
......
...@@ -423,6 +423,7 @@ ...@@ -423,6 +423,7 @@
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : GeneralBinaryGradInferMeta
param: [x, y] param: [x, y]
composite : elementwise_pow_grad(x, y, out_grad, axis)
kernel : kernel :
func : elementwise_pow_grad func : elementwise_pow_grad
......
...@@ -1218,6 +1218,7 @@ set(TEST_CINN_OPS ...@@ -1218,6 +1218,7 @@ set(TEST_CINN_OPS
test_elementwise_div_op test_elementwise_div_op
test_elementwise_mul_op test_elementwise_mul_op
test_gather_nd_op test_gather_nd_op
test_elementwise_pow_op
test_transpose_op) test_transpose_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
......
...@@ -31,6 +31,7 @@ class TestElementwisePowOp(OpTest): ...@@ -31,6 +31,7 @@ class TestElementwisePowOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(1, 2, [20, 5]).astype("float64"), 'X': np.random.uniform(1, 2, [20, 5]).astype("float64"),
'Y': np.random.uniform(1, 2, [20, 5]).astype("float64"), 'Y': np.random.uniform(1, 2, [20, 5]).astype("float64"),
...@@ -45,15 +46,22 @@ class TestElementwisePowOp(OpTest): ...@@ -45,15 +46,22 @@ class TestElementwisePowOp(OpTest):
def test_check_grad_normal(self): def test_check_grad_normal(self):
if hasattr(self, 'attrs'): if hasattr(self, 'attrs'):
self.check_grad(['X', 'Y'], 'Out', check_eager=False) self.check_grad(
['X', 'Y'], 'Out', check_eager=False, check_prim=True
)
else: else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True) self.check_grad(
['X', 'Y'], 'Out', check_eager=True, check_prim=True
)
class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp): class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.enable_cinn = False
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(1, 2, []).astype("float64"), 'X': np.random.uniform(1, 2, []).astype("float64"),
'Y': np.random.uniform(1, 2, []).astype("float64"), 'Y': np.random.uniform(1, 2, []).astype("float64"),
...@@ -65,6 +73,9 @@ class TestElementwisePowOp_ZeroDim2(TestElementwisePowOp): ...@@ -65,6 +73,9 @@ class TestElementwisePowOp_ZeroDim2(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.enable_cinn = False
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(1, 2, [20, 5]).astype("float64"), 'X': np.random.uniform(1, 2, [20, 5]).astype("float64"),
'Y': np.random.uniform(1, 2, []).astype("float64"), 'Y': np.random.uniform(1, 2, []).astype("float64"),
...@@ -76,6 +87,9 @@ class TestElementwisePowOp_ZeroDim3(TestElementwisePowOp): ...@@ -76,6 +87,9 @@ class TestElementwisePowOp_ZeroDim3(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.enable_cinn = False
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(1, 2, []).astype("float64"), 'X': np.random.uniform(1, 2, []).astype("float64"),
'Y': np.random.uniform(1, 2, [20, 5]).astype("float64"), 'Y': np.random.uniform(1, 2, [20, 5]).astype("float64"),
...@@ -87,6 +101,8 @@ class TestElementwisePowOp_big_shape_1(TestElementwisePowOp): ...@@ -87,6 +101,8 @@ class TestElementwisePowOp_big_shape_1(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(1, 2, [10, 10]).astype("float64"), 'X': np.random.uniform(1, 2, [10, 10]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [10, 10]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [10, 10]).astype("float64"),
...@@ -98,6 +114,8 @@ class TestElementwisePowOp_big_shape_2(TestElementwisePowOp): ...@@ -98,6 +114,8 @@ class TestElementwisePowOp_big_shape_2(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(1, 2, [10, 10]).astype("float64"), 'X': np.random.uniform(1, 2, [10, 10]).astype("float64"),
'Y': np.random.uniform(0.2, 2, [10, 10]).astype("float64"), 'Y': np.random.uniform(0.2, 2, [10, 10]).astype("float64"),
...@@ -112,6 +130,8 @@ class TestElementwisePowOp_scalar(TestElementwisePowOp): ...@@ -112,6 +130,8 @@ class TestElementwisePowOp_scalar(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [3, 3, 4]).astype(np.float64), 'X': np.random.uniform(0.1, 1, [3, 3, 4]).astype(np.float64),
'Y': np.random.uniform(0.1, 1, [1]).astype(np.float64), 'Y': np.random.uniform(0.1, 1, [1]).astype(np.float64),
...@@ -123,6 +143,9 @@ class TestElementwisePowOp_tensor(TestElementwisePowOp): ...@@ -123,6 +143,9 @@ class TestElementwisePowOp_tensor(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [100]).astype("float64"), 'X': np.random.uniform(0.1, 1, [100]).astype("float64"),
'Y': np.random.uniform(1, 3, [100]).astype("float64"), 'Y': np.random.uniform(1, 3, [100]).astype("float64"),
...@@ -134,6 +157,8 @@ class TestElementwisePowOp_broadcast_0(TestElementwisePowOp): ...@@ -134,6 +157,8 @@ class TestElementwisePowOp_broadcast_0(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 1, 100]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 1, 100]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64"),
...@@ -145,6 +170,7 @@ class TestElementwisePowOp_broadcast_1(TestElementwisePowOp): ...@@ -145,6 +170,7 @@ class TestElementwisePowOp_broadcast_1(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 100, 1]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 100, 1]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64"),
...@@ -154,11 +180,18 @@ class TestElementwisePowOp_broadcast_1(TestElementwisePowOp): ...@@ -154,11 +180,18 @@ class TestElementwisePowOp_broadcast_1(TestElementwisePowOp):
'Out': np.power(self.inputs['X'], self.inputs['Y'].reshape(100, 1)) 'Out': np.power(self.inputs['X'], self.inputs['Y'].reshape(100, 1))
} }
def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
class TestElementwisePowOp_broadcast_2(TestElementwisePowOp): class TestElementwisePowOp_broadcast_2(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [100, 3, 1]).astype("float64"), 'X': np.random.uniform(0.1, 1, [100, 3, 1]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [100]).astype("float64"),
...@@ -170,11 +203,18 @@ class TestElementwisePowOp_broadcast_2(TestElementwisePowOp): ...@@ -170,11 +203,18 @@ class TestElementwisePowOp_broadcast_2(TestElementwisePowOp):
) )
} }
def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
class TestElementwisePowOp_broadcast_3(TestElementwisePowOp): class TestElementwisePowOp_broadcast_3(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 20, 5, 1]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 20, 5, 1]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [20, 5]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [20, 5]).astype("float64"),
...@@ -186,11 +226,19 @@ class TestElementwisePowOp_broadcast_3(TestElementwisePowOp): ...@@ -186,11 +226,19 @@ class TestElementwisePowOp_broadcast_3(TestElementwisePowOp):
) )
} }
def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
class TestElementwisePowOp_broadcast_4(TestElementwisePowOp): class TestElementwisePowOp_broadcast_4(TestElementwisePowOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 10, 3, 5]).astype("float64"), 'X': np.random.uniform(0.1, 1, [2, 10, 3, 5]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [2, 10, 1, 5]).astype("float64"), 'Y': np.random.uniform(0.1, 1, [2, 10, 1, 5]).astype("float64"),
...@@ -202,6 +250,7 @@ class TestElementwisePowOpInt(OpTest): ...@@ -202,6 +250,7 @@ class TestElementwisePowOpInt(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.inputs = {'X': np.asarray([1, 3, 6]), 'Y': np.asarray([1, 1, 1])} self.inputs = {'X': np.asarray([1, 3, 6]), 'Y': np.asarray([1, 1, 1])}
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
...@@ -217,6 +266,7 @@ class TestElementwisePowGradOpInt(unittest.TestCase): ...@@ -217,6 +266,7 @@ class TestElementwisePowGradOpInt(unittest.TestCase):
self.x = np.asarray([1, 3, 6]) self.x = np.asarray([1, 3, 6])
self.y = np.asarray([1, 1, 1]) self.y = np.asarray([1, 1, 1])
self.res = self.x**self.y self.res = self.x**self.y
# dout = 1 # dout = 1
self.grad_res = np.asarray([1, 1, 1]) self.grad_res = np.asarray([1, 1, 1])
# dx = dout * y * pow(x, y-1) # dx = dout * y * pow(x, y-1)
...@@ -250,6 +300,8 @@ class TestElementwisePowOpFP16(OpTest): ...@@ -250,6 +300,8 @@ class TestElementwisePowOpFP16(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_pow" self.op_type = "elementwise_pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.prim_op_type = "prim"
self.inputs = { self.inputs = {
'X': np.random.uniform(1, 2, [20, 5]).astype("float16"), 'X': np.random.uniform(1, 2, [20, 5]).astype("float16"),
'Y': np.random.uniform(1, 2, [20, 5]).astype("float16"), 'Y': np.random.uniform(1, 2, [20, 5]).astype("float16"),
...@@ -270,6 +322,7 @@ class TestElementwisePowOpFP16(OpTest): ...@@ -270,6 +322,7 @@ class TestElementwisePowOpFP16(OpTest):
self.inputs['X'], self.inputs['Y'], 1 / self.inputs['X'].size self.inputs['X'], self.inputs['Y'], 1 / self.inputs['X'].size
), ),
check_eager=True, check_eager=True,
check_prim=True,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册