未验证 提交 978d544b 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】delete high order prim flag && add special prune rules for node.cc (#51676)

* delete prim flag for matmul_2_grad

* delete prim flag for matmul_2_grad

* add new setgradoutmeta for matmul_double_grad_node

* modify test and delete log

* deal with review
上级 53bb883d
...@@ -60,6 +60,19 @@ black_ops_list = [ ...@@ -60,6 +60,19 @@ black_ops_list = [
] ]
# white ops list whose kernel can be deleted after performance analysis
# original kernel and its derivative kernel can be deleted when composite_grad
# kernel performs same to it.
prim_white_list = ["matmul_double_grad"]
# dict of special api that forward api's output will affect bacward api's output
# bacward api's output usually affected by backward api's input
special_prune_dict = {
"matmul_grad": {"x": "grad_y", "y": "grad_x"},
"multiply_grad": {"x": "grad_y", "y": "grad_x"},
}
######### #########
# Utils # # Utils #
######### #########
...@@ -977,8 +990,21 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -977,8 +990,21 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
grad_node_out_list.append(name) grad_node_out_list.append(name)
is_optional = name in self.optional_inputs is_optional = name in self.optional_inputs
is_special_forward_api = (
True if forward_api_name in special_prune_dict else False
)
if is_optional: if is_optional:
set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
else:
if (
is_special_forward_api
and name in special_prune_dict[forward_api_name]
):
meta_name = GetAutoGradMetaName(
special_prune_dict[forward_api_name][name]
)
set_grad_out_meta = f"{indent}grad_node->SetGradOutMeta({name}, {meta_name}, {pos});"
else: else:
set_grad_out_meta = ( set_grad_out_meta = (
f"{indent}grad_node->SetGradOutMeta({name}, {pos});" f"{indent}grad_node->SetGradOutMeta({name}, {pos});"
...@@ -2255,13 +2281,33 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2255,13 +2281,33 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
""" """
# TODO(Ruting):using composite only when we don't have backward kernel in the future. # TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api: elif is_composite_grad_api:
if composite_grad_api_name in prim_white_list:
grad_function_call_str = f""" grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{ {indent}bool original_global_grad = egr::Controller::Instance().HasGrad();
{indent}if(!create_graph){{
{indent}{indent}egr::Controller::Instance().SetHasGrad(create_graph);
}}
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called "; VLOG(4) << "Composite api {composite_grad_api_name} is called ";
{indent}if(!create_graph){{
{indent}{indent}egr::Controller::Instance().SetHasGrad(original_global_grad);
}}
"""
else:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{
{indent}bool original_global_grad = egr::Controller::Instance().HasGrad();
{indent}if(!create_graph){{
{indent}{indent}egr::Controller::Instance().SetHasGrad(create_graph);
}}
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
{indent}VLOG(4) << "Composite api {composite_grad_api_name} is called ";
{indent}if(!create_graph){{
{indent}{indent}egr::Controller::Instance().SetHasGrad(original_global_grad);
}}
}}else{{ }}else{{
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str}); {indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});
VLOG(4) << "Fused api {backward_api_name} is called "; {indent}VLOG(4) << "Fused api {backward_api_name} is called ";
}} }}
""" """
else: else:
......
...@@ -263,6 +263,69 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in, ...@@ -263,6 +263,69 @@ void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
} }
} }
/*
special func for matmul_double_grad etc. dx exists when x and y exists,
if stop_gradient of y is true, dy is None who is matmul_grad's out_put, ddy is
None, so dx = ddy * dout should be None.
*/
void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in,
const AutogradMeta* fwd_out_meta,
size_t slot_rank) {
auto* fwd_in_meta = egr::EagerUtils::nullable_autograd_meta(fwd_in);
PADDLE_ENFORCE_LE(
(slot_rank + 1),
bwd_out_meta_.size(),
paddle::platform::errors::InvalidArgument(
"Slot Rank should less equal than bwd_out_meta_ size, "
"since bwd_out_meta_ is designed to hold as same num as "
"backward outputs."));
auto& metas = bwd_out_meta_.at(slot_rank);
// Init stop gradient vector before use to avoid push back
if (metas.size() == 0) {
metas.resize(1);
}
auto& meta = metas[0];
// Set Stop_gradient
if (fwd_in_meta && !fwd_in_meta->StopGradient() && fwd_out_meta) {
meta.SetStopGradient(false);
} else {
meta.SetStopGradient(true);
}
// Set Adj Edges
if (fwd_in_meta && !fwd_in_meta->StopGradient() && fwd_out_meta) {
auto node = fwd_in_meta->GetMutableGradNode();
if (!node || !node.get()) {
fwd_in_meta->SetGradNode(
std::make_shared<egr::GradNodeAccumulation>(fwd_in_meta));
}
VLOG(3) << "Add Edges for slot: " << slot_rank << ", the Edge is from "
<< this->name() << " (addr: " << this << ") "
<< " to " << fwd_in_meta->GetMutableGradNode()->name()
<< " (addr: " << fwd_in_meta->GetMutableGradNode().get() << ")";
meta.SetEdge(fwd_in_meta->GetMutableGradNode(), fwd_in_meta->OutRankInfo());
}
// Record TensorMeta
if (fwd_in.impl() && fwd_in.impl().get()) {
if (phi::DenseTensor::classof(fwd_in.impl().get())) {
// Only Copy Meta
phi::DenseTensor* dense_tensor =
static_cast<phi::DenseTensor*>(fwd_in.impl().get());
PADDLE_ENFORCE_NE(
dense_tensor->meta().dtype,
phi::DataType::UNDEFINED,
paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
"which is illegal."));
meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in.place());
}
} else {
VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
"non-DenseTensor argument.";
}
}
void GradNodeBase::SetGradOutMeta(const std::vector<paddle::Tensor>& fwd_in, void GradNodeBase::SetGradOutMeta(const std::vector<paddle::Tensor>& fwd_in,
size_t slot_rank) { size_t slot_rank) {
size_t slot_size = fwd_in.size(); size_t slot_size = fwd_in.size();
......
...@@ -225,6 +225,9 @@ class GradNodeBase { ...@@ -225,6 +225,9 @@ class GradNodeBase {
void SetGradOutMeta(const std::vector<const paddle::Tensor*>& fwd_in, void SetGradOutMeta(const std::vector<const paddle::Tensor*>& fwd_in,
size_t slot_rank); size_t slot_rank);
void SetGradOutMeta(const paddle::Tensor& fwd_in, size_t slot_rank); void SetGradOutMeta(const paddle::Tensor& fwd_in, size_t slot_rank);
void SetGradOutMeta(const paddle::Tensor& fwd_in,
const AutogradMeta* fwd_in_other,
size_t slot_rank);
/** /**
* Default setters for Grad in/out meta this should be used for same special * Default setters for Grad in/out meta this should be used for same special
* Node which will not create by user * Node which will not create by user
......
...@@ -681,8 +681,7 @@ ...@@ -681,8 +681,7 @@
param : [x, y, grad_out] param : [x, y, grad_out]
kernel : kernel :
func : matmul_double_grad func : matmul_double_grad
composite : matmul_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x, transpose_y, x_grad, y_grad, grad_out_grad) composite : matmul_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, transpose_x=false, transpose_y=false)
backward : matmul_triple_grad
optional : grad_x_grad, grad_y_grad optional : grad_x_grad, grad_y_grad
- backward_op : matmul_grad - backward_op : matmul_grad
...@@ -696,17 +695,6 @@ ...@@ -696,17 +695,6 @@
func : matmul_grad func : matmul_grad
backward : matmul_double_grad backward : matmul_double_grad
- backward_op : matmul_triple_grad
forward : matmul_double_grad (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y), Tensor(grad_grad_out)
args : (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, Tensor grad_x_grad, Tensor grad_y_grad, Tensor grad_grad_out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad)
infer_meta :
func : GeneralQuinaryGradInferMeta
param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel :
func : matmul_triple_grad
optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad
- backward_op : max_grad - backward_op : max_grad
forward: max (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out) forward: max (Tensor x, IntArray axis={}, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false) args : (Tensor x, Tensor out, Tensor out_grad, IntArray axis={}, bool keepdim=false, bool reduce_all=false)
......
...@@ -706,30 +706,38 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -706,30 +706,38 @@ class TestDygraphDoubleGradMatmul(TestCase):
dout = paddle.to_tensor( dout = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32' np.ones([3, 3]), stop_gradient=False, dtype='float32'
) )
(dx,) = paddle.grad( (dx, dy) = paddle.grad(
[out], [x], [dout], retain_graph=True, create_graph=True [out], [x, y], [dout], retain_graph=True, create_graph=True
) )
ddx = paddle.to_tensor( ddx = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32' np.ones([3, 3]), stop_gradient=False, dtype='float32'
) )
ddy = ddx
dx_double_grad, dy_double_grad, ddout = paddle.grad( dx_double_grad, dy_double_grad, ddout = paddle.grad(
[dx], [dx, dy],
[x, y, dout], [x, y, dout],
[ddx], [ddx, ddy],
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
return dx_double_grad, dy_double_grad, ddout return dx_double_grad, dy_double_grad, ddout
def expected(): def expected():
dx_double_grad_expected = np.zeros([3, 3], dtype="float32") dx_double_grad_expected = np.matmul(
np.ones([3, 3], dtype="float32"),
np.ones([3, 3], dtype="float32"),
)
dy_double_grad_expected = np.matmul( dy_double_grad_expected = np.matmul(
np.ones([3, 3], dtype="float32"), np.ones([3, 3], dtype="float32"),
np.ones([3, 3], dtype="float32"), np.ones([3, 3], dtype="float32"),
) )
ddout_expected = np.matmul( ddout_expected1 = np.matmul(
np.ones([3, 3], dtype="float32"), input_numpy_y np.ones([3, 3], dtype="float32"), input_numpy_y
) )
ddout_expected2 = np.matmul(
input_numpy_x, np.ones([3, 3], dtype="float32")
)
ddout_expected = ddout_expected1 + ddout_expected2
return ( return (
dx_double_grad_expected, dx_double_grad_expected,
dy_double_grad_expected, dy_double_grad_expected,
...@@ -773,27 +781,26 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -773,27 +781,26 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddy = paddle.to_tensor( ddy = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32' np.ones([3, 3]), stop_gradient=False, dtype='float32'
) )
dx_double_grad, dy_double_grad, ddout = paddle.grad( # when x isnot be differentiate in first grad dy in second grad could be None in composite op
dx_double_grad, ddout = paddle.grad(
[dy], [dy],
[x, y, dout], [x, dout],
[ddy], [ddy],
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
return dx_double_grad, dy_double_grad, ddout return dx_double_grad, ddout
def expected(): def expected():
dx_double_grad_expected = np.matmul( dx_double_grad_expected = np.matmul(
np.ones([3, 3], dtype="float32"), np.ones([3, 3], dtype="float32"),
np.ones([3, 3], dtype="float32"), np.ones([3, 3], dtype="float32"),
) )
dy_double_grad_expected = np.zeros([3, 3], dtype="float32")
ddout_expected = np.matmul( ddout_expected = np.matmul(
input_numpy_x, np.ones([3, 3], dtype="float32") input_numpy_x, np.ones([3, 3], dtype="float32")
) )
return ( return (
dx_double_grad_expected, dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected, ddout_expected,
) )
...@@ -834,24 +841,23 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -834,24 +841,23 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddy = paddle.to_tensor( ddy = paddle.to_tensor(
np.ones([3]), stop_gradient=False, dtype='float32' np.ones([3]), stop_gradient=False, dtype='float32'
) )
dx_double_grad, dy_double_grad, ddout = paddle.grad( # when x is not be differentiate in first grad, dy from second grad could be None in composite api.
dx_double_grad, ddout = paddle.grad(
[dy], [dy],
[x, y, dout], [x, dout],
[ddy], [ddy],
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
return dx_double_grad, dy_double_grad, ddout return dx_double_grad, ddout
def expected(): def expected():
dx_double_grad_expected = np.ones([3], dtype="float32") dx_double_grad_expected = np.ones([3], dtype="float32")
dy_double_grad_expected = np.zeros([3], dtype="float32")
ddout_expected = np.matmul( ddout_expected = np.matmul(
input_numpy_x, np.ones([3], dtype="float32") input_numpy_x, np.ones([3], dtype="float32")
) )
return ( return (
dx_double_grad_expected, dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected, ddout_expected,
) )
...@@ -892,23 +898,22 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -892,23 +898,22 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddx = paddle.to_tensor( ddx = paddle.to_tensor(
np.ones([3]), stop_gradient=False, dtype='float32' np.ones([3]), stop_gradient=False, dtype='float32'
) )
dx_double_grad, dy_double_grad, ddout = paddle.grad( # when y is not be differentiate in first grad, dx from second grad could be None in composite api.
dy_double_grad, ddout = paddle.grad(
[dx], [dx],
[x, y, dout], [y, dout],
[ddx], [ddx],
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
return dx_double_grad, dy_double_grad, ddout return dy_double_grad, ddout
def expected(): def expected():
dx_double_grad_expected = np.zeros([3], dtype="float32")
dy_double_grad_expected = np.ones([3], dtype="float32") dy_double_grad_expected = np.ones([3], dtype="float32")
ddout_expected = np.matmul( ddout_expected = np.matmul(
input_numpy_y, np.ones([3], dtype="float32") input_numpy_y, np.ones([3], dtype="float32")
) )
return ( return (
dx_double_grad_expected,
dy_double_grad_expected, dy_double_grad_expected,
ddout_expected, ddout_expected,
) )
...@@ -920,6 +925,7 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -920,6 +925,7 @@ class TestDygraphDoubleGradMatmul(TestCase):
for place in places: for place in places:
paddle.device.set_device(place) paddle.device.set_device(place)
actual_results = actual() actual_results = actual()
for expected_result, actual_result in zip( for expected_result, actual_result in zip(
expected_results, actual_results expected_results, actual_results
): ):
...@@ -950,24 +956,22 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -950,24 +956,22 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddy = paddle.to_tensor( ddy = paddle.to_tensor(
np.ones([1]), stop_gradient=False, dtype='float32' np.ones([1]), stop_gradient=False, dtype='float32'
) )
dx_double_grad, dy_double_grad, ddout = paddle.grad( dx_double_grad, ddout = paddle.grad(
[dy], [dy],
[x, y, dout], [x, dout],
[ddy], [ddy],
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
return dx_double_grad, dy_double_grad, ddout return dx_double_grad, ddout
def expected(): def expected():
dx_double_grad_expected = np.ones([2, 1], dtype="float32") dx_double_grad_expected = np.ones([2, 1], dtype="float32")
dy_double_grad_expected = np.zeros([1], dtype="float32")
ddout_expected = np.matmul( ddout_expected = np.matmul(
input_numpy_x, np.ones([1], dtype="float32") input_numpy_x, np.ones([1], dtype="float32")
) )
return ( return (
dx_double_grad_expected, dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected, ddout_expected,
) )
...@@ -1008,21 +1012,19 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -1008,21 +1012,19 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddx = paddle.to_tensor( ddx = paddle.to_tensor(
np.ones([2, 1]), stop_gradient=False, dtype='float32' np.ones([2, 1]), stop_gradient=False, dtype='float32'
) )
dx_double_grad, dy_double_grad, ddout = paddle.grad( dy_double_grad, ddout = paddle.grad(
[dx], [dx],
[x, y, dout], [y, dout],
[ddx], [ddx],
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
return dx_double_grad, dy_double_grad, ddout return dy_double_grad, ddout
def expected(): def expected():
dx_double_grad_expected = np.zeros([2, 1], dtype="float32")
dy_double_grad_expected = np.ones([1], dtype="float32") * 2 dy_double_grad_expected = np.ones([1], dtype="float32") * 2
ddout_expected = np.ones([2], dtype="float32") * input_numpy_y[0] ddout_expected = np.ones([2], dtype="float32") * input_numpy_y[0]
return ( return (
dx_double_grad_expected,
dy_double_grad_expected, dy_double_grad_expected,
ddout_expected, ddout_expected,
) )
...@@ -1041,6 +1043,8 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -1041,6 +1043,8 @@ class TestDygraphDoubleGradMatmul(TestCase):
expected_result, actual_result, rtol=1e-6 expected_result, actual_result, rtol=1e-6
) )
# TODO(Ruting) test complex dtype when composite api support
'''
# case7: ddx is none, dims = 1, complex dtype # case7: ddx is none, dims = 1, complex dtype
def test_matmul_double_grad_case7(self): def test_matmul_double_grad_case7(self):
input_numpy_x = np.random.random([3]).astype( input_numpy_x = np.random.random([3]).astype(
...@@ -1069,19 +1073,17 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -1069,19 +1073,17 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddx = paddle.to_tensor( ddx = paddle.to_tensor(
np.ones([3]), stop_gradient=False, dtype='complex64' np.ones([3]), stop_gradient=False, dtype='complex64'
) )
dx_double_grad, dy_double_grad, ddout = paddle.grad( # when y is not be differentiate in first grad, dx from second grad could be None in composite api.
dy_double_grad, ddout = paddle.grad(
[dx], [dx],
[x, y, dout], [y, dout],
[ddx], [ddx],
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
return dx_double_grad, dy_double_grad, ddout return dy_double_grad, ddout
def expected(): def expected():
dx_double_grad_expected = np.zeros(
[3], dtype="float32"
) + 0j * np.zeros([3], dtype="float32")
dy_double_grad_expected = np.ones( dy_double_grad_expected = np.ones(
[3], dtype="float32" [3], dtype="float32"
) + 0j * np.ones([3], dtype="float32") ) + 0j * np.ones([3], dtype="float32")
...@@ -1089,7 +1091,6 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -1089,7 +1091,6 @@ class TestDygraphDoubleGradMatmul(TestCase):
input_numpy_y_conj, np.ones([3], dtype="float32") input_numpy_y_conj, np.ones([3], dtype="float32")
) )
return ( return (
dx_double_grad_expected,
dy_double_grad_expected, dy_double_grad_expected,
ddout_expected, ddout_expected,
) )
...@@ -1108,6 +1109,7 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -1108,6 +1109,7 @@ class TestDygraphDoubleGradMatmul(TestCase):
expected_result, actual_result, rtol=1e-6 expected_result, actual_result, rtol=1e-6
) )
# case8: ddy is none, dims = 1, complex dtype # case8: ddy is none, dims = 1, complex dtype
def test_matmul_double_grad_case8(self): def test_matmul_double_grad_case8(self):
input_numpy_x = np.random.random([3]).astype( input_numpy_x = np.random.random([3]).astype(
...@@ -1136,24 +1138,22 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -1136,24 +1138,22 @@ class TestDygraphDoubleGradMatmul(TestCase):
ddy = paddle.to_tensor( ddy = paddle.to_tensor(
np.ones([3]), stop_gradient=False, dtype='complex64' np.ones([3]), stop_gradient=False, dtype='complex64'
) )
dx_double_grad, dy_double_grad, ddout = paddle.grad( dx_double_grad, ddout = paddle.grad(
[dy], [dy],
[x, y, dout], [x, dout],
[ddy], [ddy],
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
return dx_double_grad, dy_double_grad, ddout return dx_double_grad, ddout
def expected(): def expected():
dx_double_grad_expected = np.ones([3], dtype="float32") dx_double_grad_expected = np.ones([3], dtype="float32")
dy_double_grad_expected = np.zeros([3], dtype="float32")
ddout_expected = np.matmul( ddout_expected = np.matmul(
input_numpy_x_conj, np.ones([3], dtype="float32") input_numpy_x_conj, np.ones([3], dtype="float32")
) )
return ( return (
dx_double_grad_expected, dx_double_grad_expected,
dy_double_grad_expected,
ddout_expected, ddout_expected,
) )
...@@ -1170,6 +1170,39 @@ class TestDygraphDoubleGradMatmul(TestCase): ...@@ -1170,6 +1170,39 @@ class TestDygraphDoubleGradMatmul(TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
expected_result, actual_result, rtol=1e-6 expected_result, actual_result, rtol=1e-6
) )
'''
def test_value_error(self):
def test():
import paddle
import paddle.nn as nn
model = nn.Sequential(nn.Linear(3, 4))
x = paddle.randn([4, 1])
y = paddle.randn([4, 1])
z = paddle.randn([4, 1])
x.stop_gradient = False
y.stop_gradient = False
z.stop_gradient = False
out = model(paddle.concat((x, y, z), axis=1))
data = {
"x": x,
"y": y,
"z": z,
"u": out[:, 0:1],
"v": out[:, 1:2],
"w": out[:, 2:3],
"p": out[:, 3:4],
}
v = out[:, 1:2]
z = paddle.grad(v, x, create_graph=True)[0]
zz = paddle.grad(z, x, create_graph=True)[0]
with self.assertRaises(ValueError):
test()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -89,17 +89,17 @@ class TestDygraphTripleGradMatmul(TestCase): ...@@ -89,17 +89,17 @@ class TestDygraphTripleGradMatmul(TestCase):
np.testing.assert_array_equal(new_c.numpy(), new_c_ref) np.testing.assert_array_equal(new_c.numpy(), new_c_ref)
x_grad_ref = np.ones([3, 3]) * 0.0 x_grad_ref = np.ones([3, 3]) * 0.0
np.testing.assert_array_equal(x.grad.numpy(), x_grad_ref) assert x.grad is None
y_grad_ref = np.ones([3, 3]) * 0.0 y_grad_ref = np.ones([3, 3]) * 0.0
np.testing.assert_array_equal(y.grad.numpy(), y_grad_ref) assert y.grad is None
new_out_g_ref = np.ones([3, 3]) * 3.0 new_out_g_ref = np.ones([3, 3]) * 3.0
np.testing.assert_array_equal(new_out_g.grad.numpy(), new_out_g_ref) np.testing.assert_array_equal(new_out_g.grad.numpy(), new_out_g_ref)
new_x_g_g_ref = np.ones([3, 3]) * 0.0 new_x_g_g_ref = np.ones([3, 3]) * 0.0
new_y_g_g_ref = np.ones([3, 3]) * 3.0 new_y_g_g_ref = np.ones([3, 3]) * 3.0
np.testing.assert_array_equal(new_x_g_g.grad.numpy(), new_x_g_g_ref) assert new_x_g_g.grad is None
np.testing.assert_array_equal(new_y_g_g.grad.numpy(), new_y_g_g_ref) np.testing.assert_array_equal(new_y_g_g.grad.numpy(), new_y_g_g_ref)
...@@ -359,13 +359,14 @@ class TestDygraphTripleGradMatmulcase1(TestCase): ...@@ -359,13 +359,14 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
d_x, d_y, d_dout, d_ddx, d_ddy = paddle.grad( # d_x, d_y should be none because ddd_out = None
d_dout, d_ddx, d_ddy = paddle.grad(
[dx_double_grad, dy_double_grad], [dx_double_grad, dy_double_grad],
[x, y, dout, ddx, ddy], [dout, ddx, ddy],
retain_graph=False, retain_graph=False,
create_graph=False, create_graph=False,
) )
return d_x, d_y, d_dout, d_ddx, d_ddy return d_dout, d_ddx, d_ddy
# case1: d_ddout is none, dims != 1 # case1: d_ddout is none, dims != 1
def test_matmul_triple_grad_case1(self): def test_matmul_triple_grad_case1(self):
...@@ -377,14 +378,10 @@ class TestDygraphTripleGradMatmulcase1(TestCase): ...@@ -377,14 +378,10 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
self.input_numpy_ddy = np.ones([3, 3], dtype="float32") self.input_numpy_ddy = np.ones([3, 3], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros([3, 3], dtype="float32")
d_y_expected = np.zeros([3, 3], dtype="float32")
d_dout_expected = np.ones([3, 3], dtype="float32") * 6 d_dout_expected = np.ones([3, 3], dtype="float32") * 6
d_ddx_expected = np.ones([3, 3], dtype="float32") * 3 d_ddx_expected = np.ones([3, 3], dtype="float32") * 3
d_ddy_expected = np.ones([3, 3], dtype="float32") * 3 d_ddy_expected = np.ones([3, 3], dtype="float32") * 3
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected, d_ddx_expected,
d_ddy_expected, d_ddy_expected,
...@@ -418,18 +415,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase): ...@@ -418,18 +415,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
self.input_numpy_ddy = np.ones([3], dtype="float32") self.input_numpy_ddy = np.ones([3], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_y_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_dout_expected = np.ones([1], dtype="float32") * 6 d_dout_expected = np.ones([1], dtype="float32") * 6
d_ddx_expected = np.ones( d_ddx_expected = np.ones(
[ [
...@@ -444,8 +429,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase): ...@@ -444,8 +429,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
dtype="float32", dtype="float32",
) )
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected, d_ddx_expected,
d_ddy_expected, d_ddy_expected,
...@@ -475,8 +458,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase): ...@@ -475,8 +458,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
self.input_numpy_ddy = np.ones([1], dtype="float32") self.input_numpy_ddy = np.ones([1], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros([3, 1], dtype="float32")
d_y_expected = np.zeros([1], dtype="float32")
d_dout_expected = ( d_dout_expected = (
np.ones( np.ones(
[ [
...@@ -489,8 +470,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase): ...@@ -489,8 +470,6 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
d_ddx_expected = np.ones([3, 1], dtype="float32") d_ddx_expected = np.ones([3, 1], dtype="float32")
d_ddy_expected = np.ones([1], dtype="float32") * 3 d_ddy_expected = np.ones([1], dtype="float32") * 3
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected, d_ddx_expected,
d_ddy_expected, d_ddy_expected,
...@@ -507,6 +486,7 @@ class TestDygraphTripleGradMatmulcase1(TestCase): ...@@ -507,6 +486,7 @@ class TestDygraphTripleGradMatmulcase1(TestCase):
) )
'''
# d_ddout is none, dtype is complex64 # d_ddout is none, dtype is complex64
class TestDygraphTripleGradMatmulcase2(TestCase): class TestDygraphTripleGradMatmulcase2(TestCase):
def setUp(self): def setUp(self):
...@@ -666,6 +646,7 @@ class TestDygraphTripleGradMatmulcase2(TestCase): ...@@ -666,6 +646,7 @@ class TestDygraphTripleGradMatmulcase2(TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
expected_result, actual_result, rtol=1e-6 expected_result, actual_result, rtol=1e-6
) )
'''
# d_ddout is none, d_dx is none, dtype is float32 # d_ddout is none, d_dx is none, dtype is float32
...@@ -708,13 +689,15 @@ class TestDygraphTripleGradMatmulcase3(TestCase): ...@@ -708,13 +689,15 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
d_x, d_y, d_dout, d_ddx, d_ddy = paddle.grad( # d_x d_y is None because (double grad out_put ddout grad tensor)d_ddout is None
# d_ddy is None because (double grad out_put dx grad tensor) d_dx and d_ddout is None
d_dout, d_ddx = paddle.grad(
[dy_double_grad], [dy_double_grad],
[x, y, dout, ddx, ddy], [dout, ddx],
retain_graph=False, retain_graph=False,
create_graph=False, create_graph=False,
) )
return d_x, d_y, d_dout, d_ddx, d_ddy return d_dout, d_ddx
# case1: d_ddout is none, d_dx is none, dims != 1 # case1: d_ddout is none, d_dx is none, dims != 1
def test_matmul_triple_grad_case1(self): def test_matmul_triple_grad_case1(self):
...@@ -726,17 +709,11 @@ class TestDygraphTripleGradMatmulcase3(TestCase): ...@@ -726,17 +709,11 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
self.input_numpy_ddy = np.ones([3, 3], dtype="float32") self.input_numpy_ddy = np.ones([3, 3], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros([3, 3], dtype="float32")
d_y_expected = np.zeros([3, 3], dtype="float32")
d_dout_expected = np.ones([3, 3], dtype="float32") * 3 d_dout_expected = np.ones([3, 3], dtype="float32") * 3
d_ddx_expected = np.ones([3, 3], dtype="float32") * 3 d_ddx_expected = np.ones([3, 3], dtype="float32") * 3
d_ddy_expected = np.zeros([3, 3], dtype="float32")
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected, d_ddx_expected,
d_ddy_expected,
) )
for place in self.places: for place in self.places:
...@@ -767,18 +744,6 @@ class TestDygraphTripleGradMatmulcase3(TestCase): ...@@ -767,18 +744,6 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
self.input_numpy_ddy = np.ones([3], dtype="float32") self.input_numpy_ddy = np.ones([3], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_y_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_dout_expected = np.ones([1], dtype="float32") * 3 d_dout_expected = np.ones([1], dtype="float32") * 3
d_ddx_expected = np.ones( d_ddx_expected = np.ones(
[ [
...@@ -786,18 +751,9 @@ class TestDygraphTripleGradMatmulcase3(TestCase): ...@@ -786,18 +751,9 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
], ],
dtype="float32", dtype="float32",
) )
d_ddy_expected = np.zeros(
[
3,
],
dtype="float32",
)
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected, d_ddx_expected,
d_ddy_expected,
) )
for place in self.places: for place in self.places:
...@@ -824,8 +780,6 @@ class TestDygraphTripleGradMatmulcase3(TestCase): ...@@ -824,8 +780,6 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
self.input_numpy_ddy = np.ones([1], dtype="float32") self.input_numpy_ddy = np.ones([1], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros([3, 1], dtype="float32")
d_y_expected = np.zeros([1], dtype="float32")
d_dout_expected = np.ones( d_dout_expected = np.ones(
[ [
3, 3,
...@@ -833,13 +787,9 @@ class TestDygraphTripleGradMatmulcase3(TestCase): ...@@ -833,13 +787,9 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
dtype="float32", dtype="float32",
) )
d_ddx_expected = np.ones([3, 1], dtype="float32") d_ddx_expected = np.ones([3, 1], dtype="float32")
d_ddy_expected = np.zeros([1], dtype="float32")
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected, d_ddx_expected,
d_ddy_expected,
) )
for place in self.places: for place in self.places:
...@@ -853,6 +803,7 @@ class TestDygraphTripleGradMatmulcase3(TestCase): ...@@ -853,6 +803,7 @@ class TestDygraphTripleGradMatmulcase3(TestCase):
) )
'''
# d_ddout is none, d_dx is none, dtype is complex64 # d_ddout is none, d_dx is none, dtype is complex64
class TestDygraphTripleGradMatmulcase4(TestCase): class TestDygraphTripleGradMatmulcase4(TestCase):
def setUp(self): def setUp(self):
...@@ -999,6 +950,7 @@ class TestDygraphTripleGradMatmulcase4(TestCase): ...@@ -999,6 +950,7 @@ class TestDygraphTripleGradMatmulcase4(TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
expected_result, actual_result, rtol=1e-6 expected_result, actual_result, rtol=1e-6
) )
'''
# d_ddout is none, d_dy is none, dtype is float32 # d_ddout is none, d_dy is none, dtype is float32
...@@ -1041,13 +993,13 @@ class TestDygraphTripleGradMatmulcase5(TestCase): ...@@ -1041,13 +993,13 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
retain_graph=True, retain_graph=True,
create_graph=True, create_graph=True,
) )
d_x, d_y, d_dout, d_ddx, d_ddy = paddle.grad( d_dout, d_ddy = paddle.grad(
[dx_double_grad], [dx_double_grad],
[x, y, dout, ddx, ddy], [dout, ddy],
retain_graph=False, retain_graph=False,
create_graph=False, create_graph=False,
) )
return d_x, d_y, d_dout, d_ddx, d_ddy return d_dout, d_ddy
# case1: d_ddout is none, d_dy is none, dims != 1 # case1: d_ddout is none, d_dy is none, dims != 1
def test_matmul_triple_grad_case1(self): def test_matmul_triple_grad_case1(self):
...@@ -1059,16 +1011,10 @@ class TestDygraphTripleGradMatmulcase5(TestCase): ...@@ -1059,16 +1011,10 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
self.input_numpy_ddy = np.ones([3, 3], dtype="float32") self.input_numpy_ddy = np.ones([3, 3], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros([3, 3], dtype="float32")
d_y_expected = np.zeros([3, 3], dtype="float32")
d_dout_expected = np.ones([3, 3], dtype="float32") * 3 d_dout_expected = np.ones([3, 3], dtype="float32") * 3
d_ddx_expected = np.zeros([3, 3], dtype="float32") * 3
d_ddy_expected = np.ones([3, 3], dtype="float32") * 3 d_ddy_expected = np.ones([3, 3], dtype="float32") * 3
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected,
d_ddy_expected, d_ddy_expected,
) )
...@@ -1100,25 +1046,7 @@ class TestDygraphTripleGradMatmulcase5(TestCase): ...@@ -1100,25 +1046,7 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
self.input_numpy_ddy = np.ones([3], dtype="float32") self.input_numpy_ddy = np.ones([3], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_y_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_dout_expected = np.ones([1], dtype="float32") * 3 d_dout_expected = np.ones([1], dtype="float32") * 3
d_ddx_expected = np.zeros(
[
3,
],
dtype="float32",
)
d_ddy_expected = np.ones( d_ddy_expected = np.ones(
[ [
3, 3,
...@@ -1126,10 +1054,7 @@ class TestDygraphTripleGradMatmulcase5(TestCase): ...@@ -1126,10 +1054,7 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
dtype="float32", dtype="float32",
) )
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected,
d_ddy_expected, d_ddy_expected,
) )
...@@ -1157,21 +1082,15 @@ class TestDygraphTripleGradMatmulcase5(TestCase): ...@@ -1157,21 +1082,15 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
self.input_numpy_ddy = np.ones([1], dtype="float32") self.input_numpy_ddy = np.ones([1], dtype="float32")
init_data() init_data()
d_x_expected = np.zeros([3, 1], dtype="float32")
d_y_expected = np.zeros([1], dtype="float32")
d_dout_expected = np.ones( d_dout_expected = np.ones(
[ [
3, 3,
], ],
dtype="float32", dtype="float32",
) )
d_ddx_expected = np.zeros([3, 1], dtype="float32")
d_ddy_expected = np.ones([1], dtype="float32") * 3 d_ddy_expected = np.ones([1], dtype="float32") * 3
expected_results = ( expected_results = (
d_x_expected,
d_y_expected,
d_dout_expected, d_dout_expected,
d_ddx_expected,
d_ddy_expected, d_ddy_expected,
) )
...@@ -1186,6 +1105,8 @@ class TestDygraphTripleGradMatmulcase5(TestCase): ...@@ -1186,6 +1105,8 @@ class TestDygraphTripleGradMatmulcase5(TestCase):
) )
'''
TODO(Ruting) test complex dtype when composite api support
# d_ddout is none, d_dy is none, dtype is complex64 # d_ddout is none, d_dy is none, dtype is complex64
class TestDygraphTripleGradMatmulcase6(TestCase): class TestDygraphTripleGradMatmulcase6(TestCase):
def setUp(self): def setUp(self):
...@@ -1332,7 +1253,7 @@ class TestDygraphTripleGradMatmulcase6(TestCase): ...@@ -1332,7 +1253,7 @@ class TestDygraphTripleGradMatmulcase6(TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
expected_result, actual_result, rtol=1e-6 expected_result, actual_result, rtol=1e-6
) )
'''
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册