未验证 提交 978d544b 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】delete high order prim flag && add special prune rules for node.cc (#51676)

* delete prim flag for matmul_2_grad

* delete prim flag for matmul_2_grad

* add new setgradoutmeta for matmul_double_grad_node

* modify test and delete log

* deal with review
上级 53bb883d
......@@ -60,6 +60,19 @@ black_ops_list = [
]
# white ops list whose kernel can be deleted after performance analysis
# original kernel and its derivative kernel can be deleted when composite_grad
# kernel performs same to it.
prim_white_list = ["matmul_double_grad"]
# dict of special api that forward api's output will affect bacward api's output
# bacward api's output usually affected by backward api's input
special_prune_dict = {
"matmul_grad": {"x": "grad_y", "y": "grad_x"},
"multiply_grad": {"x": "grad_y", "y": "grad_x"},
}
#########
# Utils #
#########
......@@ -977,12 +990,25 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
grad_node_out_list.append(name)
is_optional = name in self.optional_inputs
is_special_forward_api = (
True if forward_api_name in special_prune_dict else False
)
if is_optional:
set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
else:
set_grad_out_meta = (
f"{indent}grad_node->SetGradOutMeta({name}, {pos});"
)
if (
is_special_forward_api
and name in special_prune_dict[forward_api_name]
):
meta_name = GetAutoGradMetaName(
special_prune_dict[forward_api_name][name]
)
set_grad_out_meta = f"{indent}grad_node->SetGradOutMeta({name}, {meta_name}, {pos});"
else:
set_grad_out_meta = (
f"{indent}grad_node->SetGradOutMeta({name}, {pos});"
)
set_grad_out_meta_list.append(set_grad_out_meta)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
......@@ -2255,13 +2281,33 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
"""
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{
if composite_grad_api_name in prim_white_list:
grad_function_call_str = f"""
{indent}bool original_global_grad = egr::Controller::Instance().HasGrad();
{indent}if(!create_graph){{
{indent}{indent}egr::Controller::Instance().SetHasGrad(create_graph);
}}
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called ";
{indent}if(!create_graph){{
{indent}{indent}egr::Controller::Instance().SetHasGrad(original_global_grad);
}}
"""
else:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{
{indent}bool original_global_grad = egr::Controller::Instance().HasGrad();
{indent}if(!create_graph){{
{indent}{indent}egr::Controller::Instance().SetHasGrad(create_graph);
}}
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
{indent}VLOG(4) << "Composite api {composite_grad_api_name} is called ";
{indent}if(!create_graph){{
{indent}{indent}egr::Controller::Instance().SetHasGrad(original_global_grad);
}}
}}else{{
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});
VLOG(4) << "Fused api {backward_api_name} is called ";
{indent}VLOG(4) << "Fused api {backward_api_name} is called ";
}}
"""
else:
......
......@@ -263,6 +263,69 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
}
}
/*
special func for matmul_double_grad etc. dx exists when x and y exists,
if stop_gradient of y is true, dy is None who is matmul_grad's out_put, ddy is
None, so dx = ddy * dout should be None.
*/
void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
const AutogradMeta* fwd_out_meta,
size_t slot_rank) {
auto* fwd_in_meta = egr::EagerUtils::nullable_autograd_meta(fwd_in);
PADDLE_ENFORCE_LE(
(slot_rank + 1),
bwd_out_meta_.size(),
paddle::platform::errors::InvalidArgument(
"Slot Rank should less equal than bwd_out_meta_ size, "
"since bwd_out_meta_ is designed to hold as same num as "
"backward outputs."));
auto& metas = bwd_out_meta_.at(slot_rank);
// Init stop gradient vector before use to avoid push back
if (metas.size() == 0) {
metas.resize(1);
}
auto& meta = metas[0];
// Set Stop_gradient
if (fwd_in_meta && !fwd_in_meta->StopGradient() && fwd_out_meta) {
meta.SetStopGradient(false);
} else {
meta.SetStopGradient(true);
}
// Set Adj Edges
if (fwd_in_meta && !fwd_in_meta->StopGradient() && fwd_out_meta) {
auto node = fwd_in_meta->GetMutableGradNode();
if (!node || !node.get()) {
fwd_in_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta));
}
VLOG(3) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
<< this->name() << " (addr: " << this << ") "
<< " to " << fwd_in_meta->GetMutableGradNode()->name()
<< " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")";
meta.SetEdge(fwd_in_meta->GetMutableGradNode(), fwd_in_meta->OutRankInfo());
}
// Record TensorMeta
if (fwd_in.impl() && fwd_in.impl().get()) {
if (phi::DenseTensor::classof(fwd_in.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_in.impl().get());
PADDLE_ENFORCE_NE(
dense_tensor->meta().dtype,
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in.place());
}
} else {
VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument.";
}
}
void GradNodeBase::SetGradOutMeta(const std::vector<paddle::Tensor>& fwd_in,
size_t slot_rank) {
size_t slot_size = fwd_in.size();
......
......@@ -225,6 +225,9 @@ class GradNodeBase {
void SetGradOutMeta(const std::vector<const paddle::Tensor*>& fwd_in,
size_t slot_rank);
void SetGradOutMeta(const paddle::Tensor& fwd_in, size_t slot_rank);
void SetGradOutMeta(const paddle::Tensor& fwd_in,
const AutogradMeta* fwd_in_other,
size_t slot_rank);
/**
* Default setters for Grad in/out meta this should be used for same special
* Node which will not create by user
......
......@@ -681,8 +681,7 @@
param : [x, y, grad_out]
kernel :
func : matmul_double_grad
composite : matmul_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x, transpose_y, x_grad, y_grad, grad_out_grad)
backward : matmul_triple_grad
composite : matmul_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x=false, transpose_y=false)
optional : grad_x_grad, grad_y_grad
- backward_op : matmul_grad
......@@ -696,17 +695,6 @@
func : matmul_grad
backward : matmul_double_grad
- backward_op : matmul_triple_grad
forward : matmul_double_grad (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y), Tensor(grad_grad_out)
args : (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, Tensor grad_x_grad, Tensor grad_y_grad, Tensor grad_grad_out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad)
infer_meta :
func : GeneralQuinaryGradInferMeta
param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel :
func : matmul_triple_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad
- backward_op : max_grad
forward: max (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false)
......
......@@ -706,30 +706,38 @@ class TestDygraphDoubleGradMatmul(TestCase):
dout = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32'
)
(dx,) = paddle.grad(
[out], [x], [dout], retain_graph=True, create_graph=True
(dx, dy) = paddle.grad(
[out], [x, y], [dout], retain_graph=True, create_graph=True
)
ddx = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32'
)
ddy = ddx
dx_double_grad, dy_double_grad, ddout = paddle.grad(
[dx],
[dx, dy],
[x, y, dout],
[ddx],
[ddx, ddy],
retain_graph=True,
create_graph=True,
)
return dx_double_grad, dy_double_grad, ddout
def expected():
dx_double_grad_expected = np.zeros([3, 3], dtype="float32")
dx_double_grad_expected = np.matmul(
np.ones([3, 3], dtype="float32"),
np.ones([3, 3], dtype="float32"),
)
dy_double_grad_expected = np.matmul(
np.ones([3, 3], dtype="float32"),
np.ones([3, 3], dtype="float32"),
)
ddout_expected = np.matmul(
ddout_expected1 = np.matmul(
np.ones([3, 3], dtype="float32"), input_numpy_y
)
ddout_expected2 = np.matmul(
input_numpy_x, np.ones([3, 3], dtype="float32")
)
ddout_expected = ddout_expected1 + ddout_expected2
return (
dx_double_grad_expected,
dy_double_grad_expected,
......@@ -773,27 +781,26 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddy = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32'
)
dx_double_grad, dy_double_grad, ddout = paddle.grad(
# when x isnot be differentiate in first grad dy in second grad could be None in composite op
dx_double_grad, ddout = paddle.grad(
[dy],
[x, y, dout],
[x, dout],
[ddy],
retain_graph=True,
create_graph=True,
)
return dx_double_grad, dy_double_grad, ddout
return dx_double_grad, ddout
def expected():
dx_double_grad_expected = np.matmul(
np.ones([3, 3], dtype="float32"),
np.ones([3, 3], dtype="float32"),
)
dy_double_grad_expected = np.zeros([3, 3], dtype="float32")
ddout_expected = np.matmul(
input_numpy_x, np.ones([3, 3], dtype="float32")
)
return (
dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected,
)
......@@ -834,24 +841,23 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddy = paddle.to_tensor(
np.ones([3]), stop_gradient=False, dtype='float32'
)
dx_double_grad, dy_double_grad, ddout = paddle.grad(
# when x is not be differentiate in first grad, dy from second grad could be None in composite api.
dx_double_grad, ddout = paddle.grad(
[dy],
[x, y, dout],
[x, dout],
[ddy],
retain_graph=True,
create_graph=True,
)
return dx_double_grad, dy_double_grad, ddout
return dx_double_grad, ddout
def expected():
dx_double_grad_expected = np.ones([3], dtype="float32")
dy_double_grad_expected = np.zeros([3], dtype="float32")
ddout_expected = np.matmul(
input_numpy_x, np.ones([3], dtype="float32")
)
return (
dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected,
)
......@@ -892,23 +898,22 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddx = paddle.to_tensor(
np.ones([3]), stop_gradient=False, dtype='float32'
)
dx_double_grad, dy_double_grad, ddout = paddle.grad(
# when y is not be differentiate in first grad, dx from second grad could be None in composite api.
dy_double_grad, ddout = paddle.grad(
[dx],
[x, y, dout],
[y, dout],
[ddx],
retain_graph=True,
create_graph=True,
)
return dx_double_grad, dy_double_grad, ddout
return dy_double_grad, ddout
def expected():
dx_double_grad_expected = np.zeros([3], dtype="float32")
dy_double_grad_expected = np.ones([3], dtype="float32")
ddout_expected = np.matmul(
input_numpy_y, np.ones([3], dtype="float32")
)
return (
dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected,
)
......@@ -920,6 +925,7 @@ class TestDygraphDoubleGradMatmul(TestCase):
for place in places:
paddle.device.set_device(place)
actual_results = actual()
for expected_result, actual_result in zip(
expected_results, actual_results
):
......@@ -950,24 +956,22 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddy = paddle.to_tensor(
np.ones([1]), stop_gradient=False, dtype='float32'
)
dx_double_grad, dy_double_grad, ddout = paddle.grad(
dx_double_grad, ddout = paddle.grad(
[dy],
[x, y, dout],
[x, dout],
[ddy],
retain_graph=True,
create_graph=True,
)
return dx_double_grad, dy_double_grad, ddout
return dx_double_grad, ddout
def expected():
dx_double_grad_expected = np.ones([2, 1], dtype="float32")
dy_double_grad_expected = np.zeros([1], dtype="float32")
ddout_expected = np.matmul(
input_numpy_x, np.ones([1], dtype="float32")
)
return (
dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected,
)
......@@ -1008,21 +1012,19 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddx = paddle.to_tensor(
np.ones([2, 1]), stop_gradient=False, dtype='float32'
)
dx_double_grad, dy_double_grad, ddout = paddle.grad(
dy_double_grad, ddout = paddle.grad(
[dx],
[x, y, dout],
[y, dout],
[ddx],
retain_graph=True,
create_graph=True,
)
return dx_double_grad, dy_double_grad, ddout
return dy_double_grad, ddout
def expected():
dx_double_grad_expected = np.zeros([2, 1], dtype="float32")
dy_double_grad_expected = np.ones([1], dtype="float32") * 2
ddout_expected = np.ones([2], dtype="float32") * input_numpy_y[0]
return (
dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected,
)
......@@ -1041,6 +1043,8 @@ class TestDygraphDoubleGradMatmul(TestCase):
expected_result, actual_result, rtol=1e-6
)
# TODO(Ruting) test complex dtype when composite api support
'''
# case7: ddx is none, dims = 1, complex dtype
def test_matmul_double_grad_case7(self):
input_numpy_x = np.random.random([3]).astype(
......@@ -1069,19 +1073,17 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddx = paddle.to_tensor(
np.ones([3]), stop_gradient=False, dtype='complex64'
)
dx_double_grad, dy_double_grad, ddout = paddle.grad(
# when y is not be differentiate in first grad, dx from second grad could be None in composite api.
dy_double_grad, ddout = paddle.grad(
[dx],
[x, y, dout],
[y, dout],
[ddx],
retain_graph=True,
create_graph=True,
)
return dx_double_grad, dy_double_grad, ddout
return dy_double_grad, ddout
def expected():
dx_double_grad_expected = np.zeros(
[3], dtype="float32"
) + 0j * np.zeros([3], dtype="float32")
dy_double_grad_expected = np.ones(
[3], dtype="float32"
) + 0j * np.ones([3], dtype="float32")
......@@ -1089,7 +1091,6 @@ class TestDygraphDoubleGradMatmul(TestCase):
input_numpy_y_conj, np.ones([3], dtype="float32")
)
return (
dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected,
)
......@@ -1108,6 +1109,7 @@ class TestDygraphDoubleGradMatmul(TestCase):
expected_result, actual_result, rtol=1e-6
)
# case8: ddy is none, dims = 1, complex dtype
def test_matmul_double_grad_case8(self):
input_numpy_x = np.random.random([3]).astype(
......@@ -1136,24 +1138,22 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddy = paddle.to_tensor(
np.ones([3]), stop_gradient=False, dtype='complex64'
)
dx_double_grad, dy_double_grad, ddout = paddle.grad(
dx_double_grad, ddout = paddle.grad(
[dy],
[x, y, dout],
[x, dout],
[ddy],
retain_graph=True,
create_graph=True,
)
return dx_double_grad, dy_double_grad, ddout
return dx_double_grad, ddout
def expected():
dx_double_grad_expected = np.ones([3], dtype="float32")
dy_double_grad_expected = np.zeros([3], dtype="float32")
ddout_expected = np.matmul(
input_numpy_x_conj, np.ones([3], dtype="float32")
)
return (
dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected,
)
......@@ -1170,6 +1170,39 @@ class TestDygraphDoubleGradMatmul(TestCase):
np.testing.assert_allclose(
expected_result, actual_result, rtol=1e-6
)
'''
def test_value_error(self):
def test():
import paddle
import paddle.nn as nn
model = nn.Sequential(nn.Linear(3, 4))
x = paddle.randn([4, 1])
y = paddle.randn([4, 1])
z = paddle.randn([4, 1])
x.stop_gradient = False
y.stop_gradient = False
z.stop_gradient = False
out = model(paddle.concat((x, y, z), axis=1))
data = {
"x": x,
"y": y,
"z": z,
"u": out[:, 0:1],
"v": out[:, 1:2],
"w": out[:, 2:3],
"p": out[:, 3:4],
}
v = out[:, 1:2]
z = paddle.grad(v, x, create_graph=True)[0]
zz = paddle.grad(z, x, create_graph=True)[0]
with self.assertRaises(ValueError):
test()
if __name__ == '__main__':
......
......@@ -89,17 +89,17 @@ class TestDygraphTripleGradMatmul(TestCase):
np.testing.assert_array_equal(new_c.numpy(), new_c_ref)
x_grad_ref = np.ones([3, 3]) * 0.0
np.testing.assert_array_equal(x.grad.numpy(), x_grad_ref)
assert x.grad is None
y_grad_ref = np.ones([3, 3]) * 0.0
np.testing.assert_array_equal(y.grad.numpy(), y_grad_ref)
assert y.grad is None
new_out_g_ref = np.ones([3, 3]) * 3.0
np.testing.assert_array_equal(new_out_g.grad.numpy(), new_out_g_ref)
new_x_g_g_ref = np.ones([3, 3]) * 0.0
new_y_g_g_ref = np.ones([3, 3]) * 3.0
np.testing.assert_array_equal(new_x_g_g.grad.numpy(), new_x_g_g_ref)
assert new_x_g_g.grad is None
np.testing.assert_array_equal(new_y_g_g.grad.numpy(), new_y_g_g_ref)
......@@ -359,13 +359,14 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
retain_graph=True,
create_graph=True,
)
d_x, d_y, d_dout, d_ddx, d_ddy = paddle.grad(
# d_x, d_y should be none because ddd_out = None
d_dout, d_ddx, d_ddy = paddle.grad(
[dx_double_grad, dy_double_grad],
[x, y, dout, ddx, ddy],
[dout, ddx, ddy],
retain_graph=False,
create_graph=False,
)
return d_x, d_y, d_dout, d_ddx, d_ddy
return d_dout, d_ddx, d_ddy
# case1: d_ddout is none, dims != 1
def test_matmul_triple_grad_case1(self):
......@@ -377,14 +378,10 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
self.input_numpy_ddy = np.ones([3, 3], dtype="float32")
init_data()
d_x_expected = np.zeros([3, 3], dtype="float32")
d_y_expected = np.zeros([3, 3], dtype="float32")
d_dout_expected = np.ones([3, 3], dtype="float32") * 6
d_ddx_expected = np.ones([3, 3], dtype="float32") * 3
d_ddy_expected = np.ones([3, 3], dtype="float32") * 3
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
......@@ -418,18 +415,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
self.input_numpy_ddy = np.ones([3], dtype="float32")
init_data()
d_x_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_y_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_dout_expected = np.ones([1], dtype="float32") * 6
d_ddx_expected = np.ones(
[
......@@ -444,8 +429,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
dtype="float32",
)
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
......@@ -475,8 +458,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
self.input_numpy_ddy = np.ones([1], dtype="float32")
init_data()
d_x_expected = np.zeros([3, 1], dtype="float32")
d_y_expected = np.zeros([1], dtype="float32")
d_dout_expected = (
np.ones(
[
......@@ -489,8 +470,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
d_ddx_expected = np.ones([3, 1], dtype="float32")
d_ddy_expected = np.ones([1], dtype="float32") * 3
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
......@@ -507,6 +486,7 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
)
'''
# d_ddout is none, dtype is complex64
class TestDygraphTripleGradMatmulcase2(TestCase):
def setUp(self):
......@@ -666,6 +646,7 @@ class TestDygraphTripleGradMatmulcase2(TestCase):
np.testing.assert_allclose(
expected_result, actual_result, rtol=1e-6
)
'''
# d_ddout is none, d_dx is none, dtype is float32
......@@ -708,13 +689,15 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
retain_graph=True,
create_graph=True,
)
d_x, d_y, d_dout, d_ddx, d_ddy = paddle.grad(
# d_x d_y is None because (double grad out_put ddout grad tensor)d_ddout is None
# d_ddy is None because (double grad out_put dx grad tensor) d_dx and d_ddout is None
d_dout, d_ddx = paddle.grad(
[dy_double_grad],
[x, y, dout, ddx, ddy],
[dout, ddx],
retain_graph=False,
create_graph=False,
)
return d_x, d_y, d_dout, d_ddx, d_ddy
return d_dout, d_ddx
# case1: d_ddout is none, d_dx is none, dims != 1
def test_matmul_triple_grad_case1(self):
......@@ -726,17 +709,11 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
self.input_numpy_ddy = np.ones([3, 3], dtype="float32")
init_data()
d_x_expected = np.zeros([3, 3], dtype="float32")
d_y_expected = np.zeros([3, 3], dtype="float32")
d_dout_expected = np.ones([3, 3], dtype="float32") * 3
d_ddx_expected = np.ones([3, 3], dtype="float32") * 3
d_ddy_expected = np.zeros([3, 3], dtype="float32")
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
)
for place in self.places:
......@@ -767,18 +744,6 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
self.input_numpy_ddy = np.ones([3], dtype="float32")
init_data()
d_x_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_y_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_dout_expected = np.ones([1], dtype="float32") * 3
d_ddx_expected = np.ones(
[
......@@ -786,18 +751,9 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
],
dtype="float32",
)
d_ddy_expected = np.zeros(
[
3,
],
dtype="float32",
)
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
)
for place in self.places:
......@@ -824,8 +780,6 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
self.input_numpy_ddy = np.ones([1], dtype="float32")
init_data()
d_x_expected = np.zeros([3, 1], dtype="float32")
d_y_expected = np.zeros([1], dtype="float32")
d_dout_expected = np.ones(
[
3,
......@@ -833,13 +787,9 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
dtype="float32",
)
d_ddx_expected = np.ones([3, 1], dtype="float32")
d_ddy_expected = np.zeros([1], dtype="float32")
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
)
for place in self.places:
......@@ -853,6 +803,7 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
)
'''
# d_ddout is none, d_dx is none, dtype is complex64
class TestDygraphTripleGradMatmulcase4(TestCase):
def setUp(self):
......@@ -999,6 +950,7 @@ class TestDygraphTripleGradMatmulcase4(TestCase):
np.testing.assert_allclose(
expected_result, actual_result, rtol=1e-6
)
'''
# d_ddout is none, d_dy is none, dtype is float32
......@@ -1041,13 +993,13 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
retain_graph=True,
create_graph=True,
)
d_x, d_y, d_dout, d_ddx, d_ddy = paddle.grad(
d_dout, d_ddy = paddle.grad(
[dx_double_grad],
[x, y, dout, ddx, ddy],
[dout, ddy],
retain_graph=False,
create_graph=False,
)
return d_x, d_y, d_dout, d_ddx, d_ddy
return d_dout, d_ddy
# case1: d_ddout is none, d_dy is none, dims != 1
def test_matmul_triple_grad_case1(self):
......@@ -1059,16 +1011,10 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
self.input_numpy_ddy = np.ones([3, 3], dtype="float32")
init_data()
d_x_expected = np.zeros([3, 3], dtype="float32")
d_y_expected = np.zeros([3, 3], dtype="float32")
d_dout_expected = np.ones([3, 3], dtype="float32") * 3
d_ddx_expected = np.zeros([3, 3], dtype="float32") * 3
d_ddy_expected = np.ones([3, 3], dtype="float32") * 3
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
)
......@@ -1100,25 +1046,7 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
self.input_numpy_ddy = np.ones([3], dtype="float32")
init_data()
d_x_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_y_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_dout_expected = np.ones([1], dtype="float32") * 3
d_ddx_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_ddy_expected = np.ones(
[
3,
......@@ -1126,10 +1054,7 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
dtype="float32",
)
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
)
......@@ -1157,21 +1082,15 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
self.input_numpy_ddy = np.ones([1], dtype="float32")
init_data()
d_x_expected = np.zeros([3, 1], dtype="float32")
d_y_expected = np.zeros([1], dtype="float32")
d_dout_expected = np.ones(
[
3,
],
dtype="float32",
)
d_ddx_expected = np.zeros([3, 1], dtype="float32")
d_ddy_expected = np.ones([1], dtype="float32") * 3
expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected,
d_ddx_expected,
d_ddy_expected,
)
......@@ -1186,6 +1105,8 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
)
'''
TODO(Ruting) test complex dtype when composite api support
# d_ddout is none, d_dy is none, dtype is complex64
class TestDygraphTripleGradMatmulcase6(TestCase):
def setUp(self):
......@@ -1332,7 +1253,7 @@ class TestDygraphTripleGradMatmulcase6(TestCase):
np.testing.assert_allclose(
expected_result, actual_result, rtol=1e-6
)
'''
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册