未验证 提交 8bfd62ff 编写于 作者: Z Zeng Jinle 提交者: GitHub

Expose dygraph.grad api (#23124)

* expose dygraph.grad api, test=develop, test=document_fix

* add more parameter in dygraph.grad API, test=develop

* add only_inputs=True parameter, test=develop

* follow comments, test=develop, test=document_fix

* fix typo, test=develop, test=document_fix
上级 0129f4b5
......@@ -83,7 +83,6 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
}
if (tensor && !tensor->IsInitialized()) {
// if grad var has OverridedStopGradient skip this Op
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place());
tensor->mutable_data(op.place(), var->DataType());
......@@ -139,16 +138,7 @@ void BasicEngine::PrepareDeps() {
q.pop();
for (auto& cur_op : *cur_node) {
PADDLE_ENFORCE_NE(
cur_op.GetInsMap().empty() && cur_op.GetOutsMap().empty(), true,
platform::errors::NotFound(
"Inputs and outputs of %s do not exist. "
"This may be because you call \"backward()\" twice for the same "
"subgraph. Please try to call \"stop_gradient = True\" or "
"\"detach()\" if you use some same vars between two "
"\"backward()\" "
"calls.",
cur_op.Type()));
cur_op.EnforceHasInOut();
PrepareGradAccumulators(cur_op);
}
......
......@@ -119,6 +119,20 @@ class OpBase {
void SetPlace(const platform::Place& place) { place_ = place; }
void EnforceHasInOut() const {
PADDLE_ENFORCE_NE(
ins_.empty() && outs_.empty(), true,
platform::errors::NotFound(
"Inputs and outputs of %s do not exist. This may be because:\n"
"1. You use some output variables of the previous batch as the "
"inputs of the current batch. Please try to call \"stop_gradient "
"= True\" or \"detach()\" for these variables.\n"
"2. You calculate backward twice for the same subgraph without "
"setting retain_graph=True. Please set retain_graph=True in the "
"first backward call.\n\n",
Type()));
}
static size_t GenerateUniqueId() {
static std::atomic<size_t> unique_id{0};
return unique_id.fetch_add(1);
......
......@@ -57,16 +57,15 @@ namespace imperative {
static void GetGraphInfoBetweenTargets(
std::unordered_set<VariableWrapper *> *input_target_grads,
std::unordered_set<VarBase *> *output_targets,
std::unordered_set<const OpBase *> *startup_ops_ptr,
std::unordered_map<const OpBase *, std::unordered_set<const OpBase *>>
*pending_ops_ptr,
std::unordered_map<const OpBase *, size_t> *op_deps_ptr,
std::unordered_set<OpBase *> *startup_ops_ptr,
std::unordered_map<OpBase *, std::unordered_set<OpBase *>> *pending_ops_ptr,
std::unordered_map<OpBase *, size_t> *op_deps_ptr,
std::unordered_set<VariableWrapper *> *related_grad_vars_ptr,
const std::unordered_set<VariableWrapper *> &no_grad_var_grad) {
/**
* Step 1. Find the candidate startup grad ops, prepared for following BFS.
*/
std::queue<std::pair<const OpBase *, const GradOpNode *>> q;
std::queue<std::pair<OpBase *, GradOpNode *>> q;
std::unordered_set<GradOpNode *> visited;
for (auto iter = output_targets->begin(); iter != output_targets->end();) {
auto *output_target = *iter;
......@@ -98,9 +97,8 @@ static void GetGraphInfoBetweenTargets(
* not all input_target_grads would be found.
*/
std::unordered_set<VariableWrapper *> found_input_target_grads;
std::unordered_set<const OpBase *> endpoint_ops;
std::unordered_map<const OpBase *, std::unordered_set<const OpBase *>>
preceding_ops;
std::unordered_set<OpBase *> endpoint_ops;
std::unordered_map<OpBase *, std::unordered_set<OpBase *>> preceding_ops;
while (!q.empty()) {
auto op_node_pair = q.front();
q.pop();
......@@ -153,8 +151,7 @@ static void GetGraphInfoBetweenTargets(
auto &target_vars = *related_grad_vars_ptr;
target_vars = *input_target_grads;
std::queue<std::pair<const OpBase * /*op*/, const OpBase * /*pending op*/>>
op_queue;
std::queue<std::pair<OpBase * /*op*/, OpBase * /*pending op*/>> op_queue;
for (auto &endpoint_op : endpoint_ops) {
op_queue.emplace(endpoint_op, nullptr);
}
......@@ -238,7 +235,7 @@ static void GetGraphInfoBetweenTargets(
for (auto iter = output_targets->begin(); iter != output_targets->end();) {
auto &grad_node = (*iter)->GradVarBase()->GradNode();
bool is_valid = std::find_if(grad_node->begin(), grad_node->end(),
[&](const OpBase &op) {
[&](OpBase &op) { // NOLINT
return startup_ops.count(&op) > 0;
}) != grad_node->end();
if (is_valid) {
......@@ -518,12 +515,13 @@ class PartialGradTask {
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place,
const detail::BackwardStrategy &strategy, bool create_graph);
const detail::BackwardStrategy &strategy, bool create_graph,
bool retain_graph, bool allow_unused, bool only_inputs);
std::vector<std::shared_ptr<VarBase>> Run();
private:
void RunEachOp(const OpBase *op);
void RunEachOp(OpBase *op);
void PrepareInitialReadyVarsMap(const OpBase *op);
......@@ -536,10 +534,9 @@ class PartialGradTask {
}
private:
std::unordered_set<const OpBase *> startup_ops_;
std::unordered_map<const OpBase *, std::unordered_set<const OpBase *>>
pending_ops_;
std::unordered_map<const OpBase *, size_t> op_deps_;
std::unordered_set<OpBase *> startup_ops_;
std::unordered_map<OpBase *, std::unordered_set<OpBase *>> pending_ops_;
std::unordered_map<OpBase *, size_t> op_deps_;
ReadyGradVarInfoMap ready_grad_vars_;
......@@ -562,6 +559,9 @@ class PartialGradTask {
platform::Place place_;
bool create_graph_;
bool retain_graph_;
bool allow_unused_;
bool only_inputs_;
detail::BackwardStrategy strategy_;
};
......@@ -571,12 +571,19 @@ PartialGradTask::PartialGradTask(
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place, const detail::BackwardStrategy &strategy,
bool create_graph) {
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) {
input_targets_ = input_targets;
place_ = place;
create_graph_ = create_graph;
retain_graph_ = retain_graph;
allow_unused_ = allow_unused;
only_inputs_ = only_inputs;
strategy_ = strategy;
PADDLE_ENFORCE_EQ(only_inputs_, true,
platform::errors::Unimplemented(
"only_inputs=False is not supported yet"));
for (auto &var : no_grad_vars) {
if (var && var->GradVarBase()) {
no_grad_var_grad_.insert(var->GradVarBase()->SharedVar().get());
......@@ -738,7 +745,7 @@ PartialGradTask::PartialGradTask(
std::vector<std::shared_ptr<VarBase>> PartialGradTask::Run() {
VLOG(10) << "Startup op number " << startup_ops_.size();
std::queue<const OpBase *> q;
std::queue<OpBase *> q;
for (auto *op : startup_ops_) {
q.push(op);
}
......@@ -746,8 +753,13 @@ std::vector<std::shared_ptr<VarBase>> PartialGradTask::Run() {
while (!q.empty()) {
auto *op = q.front();
q.pop();
VLOG(10) << "Start to run " << op->Type();
op->EnforceHasInOut();
RunEachOp(op);
if (!retain_graph_) {
op->ClearBackwardTrace();
}
VLOG(10) << "End to run " << op->Type();
auto iter = pending_ops_.find(op);
......@@ -773,7 +785,7 @@ std::vector<std::shared_ptr<VarBase>> PartialGradTask::Run() {
return CreateResult();
}
void PartialGradTask::RunEachOp(const OpBase *op) {
void PartialGradTask::RunEachOp(OpBase *op) {
// Prepare new inputs
NameVarMap<VarBase> tmp_ins;
for (auto &input_pair : op->GetInsMap()) {
......@@ -960,7 +972,8 @@ void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) {
std::vector<std::shared_ptr<VarBase>> PartialGradTask::CreateResult() {
std::vector<std::shared_ptr<VarBase>> result;
result.reserve(input_targets_.size());
for (auto &input_target : input_targets_) {
for (size_t i = 0; i < input_targets_.size(); ++i) {
auto &input_target = input_targets_[i];
PADDLE_ENFORCE_NOT_NULL(
input_target->GradVarBase(),
platform::errors::InvalidArgument("input should have gradient"));
......@@ -971,6 +984,12 @@ std::vector<std::shared_ptr<VarBase>> PartialGradTask::CreateResult() {
ready_var->SetOverridedStopGradient(!create_graph_);
result.emplace_back(std::move(ready_var));
} else { // return None if it does not appear in the graph
PADDLE_ENFORCE_EQ(allow_unused_, true,
platform::errors::InvalidArgument(
"The %d-th input does not appear in the backward "
"graph. Please check the input variable or set "
"allow_unused=True to get None result.",
i));
result.emplace_back();
}
}
......@@ -995,14 +1014,17 @@ PartialGradEngine::PartialGradEngine(
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place, const detail::BackwardStrategy &strategy,
bool create_graph)
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs)
: input_targets_(input_targets),
output_targets_(output_targets),
output_grads_(output_grads),
no_grad_vars_(no_grad_vars),
place_(place),
strategy_(strategy),
create_graph_(create_graph) {}
create_graph_(create_graph),
retain_graph_(retain_graph),
allow_unused_(allow_unused),
only_inputs_(only_inputs) {}
std::vector<std::shared_ptr<VarBase>> PartialGradEngine::GetResult() const {
return results_;
......@@ -1017,7 +1039,8 @@ void PartialGradEngine::Clear() {
void PartialGradEngine::Execute() {
PartialGradTask task(input_targets_, output_targets_, output_grads_,
no_grad_vars_, place_, strategy_, create_graph_);
no_grad_vars_, place_, strategy_, create_graph_,
retain_graph_, allow_unused_, only_inputs_);
VLOG(10) << "Starts to execute PartialGradEngine";
results_ = task.Run();
Clear();
......
......@@ -32,8 +32,8 @@ class PartialGradEngine : public Engine {
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place,
const detail::BackwardStrategy &strategy,
bool create_graph);
const detail::BackwardStrategy &strategy, bool create_graph,
bool retain_graph, bool allow_unused, bool only_inputs);
void Execute() override;
......@@ -50,6 +50,9 @@ class PartialGradEngine : public Engine {
platform::Place place_;
detail::BackwardStrategy strategy_;
bool create_graph_;
bool retain_graph_;
bool allow_unused_;
bool only_inputs_;
std::vector<std::shared_ptr<VarBase>> results_;
};
......
......@@ -783,10 +783,11 @@ void BindImperative(py::module *m_ptr) {
const std::vector<std::shared_ptr<imperative::VarBase>> &no_grad_vars,
const platform::Place &place,
const imperative::detail::BackwardStrategy &strategy,
bool create_graph) {
imperative::PartialGradEngine engine(input_targets, output_targets,
output_grads, no_grad_vars, place,
strategy, create_graph);
bool create_graph, bool retain_graph, bool allow_unused,
bool only_inputs) {
imperative::PartialGradEngine engine(
input_targets, output_targets, output_grads, no_grad_vars, place,
strategy, create_graph, retain_graph, allow_unused, only_inputs);
engine.Execute();
return engine.GetResult();
},
......
......@@ -23,6 +23,7 @@ import objgraph
__all__ = [
'no_grad',
'grad',
'guard',
'enable_dygraph',
'disable_dygraph',
......@@ -254,9 +255,145 @@ def _print_debug_msg(parameter_list, limit=5, is_test=False):
def grad(outputs,
inputs,
grad_outputs=None,
no_grad_set=None,
retain_graph=None,
create_graph=False,
only_inputs=True,
allow_unused=False,
no_grad_vars=None,
backward_strategy=None):
'''
.. note::
**This API is ONLY available in Dygraph mode.**
This API computes the sum of gradients of `outputs` with respect to each `inputs` .
Parameters:
outputs (Variable|list(Variable)|tuple(Variable)): the output Variable or
Variable list/tuple of the graph to compute gradients.
inputs (Variable|list(Variable)|tuple(Variable)): the input Variable or
Variable list/tuple of the graph to compute gradients. The returned
values of this API are the gradients of `inputs` .
grad_outputs (Variable|list(Variable|None)|tuple(Variable|None), optional):
initial gradient values of `outputs` . If `grad_outputs` is None,
the initial gradient values of `outputs` would be Tensors filled with 1;
if `grad_outputs` is not None, it must have the same length as `outputs` ,
and in this case, the initial gradient value of the i-th `outputs` would
be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs`
is None; (2) the i-th element of `grad_outputs` when the i-th element of
`grad_outputs` is a Variable. Default None.
retain_graph (bool, optional): whether to retain the forward graph which
is used to calculate the gradient. When it is True, the graph would
be retained, in which way users can calculate backward twice for the
same graph. When it is False, the graph would be freed. Default None,
which means it is equal to `create_graph` .
create_graph (bool, optional): whether to create the gradient graphs of
the computing process. When it is True, higher order derivatives are
supported to compute; when it is False, the gradient graphs of the
computing process would be discarded. Default False.
only_inputs (bool, optional): whether to only compute the gradients of
`inputs` . If it is False, the gradients of all remaining leaf
Variables in the graph would be also computed and accumulated.
If it is True, only the gradients of `inputs` would be computed.
Default True. only_inputs=False is under development, and it is
not supported yet.
allow_unused (bool, optional): whether to raise error or return None if some
Variables of `inputs` are unreachable in the graph. If some Variables of
`inputs` are unreachable in the graph (i.e., their gradients are None),
error would be raised if allow_unused=False, or None would be returned as
their gradients if allow_unused=True. Default False.
no_grad_vars (Variable|list(Variable)|tuple(Variable)|set(Variable), optional):
the Variables whose gradients are not needed to compute. Default None.
backward_strategy (BackwardStrategy, optional): The backward strategy to
compute gradients. See :ref:`api_fluid_dygraph_BackwardStrategy` for
details. Default None.
Returns:
tuple: a tuple of Variables, whose length is the same as the Variable number
inside `inputs`, and the i-th returned Variable is the sum of gradients of
`outputs` with respect to the i-th `inputs`.
Examples 1:
.. code-block:: python
import paddle.fluid as fluid
def test_dygraph_grad(create_graph):
with fluid.dygraph.guard():
x = fluid.layers.ones(shape=[1], dtype='float32')
x.stop_gradient = False
y = x * x
# Since y = x * x, dx = 2 * x
dx = fluid.dygraph.grad(
outputs=[y],
inputs=[x],
create_graph=create_graph,
retain_graph=True)[0]
z = y + dx
# If create_graph = False, the gradient of dx
# would not be backpropagated. Therefore,
# z = x * x + dx, and x.gradient() = 2 * x = 2.0
# If create_graph = True, the gradient of dx
# would be backpropagated. Therefore,
# z = x * x + dx = x * x + 2 * x, and
# x.gradient() = 2 * x + 2 = 4.0
z.backward()
return x.gradient()
print(test_dygraph_grad(create_graph=False)) # [2.]
print(test_dygraph_grad(create_graph=True)) # [4.]
Examples 2:
.. code-block:: python
import paddle.fluid as fluid
fluid.enable_dygraph()
def test_dygraph_grad(grad_outputs=None):
x = fluid.layers.fill_constant(shape=[1], value=2.0, dtype='float32')
x.stop_gradient = False
y1 = x * x
y2 = x * 3
# If grad_outputs=None, dy1 = [1], dy2 = [1].
# If grad_outputs=[g1, g2], then:
# - dy1 = [1] if g1 is None else g1
# - dy2 = [1] if g2 is None else g2
# Since y1 = x * x, dx = 2 * x * dy1.
# Since y2 = x * 3, dx = 3 * dy2.
# Therefore, the final result would be:
# dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2.
dx = fluid.dygraph.grad(
outputs=[y1, y2],
inputs=[x],
grad_outputs=grad_outputs)[0]
return dx.numpy()
THREE = fluid.layers.fill_constant(shape=[1], value=3.0, dtype='float32')
FOUR = fluid.layers.fill_constant(shape=[1], value=4.0, dtype='float32')
# dy1 = [1], dy2 = [1]
print(test_dygraph_grad(None)) # [7.]
# dy1 = [1], dy2 = [4]
print(test_dygraph_grad([None, FOUR])) # [16.]
# dy1 = [4], dy2 = [1]
print(test_dygraph_grad([FOUR, None])) # [19.]
# dy1 = [3], dy2 = [4]
print(test_dygraph_grad([THREE, FOUR])) # [24.]
'''
def check_in_out(in_out_list, name):
assert in_out_list is not None, "{} should not be None".format(name)
......@@ -294,18 +431,18 @@ def grad(outputs,
assert len(grad_outputs) == len(
outputs), "The length of grad_outputs must be equal to outputs"
if no_grad_set is None:
no_grad_set = []
elif isinstance(no_grad_set, core.VarBase):
no_grad_set = [no_grad_set]
elif isinstance(no_grad_set, (list, tuple, set)):
no_grad_set = list(no_grad_set)
for var in no_grad_set:
if no_grad_vars is None:
no_grad_vars = []
elif isinstance(no_grad_vars, core.VarBase):
no_grad_vars = [no_grad_vars]
elif isinstance(no_grad_vars, (list, tuple, set)):
no_grad_vars = list(no_grad_vars)
for var in no_grad_vars:
assert isinstance(
var, core.VarBase), "no_grad_set can only contains Variable"
var, core.VarBase), "no_grad_vars can only contains Variable"
else:
raise AssertionError(
"no_grad_set must be None, Variable or list/tuple/set of Variables")
"no_grad_vars must be None, Variable or list/tuple/set of Variables")
if backward_strategy is None:
backward_strategy = core.BackwardStrategy()
......@@ -315,10 +452,22 @@ def grad(outputs,
assert isinstance(create_graph, bool), "create_graph must be True or False"
if retain_graph is None:
retain_graph = create_graph
assert isinstance(retain_graph,
bool), "retain_graph must be None, True or False"
assert isinstance(allow_unused, bool), "allow_unused must be True or False"
assert isinstance(only_inputs, bool), "only_inputs must be True or False"
assert only_inputs, "only_inputs=False is not supported yet"
place = core.Place()
place.set_place(framework._current_expected_place())
return core.dygraph_partial_grad(inputs, outputs, grad_outputs, no_grad_set,
place, backward_strategy, create_graph)
return core.dygraph_partial_grad(
inputs, outputs, grad_outputs, no_grad_vars, place, backward_strategy,
create_graph, retain_graph, allow_unused, only_inputs)
@framework.dygraph_only
......
......@@ -17,7 +17,6 @@ from paddle.fluid.wrapped_decorator import wrap_decorator
import unittest
from unittest import TestCase
import numpy as np
from paddle.fluid.dygraph.base import grad
def _dygraph_guard_(func):
......@@ -48,16 +47,20 @@ class TestDygraphDoubleGrad(TestCase):
outputs,
inputs,
grad_outputs=None,
no_grad_set=None,
create_graph=False):
no_grad_vars=None,
retain_graph=None,
create_graph=False,
allow_unused=False):
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = self.sort_sum_gradient
return grad(
return fluid.dygraph.grad(
outputs=outputs,
inputs=inputs,
grad_outputs=grad_outputs,
no_grad_set=no_grad_set,
no_grad_vars=no_grad_vars,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=allow_unused,
backward_strategy=backward_strategy)
@dygraph_guard
......@@ -84,10 +87,11 @@ class TestDygraphDoubleGrad(TestCase):
[random_var(shape)], [random_var(shape)])
with self.assertRaises(AssertionError):
self.grad([random_var(shape)], [random_var(shape)], no_grad_set=[1])
self.grad(
[random_var(shape)], [random_var(shape)], no_grad_vars=[1])
with self.assertRaises(AssertionError):
self.grad([random_var(shape)], [random_var(shape)], no_grad_set=1)
self.grad([random_var(shape)], [random_var(shape)], no_grad_vars=1)
@dygraph_guard
def test_simple_example(self):
......@@ -96,17 +100,20 @@ class TestDygraphDoubleGrad(TestCase):
y = x + 1
for create_graph in [False, True]:
dx, = self.grad([x], [x], create_graph=create_graph)
dx, = self.grad(
[x], [x], create_graph=create_graph, retain_graph=True)
self.assertEqual(dx.shape, x.shape)
self.assertTrue(np.all(dx.numpy() == 1))
self.assertNotEqual(dx.stop_gradient, create_graph)
dx_mul_2, = self.grad([y, x], [x], create_graph=create_graph)
dx_mul_2, = self.grad(
[y, x], [x], create_graph=create_graph, retain_graph=True)
self.assertEqual(dx_mul_2.shape, x.shape)
self.assertTrue(np.all(dx_mul_2.numpy() == 2))
self.assertNotEqual(dx_mul_2.stop_gradient, create_graph)
none_grad, = self.grad([x], [y], create_graph=create_graph)
none_grad, = self.grad(
[x], [y], create_graph=create_graph, allow_unused=True)
self.assertTrue(none_grad is None)
grad_with_none_and_not_none, = self.grad(
......@@ -160,7 +167,8 @@ class TestDygraphDoubleGrad(TestCase):
outputs=[y, z],
inputs=[x],
grad_outputs=[grad_y, grad_z],
create_graph=create_graph)
create_graph=create_graph,
retain_graph=True)
grad_y_np = ones_grad_y if grad_y is None else grad_y.numpy(
)
......@@ -216,7 +224,7 @@ class TestDygraphDoubleGrad(TestCase):
self.assertTrue(np.allclose(x_grad_actual, x_grad_expected))
@dygraph_guard
def test_example_with_gradient_accumulation_and_no_grad_set(self):
def test_example_with_gradient_accumulation_and_no_grad_vars(self):
x = random_var(self.shape)
x_np = x.numpy()
numel = x_np.size
......@@ -231,7 +239,7 @@ class TestDygraphDoubleGrad(TestCase):
del y1, z, w
dx_actual, = self.grad(
[w_mean], [x], create_graph=True, no_grad_set=[y2])
[w_mean], [x], create_graph=True, no_grad_vars=[y2])
self.assertFalse(y2.stop_gradient)
self.assertFalse(dx_actual.stop_gradient)
......
......@@ -368,11 +368,10 @@ def loss_cls(cls, label, cfg):
def calc_gradients(outputs, inputs, no_grad_set):
if fluid.in_dygraph_mode():
from paddle.fluid.dygraph.base import grad
return grad(
return fluid.dygraph.grad(
outputs=outputs,
inputs=inputs,
no_grad_set=no_grad_set,
no_grad_vars=no_grad_set,
create_graph=True)
else:
return fluid.gradients(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册