未验证 提交 d04c9cda 编写于 作者: W wangxiaoning 提交者: GitHub

Add reduce_max_grad composite rule (#51653)

* max comp

* fix

* add test

* fix

* fix

* fix

* fix

* fix test

* fix api
上级 9b2b3dad
......@@ -14,6 +14,9 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
......@@ -25,6 +28,32 @@ class ReduceMaxOpMaker : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_max"; }
};
namespace paddle {
namespace operators {
class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out = this->GetSingleForwardOutput("Out");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
std::vector<int> axis = this->Attr<std::vector<int>>("dim");
bool keep_dim = this->Attr<bool>("keep_dim");
bool reduce_all = this->Attr<bool>("reduce_all");
paddle::Tensor x_grad_t = this->GetSingleInputGrad("X");
paddle::Tensor* x_grad = this->GetOutputPtr(&x_grad_t);
std::string x_grad_name = this->GetOutputName(x_grad_t);
VLOG(6) << "Runing max_grad composite func";
prim::max_grad<prim::DescTensor>(
x, out, out_grad, axis, keep_dim, reduce_all, x_grad);
this->RecoverOutputName(x_grad_t, x_grad_name);
}
};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_max,
ReduceMaxInferShapeFunctor,
......@@ -36,5 +65,6 @@ REGISTER_OPERATOR(
ReduceMaxOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
ops::ReduceMaxCompositeGradOpMaker,
ReduceMaxInferShapeFunctor);
REGISTER_OPERATOR(reduce_max_grad, ops::ReduceGradOp)
......@@ -38,6 +38,7 @@
- sqrt
- cumsum
- put_along_axis
- equal
- greater_than
- less_equal
- sin
......
......@@ -1044,6 +1044,57 @@ void gather_nd_grad(const Tensor& x,
}
}
template <typename T>
void max_grad(const Tensor& x,
const Tensor& out,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (!x_grad) {
return;
}
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
if (x_dim_size == 0 || x_dim_size == 1 || keepdim) {
auto out_grad_tmp = out_grad.expand(IntArray(x_dim));
auto out_tmp = out.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
} else {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
auto out_ = unsqueeze<T>(out, axis_);
auto out_grad_tmp = out_grad_.expand(IntArray(x_dim));
auto out_tmp = out_.expand(IntArray(x_dim));
auto mask = equal<T>(x, out_tmp);
x_grad_tmp = where<T>(mask, out_grad_tmp, zero_tensor);
}
set_output<T>(x_grad_tmp, x_grad);
}
template <typename T>
void assign_grad(const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
......
......@@ -727,6 +727,7 @@
param: [x]
kernel :
func : max_grad
composite : max_grad(x, out, out_grad, axis, keepdim, reduce_all, x_grad)
- backward_op : max_pool2d_with_index_grad
forward : max_pool2d_with_index(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool global_pooling, bool adaptive) -> Tensor(out), Tensor(mask)
......
......@@ -240,7 +240,9 @@ class TestMaxOp(OpTest):
def setUp(self):
self.op_type = "reduce_max"
self.prim_op_type = "prim"
self.python_api = paddle.max
self.public_python_api = paddle.max
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': [-1]}
self.outputs = {
......@@ -250,6 +252,16 @@ class TestMaxOp(OpTest):
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
# only composite op support gradient check of reduce_max
self.check_grad(
['X'],
'Out',
check_eager=True,
check_prim=True,
only_check_prim=True,
)
def test_raise_error(self):
if core.is_compiled_with_cuda():
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float16")}
......@@ -268,7 +280,10 @@ class TestMaxOp_ZeroDim(OpTest):
def setUp(self):
self.op_type = "reduce_max"
self.prim_op_type = "prim"
self.python_api = paddle.max
self.public_python_api = paddle.max
self.enable_cinn = False
self.inputs = {'X': np.random.random([]).astype("float64")}
self.attrs = {'dim': []}
self.outputs = {
......@@ -278,6 +293,46 @@ class TestMaxOp_ZeroDim(OpTest):
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
# only composite op support gradient check of reduce_max
self.check_grad(
['X'],
'Out',
check_eager=True,
check_prim=True,
only_check_prim=True,
)
class TestMaxOp_FP32(OpTest):
"""Remove Max with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.prim_op_type = "prim"
self.python_api = paddle.max
self.public_python_api = paddle.max
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': [-1], 'keep_dim': True}
self.outputs = {
'Out': self.inputs['X'].max(
axis=tuple(self.attrs['dim']), keepdims=True
)
}
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
# only composite op support gradient check of reduce_max
self.check_grad(
['X'],
'Out',
check_eager=True,
check_prim=True,
only_check_prim=True,
)
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
......@@ -829,7 +884,9 @@ class TestReduceMaxOpMultiAxises(OpTest):
def setUp(self):
self.op_type = "reduce_max"
self.prim_op_type = "prim"
self.python_api = paddle.max
self.public_python_api = paddle.max
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': [-2, -1]}
self.outputs = {
......@@ -839,6 +896,16 @@ class TestReduceMaxOpMultiAxises(OpTest):
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
# only composite op support gradient check of reduce_max
self.check_grad(
['X'],
'Out',
check_eager=True,
check_prim=True,
only_check_prim=True,
)
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册