未验证 提交 97690816 编写于 作者: warrentdrew's avatar warrentdrew 提交者: GitHub

add minimum grad composite rules (#52561)

* add minimum grad composite rules

* add public python api

* fix format

* fix format

* update testcase

* fix testcase

* fix format

* fix cmakelist.txt

* fix format

* fix param problem

* fix op and composite rule

* fix bf16 cpu support problem

* fix bf16 cpu issue

* fix axis error log

* add axis for maximum

* revert commit

* remove .orig

* fix generic problem

* revert max op

* fix axis error

* fix maximum axis

* fix test_check_output

* fix cinn

* fix minimum maximum axis check
上级 9d9f0ce5
...@@ -87,6 +87,13 @@ class ElementwiseMaxCompositeGradOpMaker ...@@ -87,6 +87,13 @@ class ElementwiseMaxCompositeGradOpMaker
auto* dy_ptr = this->GetOutputPtr(&dy); auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy); std::string dy_name = this->GetOutputName(dy);
VLOG(6) << "Runing maximum_grad composite func"; VLOG(6) << "Runing maximum_grad composite func";
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite maximum_grad but we got: ",
axis));
prim::maximum_grad<prim::DescTensor>(x, y, out_grad, dx_ptr, dy_ptr); prim::maximum_grad<prim::DescTensor>(x, y, out_grad, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name); this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name); this->RecoverOutputName(dy, dy_name);
......
...@@ -15,6 +15,9 @@ limitations under the License. */ ...@@ -15,6 +15,9 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -68,6 +71,35 @@ class ElementwiseFMinOpMaker : public ElementwiseOpMaker { ...@@ -68,6 +71,35 @@ class ElementwiseFMinOpMaker : public ElementwiseOpMaker {
} }
}; };
class ElementwiseMinCompositeGradOpMaker
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor y = this->GetSingleForwardInput("Y");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
VLOG(6) << "Runing minimum_grad composite func";
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite minimum_grad but we got: ",
axis));
prim::minimum_grad<prim::DescTensor>(x, y, out_grad, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
};
template <typename T> template <typename T>
class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> { class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -112,7 +144,8 @@ REGISTER_OPERATOR(elementwise_min, ...@@ -112,7 +144,8 @@ REGISTER_OPERATOR(elementwise_min,
ops::ElementwiseMinOpMaker, ops::ElementwiseMinOpMaker,
ops::ElementwiseOpInferVarType, ops::ElementwiseOpInferVarType,
ops::ElementwiseMinGradOpMaker<paddle::framework::OpDesc>, ops::ElementwiseMinGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseMinGradOpMaker<paddle::imperative::OpBase>); ops::ElementwiseMinGradOpMaker<paddle::imperative::OpBase>,
ops::ElementwiseMinCompositeGradOpMaker);
REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad); REGISTER_OPERATOR(elementwise_min_grad, ops::ElementwiseOpGrad);
......
...@@ -1571,6 +1571,51 @@ void gelu_grad(const Tensor& x, ...@@ -1571,6 +1571,51 @@ void gelu_grad(const Tensor& x,
} }
} }
template <typename T>
void minimum_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
Tensor* x_grad,
Tensor* y_grad) {
if (x_grad) {
auto x_tmp = cast<T>(less_than<T>(x, y), out_grad.dtype());
auto dx_res = out_grad * x_tmp;
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
set_output<T>(dx_res, x_grad);
} else {
auto dx_reduce_res =
dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp, x_grad);
}
} else {
set_output<T>(dx_res, x_grad);
}
}
if (y_grad) {
auto y_tmp = cast<T>(greater_equal<T>(x, y), out_grad.dtype());
auto dy_res = out_grad * y_tmp;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
set_output<T>(dy_res, y_grad);
} else {
auto dy_reduce_res =
dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp, y_grad);
}
} else {
set_output<T>(dy_res, y_grad);
}
}
}
template <typename T> template <typename T>
void tile_grad(const Tensor& x, void tile_grad(const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
......
...@@ -546,6 +546,7 @@ ...@@ -546,6 +546,7 @@
param: [x, y] param: [x, y]
kernel : kernel :
func : minimum_grad func : minimum_grad
composite : minimum_grad(x, y, out_grad, axis, x_grad, y_grad)
- backward_op : mish_grad - backward_op : mish_grad
forward : mish (Tensor x, float threshold) -> Tensor(out) forward : mish (Tensor x, float threshold) -> Tensor(out)
......
...@@ -1115,7 +1115,8 @@ set(TEST_CINN_OPS ...@@ -1115,7 +1115,8 @@ set(TEST_CINN_OPS
test_group_norm_op test_group_norm_op
test_tile_op test_tile_op
test_roll_op test_roll_op
test_sum_op) test_sum_op
test_elementwise_min_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN) if(WITH_CINN)
......
...@@ -34,6 +34,9 @@ class TestElementwiseOp(OpTest): ...@@ -34,6 +34,9 @@ class TestElementwiseOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
# If x and y have the same value, the min() is not differentiable. # If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method # So we generate test data by the following method
# to avoid them being too close to each other. # to avoid them being too close to each other.
...@@ -47,23 +50,60 @@ class TestElementwiseOp(OpTest): ...@@ -47,23 +50,60 @@ class TestElementwiseOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
if self.attrs['axis'] == -1:
self.check_grad(['X', 'Y'], 'Out', check_prim=True)
else:
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out')
else:
self.check_grad(['X', 'Y'], 'Out', check_prim=True)
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
if hasattr(self, 'attrs') and self.attrs['axis'] != -1:
self.check_grad(
['Y'],
'Out',
max_relative_error=0.005,
no_grad_set=set("X"),
)
else:
self.check_grad( self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X") ['Y'],
'Out',
max_relative_error=0.005,
no_grad_set=set("X"),
check_prim=True,
) )
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
if hasattr(self, 'attrs') and self.attrs['axis'] != -1:
self.check_grad( self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y') ['X'],
'Out',
max_relative_error=0.005,
no_grad_set=set('Y'),
check_dygraph=False,
)
else:
self.check_grad(
['X'],
'Out',
max_relative_error=0.005,
no_grad_set=set('Y'),
check_prim=True,
) )
def if_enable_cinn(self):
pass
class TestElementwiseFP16Op(TestElementwiseOp): class TestElementwiseFP16Op(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
self.dtype = np.float16 self.dtype = np.float16
# If x and y have the same value, the min() is not differentiable. # If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method # So we generate test data by the following method
...@@ -74,66 +114,81 @@ class TestElementwiseFP16Op(TestElementwiseOp): ...@@ -74,66 +114,81 @@ class TestElementwiseFP16Op(TestElementwiseOp):
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestElementwiseMinOp_ZeroDim1(TestElementwiseOp): class TestElementwiseMinOp_ZeroDim1(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.uniform(0.1, 1, []).astype("float64") x = np.random.uniform(0.1, 1, []).astype("float64")
y = np.random.uniform(0.1, 1, []).astype("float64") y = np.random.uniform(0.1, 1, []).astype("float64")
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMinFP16Op_ZeroDim1(TestElementwiseFP16Op): class TestElementwiseMinFP16Op_ZeroDim1(TestElementwiseFP16Op):
def init_data(self): def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype(np.float16) self.x = np.random.uniform(0.1, 1, []).astype(np.float16)
self.y = np.random.uniform(0.1, 1, []).astype(np.float16) self.y = np.random.uniform(0.1, 1, []).astype(np.float16)
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMinOp_ZeroDim2(TestElementwiseOp): class TestElementwiseMinOp_ZeroDim2(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") x = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
y = np.random.uniform(0.1, 1, []).astype("float64") y = np.random.uniform(0.1, 1, []).astype("float64")
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMinFP16Op_ZeroDim2(TestElementwiseFP16Op): class TestElementwiseMinFP16Op_ZeroDim2(TestElementwiseFP16Op):
def init_data(self): def init_data(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype("float16") self.x = np.random.uniform(0.1, 1, [13, 17]).astype("float16")
self.y = np.random.uniform(0.1, 1, []).astype("float16") self.y = np.random.uniform(0.1, 1, []).astype("float16")
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMinOp_ZeroDim3(TestElementwiseOp): class TestElementwiseMinOp_ZeroDim3(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.uniform(0.1, 1, []).astype("float64") x = np.random.uniform(0.1, 1, []).astype("float64")
y = np.random.uniform(0.1, 1, [13, 17]).astype("float64") y = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMinFP16Op_ZeroDim3(TestElementwiseFP16Op): class TestElementwiseMinFP16Op_ZeroDim3(TestElementwiseFP16Op):
def init_data(self): def init_data(self):
self.x = np.random.uniform(0.1, 1, []).astype("float16") self.x = np.random.uniform(0.1, 1, []).astype("float16")
self.y = np.random.uniform(0.1, 1, [13, 17]).astype("float16") self.y = np.random.uniform(0.1, 1, [13, 17]).astype("float16")
def if_enable_cinn(self):
self.enable_cinn = False
@skip_check_grad_ci( @skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast." reason="[skip shape check] Use y_shape(1) to test broadcast."
...@@ -142,6 +197,9 @@ class TestElementwiseMinOp_scalar(TestElementwiseOp): ...@@ -142,6 +197,9 @@ class TestElementwiseMinOp_scalar(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.random_integers(-5, 5, [10, 3, 4]).astype("float64") x = np.random.random_integers(-5, 5, [10, 3, 4]).astype("float64")
y = np.array([0.5]).astype("float64") y = np.array([0.5]).astype("float64")
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
...@@ -155,6 +213,9 @@ class TestElementwiseMinFP16Op_scalar(TestElementwiseFP16Op): ...@@ -155,6 +213,9 @@ class TestElementwiseMinFP16Op_scalar(TestElementwiseFP16Op):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.random_integers(-5, 5, [10, 3, 4]).astype(np.float16) x = np.random.random_integers(-5, 5, [10, 3, 4]).astype(np.float16)
y = np.array([0.5]).astype(np.float16) y = np.array([0.5]).astype(np.float16)
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
...@@ -165,6 +226,9 @@ class TestElementwiseMinOp_Vector(TestElementwiseOp): ...@@ -165,6 +226,9 @@ class TestElementwiseMinOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.random((100,)).astype("float64") x = np.random.random((100,)).astype("float64")
sgn = np.random.choice([-1, 1], (100,)).astype("float64") sgn = np.random.choice([-1, 1], (100,)).astype("float64")
y = x + sgn * np.random.uniform(0.1, 1, (100,)).astype("float64") y = x + sgn * np.random.uniform(0.1, 1, (100,)).astype("float64")
...@@ -176,6 +240,9 @@ class TestElementwiseMinFP16Op_Vector(TestElementwiseFP16Op): ...@@ -176,6 +240,9 @@ class TestElementwiseMinFP16Op_Vector(TestElementwiseFP16Op):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.random((100,)).astype(np.float16) x = np.random.random((100,)).astype(np.float16)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float16) sgn = np.random.choice([-1, 1], (100,)).astype(np.float16)
y = x + sgn * np.random.uniform(0.1, 1, (100,)).astype(np.float16) y = x + sgn * np.random.uniform(0.1, 1, (100,)).astype(np.float16)
...@@ -187,6 +254,9 @@ class TestElementwiseMinOp_broadcast_2(TestElementwiseOp): ...@@ -187,6 +254,9 @@ class TestElementwiseMinOp_broadcast_2(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = broadcast_wrapper(shape=[1, 1, 100]) self.python_api = broadcast_wrapper(shape=[1, 1, 100])
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(np.float64) x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(np.float64)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float64) sgn = np.random.choice([-1, 1], (100,)).astype(np.float64)
y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype( y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype(
...@@ -205,6 +275,9 @@ class TestElementwiseMinFP16Op_broadcast_2(TestElementwiseFP16Op): ...@@ -205,6 +275,9 @@ class TestElementwiseMinFP16Op_broadcast_2(TestElementwiseFP16Op):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = broadcast_wrapper(shape=[1, 1, 100]) self.python_api = broadcast_wrapper(shape=[1, 1, 100])
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(np.float16) x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(np.float16)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float16) sgn = np.random.choice([-1, 1], (100,)).astype(np.float16)
y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype( y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype(
...@@ -223,6 +296,9 @@ class TestElementwiseMinOp_broadcast_4(TestElementwiseOp): ...@@ -223,6 +296,9 @@ class TestElementwiseMinOp_broadcast_4(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.prim_op_type = "prim"
self.public_python_api = paddle.minimum
self.if_enable_cinn()
x = np.random.uniform(0.5, 1, (2, 10, 2, 5)).astype(np.float64) x = np.random.uniform(0.5, 1, (2, 10, 2, 5)).astype(np.float64)
sgn = np.random.choice([-1, 1], (2, 10, 1, 5)).astype(np.float64) sgn = np.random.choice([-1, 1], (2, 10, 1, 5)).astype(np.float64)
y = x + sgn * np.random.uniform(1, 2, (2, 10, 1, 5)).astype(np.float64) y = x + sgn * np.random.uniform(1, 2, (2, 10, 1, 5)).astype(np.float64)
...@@ -235,6 +311,9 @@ class TestElementwiseMinFP16Op_broadcast_4(TestElementwiseFP16Op): ...@@ -235,6 +311,9 @@ class TestElementwiseMinFP16Op_broadcast_4(TestElementwiseFP16Op):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum
self.prim_op_type = "prim"
self.if_enable_cinn()
x = np.random.uniform(0.5, 1, (2, 10, 2, 5)).astype(np.float16) x = np.random.uniform(0.5, 1, (2, 10, 2, 5)).astype(np.float16)
sgn = np.random.choice([-1, 1], (2, 10, 1, 5)).astype(np.float16) sgn = np.random.choice([-1, 1], (2, 10, 1, 5)).astype(np.float16)
y = x + sgn * np.random.uniform(1, 2, (2, 10, 1, 5)).astype(np.float16) y = x + sgn * np.random.uniform(1, 2, (2, 10, 1, 5)).astype(np.float16)
...@@ -268,7 +347,7 @@ class TestElementwiseBF16Op(OpTest): ...@@ -268,7 +347,7 @@ class TestElementwiseBF16Op(OpTest):
self.python_api = paddle.minimum self.python_api = paddle.minimum
self.public_python_api = paddle.minimum self.public_python_api = paddle.minimum
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.enable_cinn = False self.if_enable_cinn()
self.dtype = np.uint16 self.dtype = np.uint16
self.inputs = { self.inputs = {
'X': convert_float_to_uint16(self.x), 'X': convert_float_to_uint16(self.x),
...@@ -282,18 +361,83 @@ class TestElementwiseBF16Op(OpTest): ...@@ -282,18 +361,83 @@ class TestElementwiseBF16Op(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', numeric_grad_delta=0.05) places = self._get_places()
for place in places:
if type(place) is paddle.fluid.libpaddle.CPUPlace:
check_prim = False
else:
check_prim = True
self.check_grad_with_place(
place,
inputs_to_check=['X', 'Y'],
output_names='Out',
no_grad_set=None,
numeric_grad_delta=0.05,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
user_defined_grad_outputs=None,
check_dygraph=True,
check_prim=check_prim,
only_check_prim=False,
atol=1e-5,
check_cinn=False,
)
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
self.check_grad( places = self._get_places()
['Y'], 'Out', numeric_grad_delta=0.05, no_grad_set=set("X") for place in places:
if type(place) is paddle.fluid.libpaddle.CPUPlace:
check_prim = False
else:
check_prim = True
self.check_grad_with_place(
place,
inputs_to_check=['Y'],
output_names='Out',
no_grad_set=set("X"),
numeric_grad_delta=0.05,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
user_defined_grad_outputs=None,
check_dygraph=True,
check_prim=check_prim,
only_check_prim=False,
atol=1e-5,
check_cinn=False,
) )
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
self.check_grad( places = self._get_places()
['X'], 'Out', numeric_grad_delta=0.05, no_grad_set=set('Y') for place in places:
if type(place) is paddle.fluid.libpaddle.CPUPlace:
check_prim = False
else:
check_prim = True
self.check_grad_with_place(
place,
inputs_to_check=['Y'],
output_names='Out',
no_grad_set=set("X"),
numeric_grad_delta=0.05,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None,
user_defined_grad_outputs=None,
check_dygraph=True,
check_prim=check_prim,
only_check_prim=False,
atol=1e-5,
check_cinn=False,
) )
def if_enable_cinn(self):
self.enable_cinn = False
class TestElementwiseMinBF16Op_ZeroDim1(TestElementwiseBF16Op): class TestElementwiseMinBF16Op_ZeroDim1(TestElementwiseBF16Op):
def init_data(self): def init_data(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册