未验证 提交 4a484973 编写于 作者: H heyanru 提交者: GitHub

【prim】Maximum grad (#51006)

* refresh

* compat

* register

* testop

* fix

* fix

* fox

* cast

* cast

* fix

* type

* fix

* out

* cast

* fix

* fix

* fix

* broad

* broad

* broad

* fix

* fix

* fix

* fix

* fix

* broad

* broad

* numel

* fix

* fix

* fix

* fix

* cinn

* fix

* fix

* fix

* fix
上级 9751bd0d
......@@ -15,6 +15,9 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle {
namespace framework {
......@@ -68,6 +71,35 @@ class ElementwiseFMaxOpMaker : public ElementwiseOpMaker {
}
};
class ElementwiseMaxCompositeGradOpMaker
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor y = this->GetSingleForwardInput("Y");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite maximum_grad but we got: ",
axis));
VLOG(6) << "Runing maximum_grad composite func";
prim::maximum_grad<prim::DescTensor>(x, y, out_grad, axis, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
};
template <typename T>
class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -112,7 +144,8 @@ REGISTER_OPERATOR(elementwise_max,
ops::ElementwiseMaxOpMaker,
ops::ElementwiseOpInferVarType,
ops::ElementwiseMaxGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseMaxGradOpMaker<paddle::imperative::OpBase>);
ops::ElementwiseMaxGradOpMaker<paddle::imperative::OpBase>,
ops::ElementwiseMaxCompositeGradOpMaker);
REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);
......
......@@ -31,3 +31,5 @@
- pad
- cumsum
- put_along_axis
- greater_than
- less_equal
......@@ -898,5 +898,51 @@ void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
}
}
template <typename T>
void maximum_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* x_grad,
Tensor* y_grad) {
if (x_grad) {
auto x_tmp = cast<T>(greater_than<T>(x, y), out_grad.dtype());
auto dx_res = out_grad * x_tmp;
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
set_output<T>(dx_res, x_grad);
} else {
auto dx_reduce_res =
dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp, x_grad);
}
} else {
set_output<T>(dx_res, x_grad);
}
}
if (y_grad) {
auto y_tmp = cast<T>(less_equal<T>(x, y), out_grad.dtype());
auto dy_res = out_grad * y_tmp;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
set_output<T>(dy_res, y_grad);
} else {
auto dy_reduce_res =
dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp, y_grad);
}
} else {
set_output<T>(dy_res, y_grad);
}
}
}
} // namespace prim
} // namespace paddle
......@@ -39,6 +39,8 @@ class DescTensor : public phi::ExtendedTensor,
return dims_;
}
int64_t numel() const override { return product(dims()); }
DataType dtype() const override {
return paddle::framework::TransToPhiDataType(desc_ptr_->GetDataType());
}
......
......@@ -752,6 +752,7 @@
param: [x, y]
kernel :
func : maximum_grad
composite : maximum_grad(x, y, out_grad, axis, x_grad, y_grad)
- backward_op : mean_all_grad
forward : mean_all(Tensor x) -> Tensor(out)
......
......@@ -1213,6 +1213,7 @@ set(TEST_CINN_OPS
test_elementwise_mul_op
test_gather_nd_op
test_elementwise_pow_op
test_elementwise_max_op
test_transpose_op
test_reshape_op
test_mean_op
......
......@@ -25,6 +25,8 @@ class TestElementwiseOp(OpTest):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
......@@ -42,25 +44,58 @@ class TestElementwiseOp(OpTest):
def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
if self.attrs['axis'] == -1:
self.check_grad(
['X', 'Y'], 'Out', check_eager=False, check_prim=True
)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
self.check_grad(
['X', 'Y'], 'Out', check_eager=True, check_prim=True
)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")
)
if hasattr(self, 'attrs') and self.attrs['axis'] != -1:
self.check_grad(
['Y'],
'Out',
max_relative_error=0.005,
no_grad_set=set("X"),
)
else:
self.check_grad(
['Y'],
'Out',
max_relative_error=0.005,
no_grad_set=set("X"),
check_prim=True,
)
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')
)
if hasattr(self, 'attrs') and self.attrs['axis'] != -1:
self.check_grad(
['X'],
'Out',
max_relative_error=0.005,
no_grad_set=set('Y'),
)
else:
self.check_grad(
['X'],
'Out',
max_relative_error=0.005,
no_grad_set=set('Y'),
check_prim=True,
)
class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.uniform(0.1, 1, []).astype("float64")
y = np.random.uniform(0.1, 1, []).astype("float64")
self.inputs = {'X': x, 'Y': y}
......@@ -71,6 +106,8 @@ class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
y = np.random.uniform(0.1, 1, []).astype("float64")
self.inputs = {'X': x, 'Y': y}
......@@ -81,6 +118,8 @@ class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.uniform(0.1, 1, []).astype("float64")
y = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
self.inputs = {'X': x, 'Y': y}
......@@ -99,6 +138,8 @@ class TestElementwiseBF16Op(OpTest):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
self.dtype = np.uint16
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
......@@ -120,6 +161,7 @@ class TestElementwiseBF16Op(OpTest):
def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
# check_prim=False, bfloat16 is not supported in `less_equal`
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
......@@ -138,6 +180,8 @@ class TestElementwiseMaxOp_scalar(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.random_integers(-5, 5, [2, 3, 20]).astype("float64")
y = np.array([0.5]).astype("float64")
self.inputs = {'X': x, 'Y': y}
......@@ -148,6 +192,8 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.random((100,)).astype("float64")
sgn = np.random.choice([-1, 1], (100,)).astype("float64")
y = x + sgn * np.random.uniform(0.1, 1, (100,)).astype("float64")
......@@ -159,6 +205,7 @@ class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (100, 5, 2)).astype(np.float64)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float64)
y = x[:, 0, 0] + sgn * np.random.uniform(1, 2, (100,)).astype(
......@@ -178,6 +225,7 @@ class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (2, 100, 3)).astype(np.float64)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float64)
y = x[0, :, 0] + sgn * np.random.uniform(1, 2, (100,)).astype(
......@@ -197,6 +245,7 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (1, 3, 100)).astype(np.float64)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float64)
y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype(
......@@ -215,6 +264,7 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (2, 50, 2, 1)).astype(np.float64)
sgn = np.random.choice([-1, 1], (50, 2)).astype(np.float64)
y = x[0, :, :, 0] + sgn * np.random.uniform(1, 2, (50, 2)).astype(
......@@ -234,6 +284,7 @@ class TestElementwiseMaxOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float64)
sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float64)
y = x + sgn * np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float64)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册