未验证 提交 ca2b6095 编写于 作者: G GGBond8488 提交者: GitHub

add cumsum prim backward (#50565)

* add cumsum prim backward

* skip aixs=None test case

* fix op generante eror

* fix static test error

* remove unused code

* fix static test error

* skip cpu float16 test case

* skip eager cpu cumsum float16 test case

* add cinn test

* reshape flatten out

* Disable cinn single test

* remove cinn test

* reformat todo

* add prim in cumsum op test

* remove old test

* fix typro

* fix typro

* fix typro

* pass axis=None test case

* remove forward prim test

* remove same name axis
上级 16a1b4a1
......@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
......@@ -100,6 +103,27 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
}
};
class CumsumCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::experimental::Tensor x = this->GetSingleForwardInput("X");
paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::experimental::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
int axis = static_cast<int>(this->Attr<int>("axis"));
bool flatten = static_cast<bool>(this->Attr<bool>("flatten"));
bool exclusive = static_cast<bool>(this->Attr<bool>("exclusive"));
bool reverse = static_cast<bool>(this->Attr<bool>("reverse"));
VLOG(6) << "Runing add_grad composite func";
prim::cumsum_grad<prim::DescTensor>(
x, out_grad, axis, flatten, exclusive, reverse, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};
class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -182,6 +206,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp,
REGISTER_OPERATOR(cumsum,
ops::CumOp,
ops::CumsumOpMaker,
ops::CumsumCompositeGradOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
......
......@@ -25,3 +25,4 @@
- tile
- transpose
- pad
- cumsum
......@@ -414,5 +414,20 @@ void slice_grad(const Tensor& input,
}
}
template <typename T>
void cumsum_grad(const Tensor& x,
const Tensor& out_grad,
const Scalar& axis,
bool flatten,
bool exclusive,
bool reverse,
Tensor* x_grad) {
if (x_grad) {
auto grad = cumsum<T>(out_grad, axis, flatten, exclusive, !reverse);
grad = reshape<T>(grad, x.shape());
set_output<T>(grad, x_grad);
}
}
} // namespace prim
} // namespace paddle
......@@ -313,6 +313,7 @@
kernel :
func : cumsum_grad
data_type: x
composite: cumsum_grad(x, out_grad, axis, flatten, exclusive, reverse, x_grad)
- backward_op : deformable_conv_grad
forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out)
......
......@@ -365,6 +365,17 @@
outputs :
out : Out
- op : cumsum
backward: cumsum_grad
inputs :
x : X
outputs :
out : Out
scalar:
axis:
data_type : int
tensor_name: AxisTensor
- op : data_norm
backward : data_norm_grad
extra :
......
......@@ -115,6 +115,9 @@ class TestCumsumOp(unittest.TestCase):
class TestSumOp1(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2}
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=2)}
......@@ -123,12 +126,15 @@ class TestSumOp1(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp2(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': -1, 'reverse': True}
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {
......@@ -141,12 +147,15 @@ class TestSumOp2(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp3(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 1}
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=1)}
......@@ -155,12 +164,15 @@ class TestSumOp3(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOp4(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 0}
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=0)}
......@@ -175,6 +187,9 @@ class TestSumOp4(OpTest):
class TestSumOp5(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.inputs = {'X': np.random.random((5, 20)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=1)}
......@@ -188,6 +203,9 @@ class TestSumOp5(OpTest):
class TestSumOp7(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.inputs = {'X': np.random.random((100)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].cumsum(axis=0)}
......@@ -226,6 +244,9 @@ class TestCumsumFP16(unittest.TestCase):
class TestSumOpExclusive1(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a}
......@@ -243,12 +264,15 @@ class TestSumOpExclusive1(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOpExclusive2(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((1, 1, 100)).astype("float64")
self.inputs = {'X': a}
......@@ -266,12 +290,15 @@ class TestSumOpExclusive2(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOpExclusive3(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a}
......@@ -289,12 +316,15 @@ class TestSumOpExclusive3(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOpExclusive4(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((1, 1, 100)).astype("float64")
self.inputs = {'X': a}
......@@ -312,12 +342,15 @@ class TestSumOpExclusive4(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOpExclusive5(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True}
a = np.random.random((4, 5, 40)).astype("float64")
self.inputs = {'X': a}
......@@ -335,12 +368,15 @@ class TestSumOpExclusive5(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOpExclusiveFP16(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, "exclusive": True, "dtype": "float16"}
a = np.random.random((4, 5, 20)).astype("float64")
self.inputs = {'X': a}
......@@ -358,12 +394,15 @@ class TestSumOpExclusiveFP16(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestSumOpReverseExclusive(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.prim_op_type = "prim"
self.python_api = paddle.cumsum
self.enable_cinn = False
self.attrs = {'axis': 2, 'reverse': True, "exclusive": True}
a = np.random.random((4, 5, 6)).astype("float64")
self.inputs = {'X': a}
......@@ -382,7 +421,7 @@ class TestSumOpReverseExclusive(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class BadInputTest(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册