未验证 提交 a2b80145 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run (#41306)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* [DoubleGrad #4] Bug Fixes to Double Grad Node Generation

* Fixed yaml typo

* Fixed yaml typo

* fixed minor issues

* [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad()

* Fixed minor issue

* Fixed CI-Inference issue

* Fixed CI-inference issues

* [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run

* Fixed minor issues

* Fixed issue with backward graph construction logic

* Fixed implementation issues with backward graph reconstruction

* Fixed unittest issue

* Fixed issues
上级 5936fa6e
......@@ -25,7 +25,10 @@ class GradNodeAccumulation : public GradNodeBase {
// Constructor: configure fwd input tensors to grad node
explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) {
VLOG(6) << "Construct GradNodeAccumulation";
weak_grad_ = meta->WeakGrad();
if (meta) {
weak_grad_ = meta->WeakGrad();
}
SetDefaultGradInOutMeta();
}
......@@ -40,11 +43,6 @@ class GradNodeAccumulation : public GradNodeBase {
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
std::string name() { return "GradNodeAccumulation"; }
/**
......@@ -58,6 +56,11 @@ class GradNodeAccumulation : public GradNodeBase {
inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; }
void ApplyReduceHooks();
std::shared_ptr<GradNodeBase> Copy() const override {
return std::shared_ptr<GradNodeAccumulation>(
new GradNodeAccumulation(nullptr));
}
private:
std::weak_ptr<paddle::experimental::Tensor> weak_grad_;
......
......@@ -44,11 +44,6 @@ class GradNodeScale : public GradNodeBase {
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
void SetTensorWrappers_X(
const std::vector<paddle::experimental::Tensor>& tensors);
......@@ -56,6 +51,12 @@ class GradNodeScale : public GradNodeBase {
std::string name() override { return ""; }
// Members: define fwd input tensors
// For Scale there is no fwd input tensor needed
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::make_shared<GradNodeScale>(*this);
return copied_node;
}
private:
float scale_{1.0};
};
......
......@@ -2479,22 +2479,23 @@ static std::string GenerateGradNodeHeaderContents(
"\n"
" void ClearTensorWrappers() override { \n"
"%s\n"
" is_tensor_wrappers_cleared = true;\n"
" SetIsTensorWrappersCleared(true);\n"
" }\n"
" std::string name() override { return \" GradNode%s \"; } \n "
"\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
" auto copied_node = std::shared_ptr<GradNode%s>(new "
"GradNode%s(*this));\n "
" return copied_node;\n "
"}}\n "
"\n"
" // SetX, SetY, ...\n"
"%s\n"
" // SetAttrMap\n"
"%s\n"
" bool IsTensorWrappersCleared() override { \n"
" return is_tensor_wrappers_cleared;\n"
" }\n"
" private:\n"
" // TensorWrappers\n"
"%s\n"
" bool is_tensor_wrappers_cleared = false;\n"
"\n"
" // Attribute Map\n"
"%s\n"
"};";
......@@ -2601,8 +2602,9 @@ static std::string GenerateGradNodeHeaderContents(
std::string grad_node_str = paddle::string::Sprintf(
GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type,
op_type, clear_tensor_wrappers_str, op_type, set_tensor_wrappers_str,
set_attr_map_str, tensor_wrapper_members_str, attr_members_str);
op_type, clear_tensor_wrappers_str, op_type, op_type, op_type,
set_tensor_wrappers_str, set_attr_map_str, tensor_wrapper_members_str,
attr_members_str);
return grad_node_str;
}
......
......@@ -125,7 +125,13 @@ class {} : public egr::GradNodeBase {{
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
SetIsTensorWrappersCleared(true);
}}
std::shared_ptr<GradNodeBase> Copy() const override {{
auto copied_node = std::shared_ptr<{}>(new {}(*this));
return copied_node;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
......@@ -133,15 +139,10 @@ class {} : public egr::GradNodeBase {{
// SetAttributes
{}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private:
// TensorWrappers
{}
bool is_tensor_wrappers_cleared = false;
// Attributes
{}
}};
......@@ -1218,9 +1219,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name = GetGradNodeName(forward_op_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)
grad_node_name, clear_tensor_wrapper_str, grad_node_name,
grad_node_name, set_tensor_wrapper_methods_str,
set_attribute_methods_str, tensor_wrapper_members_str,
attribute_members_str)
logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
......
......@@ -50,7 +50,16 @@ class GeneralGrad {
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto* target_node = auto_grad_meta->GetMutableGradNode().get();
if (orig_to_copied_node_mapping_.count(target_node)) {
target_node = orig_to_copied_node_mapping_[target_node];
} else {
VLOG(6) << "Unable to find target node in "
"orig_to_copied_node_mapping_, likely indicating an "
"unused input";
}
PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal(
"There is no grad op for %s:[%d] or it's"
......@@ -249,7 +258,15 @@ class GeneralGrad {
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto* target_node = auto_grad_meta->GetMutableGradNode().get();
if (orig_to_copied_node_mapping_.count(target_node)) {
target_node = orig_to_copied_node_mapping_[target_node];
} else {
VLOG(6) << "Unable to find target node in "
"orig_to_copied_node_mapping_, likely indicating an unused "
"input";
}
auto iter = results_map.find(target_node);
if (iter != results_map.end()) {
......@@ -326,6 +343,78 @@ class GeneralGrad {
potential_stop_nodes.clear();
depending_nodes.clear();
results_map.clear();
copied_grad_nodes_.clear();
orig_to_copied_node_mapping_.clear();
}
GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
if (orig_to_copied_node_mapping_.count(orig_node.get())) {
return orig_to_copied_node_mapping_[orig_node.get()];
}
std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();
// Save node and update mapping
orig_to_copied_node_mapping_[orig_node.get()] = copied_node.get();
copied_grad_nodes_.push_back(copied_node);
return copied_node.get();
}
void ReconstructBackwardGraph(
const std::queue<GradNodeBase*>& orig_init_queue) {
std::queue<GradNodeBase*> queue = orig_init_queue;
std::unordered_set<GradNodeBase*> visited;
// BFS and recursively copy the grad nodes
while (!queue.empty()) {
GradNodeBase* orig_node = queue.front();
queue.pop();
if (visited.count(orig_node)) {
continue;
}
visited.insert(orig_node);
PADDLE_ENFORCE(
orig_to_copied_node_mapping_.count(orig_node),
paddle::platform::errors::Fatal(
"Cannot reconstruct backward graph,"
"unable to find copied target for certain grad node."));
GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node];
const std::vector<std::vector<Edge>>& orig_edges = orig_node->GetEdges();
std::vector<std::vector<Edge>>& copied_edges =
copied_node->GetMutableEdges();
for (size_t i = 0; i < orig_edges.size(); i++) {
for (size_t j = 0; j < orig_edges[i].size(); j++) {
const Edge& orig_edge = orig_edges[i][j];
Edge& copied_edge = copied_edges[i][j];
std::shared_ptr<GradNodeBase> orig_next_node =
orig_edge.GetMutableGradNode();
if (!orig_next_node) continue;
// Copy Next Node
std::shared_ptr<GradNodeBase> copied_next_node;
if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
copied_next_node =
orig_to_copied_node_mapping_[orig_next_node.get()]
->shared_from_this();
} else {
copied_next_node = orig_next_node->Copy();
orig_to_copied_node_mapping_[orig_next_node.get()] =
copied_next_node.get();
copied_grad_nodes_.push_back(copied_next_node);
}
// Update Edge's Grad Node
copied_edge.SetGradNode(copied_next_node);
// Update BFS queue
queue.push(orig_next_node.get());
}
}
}
}
private:
......@@ -345,6 +434,10 @@ class GeneralGrad {
std::unordered_set<GradNodeBase*> /* pre nodes */>
depending_nodes;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
std::unordered_map<GradNodeBase*, GradNodeBase*> orig_to_copied_node_mapping_;
DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};
......@@ -444,6 +537,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// 1. Init queue with starting nodes
// 2. Prepare initial input buffers
std::queue<GradNodeBase*> queue;
std::queue<GradNodeBase*> orig_queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict;
for (size_t i = 0; i < tensors.size(); i++) {
......@@ -468,6 +562,16 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// TODO(zhanlve): Copy and Modify GradNode if is_general_grad
GradNodeBase* grad_node = shared_grad_node.get();
if (is_general_grad) {
// Save orig grad node
orig_queue.push(grad_node);
// Replace grad_node with copied grad_node
grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);
// Record potential startup grad node
GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
}
// Prepare GradTensorHolder
if (!node_input_buffers_dict.count(grad_node)) {
......@@ -504,9 +608,11 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Prepare queue, potential startup_nodes
queue.push(grad_node);
if (is_general_grad) {
GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node);
}
}
if (is_general_grad) {
// Copy Backward Graph
GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
}
VLOG(6) << "Update In degree Map for backward";
......
......@@ -36,9 +36,10 @@ class RunCustomOpNode : public GradNodeBase {
}
// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) // NOLINT
virtual std::vector<std::vector<paddle::experimental::Tensor>>
operator()( // NOLINT
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) // NOLINT
override;
std::string name() {
......@@ -64,13 +65,15 @@ class RunCustomOpNode : public GradNodeBase {
}
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; }
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node =
std::shared_ptr<RunCustomOpNode>(new RunCustomOpNode(*this));
return copied_node;
}
public:
std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_outs;
std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_ins;
......
......@@ -326,6 +326,10 @@ const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
return adj_edges_;
}
std::vector<std::vector<Edge>>& GradNodeBase::GetMutableEdges() {
return adj_edges_;
}
std::vector<std::vector<paddle::experimental::Tensor>>
GradNodeBase::ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
......
......@@ -113,7 +113,11 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> {
virtual void ClearTensorWrappers() = 0;
virtual bool IsTensorWrappersCleared() = 0;
/**
* Self-Copy interface designed for use in DoubleGrad
* **/
virtual std::shared_ptr<GradNodeBase> Copy() const = 0;
/**
* AddEdges is designed to set input tensors' backward Node as current
* node's Edges.
......@@ -191,6 +195,16 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> {
/**
* GetEdges is designed to get all edges of current node**/
const std::vector<std::vector<Edge>>& GetEdges() const;
std::vector<std::vector<Edge>>& GetMutableEdges();
/**
* The following interfaces are designed for no_need_buffer
* **/
bool IsTensorWrappersCleared() { return is_tensor_wrappers_cleared_; }
void SetIsTensorWrappersCleared(bool is_tensor_wrappers_cleared) {
is_tensor_wrappers_cleared_ = is_tensor_wrappers_cleared;
}
private:
// TODO(zhanlve): Merge adj_edges_ into GradOutMeta
......@@ -218,6 +232,7 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> {
// We handle complex to real conversion only if any complex GradIn is involved
bool need_complex_to_real_ = false;
int64_t next_hook_id_{0};
bool is_tensor_wrappers_cleared_ = false;
};
class Edge {
......@@ -246,6 +261,11 @@ class Edge {
return grad_node_;
}
void SetGradNode(const std::shared_ptr<GradNodeBase>& node) {
VLOG(6) << "Reseting Edge's Grad Node";
grad_node_ = node;
}
std::pair<size_t, size_t> GetEdgeRankInfo() const {
return std::make_pair(in_slot_id_, in_rank_);
}
......
......@@ -40,11 +40,6 @@ class GradNodePyLayer : public GradNodeBase {
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
std::string name() {
return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name);
}
......@@ -72,6 +67,12 @@ class GradNodePyLayer : public GradNodeBase {
}
}
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node =
std::shared_ptr<GradNodePyLayer>(new GradNodePyLayer(*this));
return copied_node;
}
private:
PyObject* ctx_{nullptr};
PyObject* outputs_{nullptr};
......
......@@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase {
GradTestNode() : GradNodeBase() { val_ = 1.0; }
std::string name() override { return "GradTestNode"; }
std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads,
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override {
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
->data<float>()[0];
......@@ -50,10 +50,14 @@ class GradTestNode : public egr::GradNodeBase {
return res;
}
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
std::shared_ptr<GradNodeBase> Copy() const override {
{
auto copied_node = std::shared_ptr<GradTestNode>(new GradTestNode(*this));
return copied_node;
}
}
float val_;
};
} // namespace eager_test
......@@ -407,10 +407,6 @@ class GradNodeRunProgram : public egr::GradNodeBase {
}
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
// SetAttrMap
void SetAttrMap(const paddle::framework::AttributeMap &attrs) {
......@@ -468,6 +464,12 @@ class GradNodeRunProgram : public egr::GradNodeBase {
}
}
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node =
std::shared_ptr<GradNodeRunProgram>(new GradNodeRunProgram(*this));
return copied_node;
}
private:
// TensorWrappers
std::vector<paddle::experimental::Tensor> x_;
......
......@@ -639,5 +639,40 @@ class TestDoubleGradResNet(TestCase):
self.assertTrue(np.array_equal(egr_g_numpy, g_numpy))
class TestDoubleGradBasics(TestCase):
def test_matmul(self):
input_numpy = np.ones([3, 3]) * 2
with _test_eager_guard():
x = paddle.to_tensor(
input_numpy, stop_gradient=False, dtype='float32')
y = paddle.to_tensor(
input_numpy, stop_gradient=False, dtype='float32')
grad_out = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32')
out = paddle.matmul(x, y, False, False)
new_x_g, new_y_g = paddle.grad(
[out], [x, y], [grad_out], retain_graph=True, create_graph=True)
new_x_g.backward()
out_ref = np.ones([3, 3]) * 12.0
self.assertTrue(np.array_equal(out.numpy(), out_ref))
new_x_g_ref = np.ones([3, 3]) * 6.0
new_y_g_ref = np.ones([3, 3]) * 6.0
self.assertTrue(np.array_equal(new_x_g.numpy(), new_x_g_ref))
self.assertTrue(np.array_equal(new_y_g.numpy(), new_y_g_ref))
x_grad_ref = np.ones([3, 3]) * 0.0
self.assertTrue(np.array_equal(x.grad.numpy(), x_grad_ref))
y_grad_ref = np.ones([3, 3]) * 3.0
self.assertTrue(np.array_equal(y.grad.numpy(), y_grad_ref))
grad_out_grad_ref = np.ones([3, 3]) * 6.0
self.assertTrue(
np.array_equal(grad_out.grad.numpy(), grad_out_grad_ref))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册