未验证 提交 a2b80145 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run (#41306)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* [DoubleGrad #4] Bug Fixes to Double Grad Node Generation

* Fixed yaml typo

* Fixed yaml typo

* fixed minor issues

* [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad()

* Fixed minor issue

* Fixed CI-Inference issue

* Fixed CI-inference issues

* [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run

* Fixed minor issues

* Fixed issue with backward graph construction logic

* Fixed implementation issues with backward graph reconstruction

* Fixed unittest issue

* Fixed issues
上级 5936fa6e
...@@ -25,7 +25,10 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -25,7 +25,10 @@ class GradNodeAccumulation : public GradNodeBase {
// Constructor: configure fwd input tensors to grad node // Constructor: configure fwd input tensors to grad node
explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) { explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) {
VLOG(6) << "Construct GradNodeAccumulation"; VLOG(6) << "Construct GradNodeAccumulation";
weak_grad_ = meta->WeakGrad(); if (meta) {
weak_grad_ = meta->WeakGrad();
}
SetDefaultGradInOutMeta(); SetDefaultGradInOutMeta();
} }
...@@ -40,11 +43,6 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -40,11 +43,6 @@ class GradNodeAccumulation : public GradNodeBase {
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
std::string name() { return "GradNodeAccumulation"; } std::string name() { return "GradNodeAccumulation"; }
/** /**
...@@ -58,6 +56,11 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -58,6 +56,11 @@ class GradNodeAccumulation : public GradNodeBase {
inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; } inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; }
void ApplyReduceHooks(); void ApplyReduceHooks();
std::shared_ptr<GradNodeBase> Copy() const override {
return std::shared_ptr<GradNodeAccumulation>(
new GradNodeAccumulation(nullptr));
}
private: private:
std::weak_ptr<paddle::experimental::Tensor> weak_grad_; std::weak_ptr<paddle::experimental::Tensor> weak_grad_;
......
...@@ -44,11 +44,6 @@ class GradNodeScale : public GradNodeBase { ...@@ -44,11 +44,6 @@ class GradNodeScale : public GradNodeBase {
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
void SetTensorWrappers_X( void SetTensorWrappers_X(
const std::vector<paddle::experimental::Tensor>& tensors); const std::vector<paddle::experimental::Tensor>& tensors);
...@@ -56,6 +51,12 @@ class GradNodeScale : public GradNodeBase { ...@@ -56,6 +51,12 @@ class GradNodeScale : public GradNodeBase {
std::string name() override { return ""; } std::string name() override { return ""; }
// Members: define fwd input tensors // Members: define fwd input tensors
// For Scale there is no fwd input tensor needed // For Scale there is no fwd input tensor needed
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::make_shared<GradNodeScale>(*this);
return copied_node;
}
private: private:
float scale_{1.0}; float scale_{1.0};
}; };
......
...@@ -2479,22 +2479,23 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2479,22 +2479,23 @@ static std::string GenerateGradNodeHeaderContents(
"\n" "\n"
" void ClearTensorWrappers() override { \n" " void ClearTensorWrappers() override { \n"
"%s\n" "%s\n"
" is_tensor_wrappers_cleared = true;\n" " SetIsTensorWrappersCleared(true);\n"
" }\n" " }\n"
" std::string name() override { return \" GradNode%s \"; } \n " " std::string name() override { return \" GradNode%s \"; } \n "
"\n" "\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
" auto copied_node = std::shared_ptr<GradNode%s>(new "
"GradNode%s(*this));\n "
" return copied_node;\n "
"}}\n "
"\n"
" // SetX, SetY, ...\n" " // SetX, SetY, ...\n"
"%s\n" "%s\n"
" // SetAttrMap\n" " // SetAttrMap\n"
"%s\n" "%s\n"
" bool IsTensorWrappersCleared() override { \n"
" return is_tensor_wrappers_cleared;\n"
" }\n"
" private:\n" " private:\n"
" // TensorWrappers\n" " // TensorWrappers\n"
"%s\n" "%s\n"
" bool is_tensor_wrappers_cleared = false;\n"
"\n"
" // Attribute Map\n" " // Attribute Map\n"
"%s\n" "%s\n"
"};"; "};";
...@@ -2601,8 +2602,9 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2601,8 +2602,9 @@ static std::string GenerateGradNodeHeaderContents(
std::string grad_node_str = paddle::string::Sprintf( std::string grad_node_str = paddle::string::Sprintf(
GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type, GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type,
op_type, clear_tensor_wrappers_str, op_type, set_tensor_wrappers_str, op_type, clear_tensor_wrappers_str, op_type, op_type, op_type,
set_attr_map_str, tensor_wrapper_members_str, attr_members_str); set_tensor_wrappers_str, set_attr_map_str, tensor_wrapper_members_str,
attr_members_str);
return grad_node_str; return grad_node_str;
} }
......
...@@ -125,7 +125,13 @@ class {} : public egr::GradNodeBase {{ ...@@ -125,7 +125,13 @@ class {} : public egr::GradNodeBase {{
void ClearTensorWrappers() override {{ void ClearTensorWrappers() override {{
{} {}
is_tensor_wrappers_cleared = true; SetIsTensorWrappersCleared(true);
}}
std::shared_ptr<GradNodeBase> Copy() const override {{
auto copied_node = std::shared_ptr<{}>(new {}(*this));
return copied_node;
}} }}
// SetTensorWrapperX, SetTensorWrapperY, ... // SetTensorWrapperX, SetTensorWrapperY, ...
...@@ -133,15 +139,10 @@ class {} : public egr::GradNodeBase {{ ...@@ -133,15 +139,10 @@ class {} : public egr::GradNodeBase {{
// SetAttributes // SetAttributes
{} {}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private: private:
// TensorWrappers // TensorWrappers
{} {}
bool is_tensor_wrappers_cleared = false;
// Attributes // Attributes
{} {}
}}; }};
...@@ -1218,9 +1219,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1218,9 +1219,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name = GetGradNodeName(forward_op_name) grad_node_name = GetGradNodeName(forward_op_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format( self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str, grad_node_name, clear_tensor_wrapper_str, grad_node_name,
set_tensor_wrapper_methods_str, set_attribute_methods_str, grad_node_name, set_tensor_wrapper_methods_str,
tensor_wrapper_members_str, attribute_members_str) set_attribute_methods_str, tensor_wrapper_members_str,
attribute_members_str)
logging.info(f"Generated Node Declaration: {self.node_declaration_str}") logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
......
...@@ -50,7 +50,16 @@ class GeneralGrad { ...@@ -50,7 +50,16 @@ class GeneralGrad {
for (size_t i = 0; i < num_inputs; i++) { for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta = AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]); EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get(); auto* target_node = auto_grad_meta->GetMutableGradNode().get();
if (orig_to_copied_node_mapping_.count(target_node)) {
target_node = orig_to_copied_node_mapping_[target_node];
} else {
VLOG(6) << "Unable to find target node in "
"orig_to_copied_node_mapping_, likely indicating an "
"unused input";
}
PADDLE_ENFORCE_NOT_NULL(target_node, PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"There is no grad op for %s:[%d] or it's" "There is no grad op for %s:[%d] or it's"
...@@ -249,7 +258,15 @@ class GeneralGrad { ...@@ -249,7 +258,15 @@ class GeneralGrad {
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i]; auto& input = inputs[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input); AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto* target_node = auto_grad_meta->GetMutableGradNode().get();
if (orig_to_copied_node_mapping_.count(target_node)) {
target_node = orig_to_copied_node_mapping_[target_node];
} else {
VLOG(6) << "Unable to find target node in "
"orig_to_copied_node_mapping_, likely indicating an unused "
"input";
}
auto iter = results_map.find(target_node); auto iter = results_map.find(target_node);
if (iter != results_map.end()) { if (iter != results_map.end()) {
...@@ -326,6 +343,78 @@ class GeneralGrad { ...@@ -326,6 +343,78 @@ class GeneralGrad {
potential_stop_nodes.clear(); potential_stop_nodes.clear();
depending_nodes.clear(); depending_nodes.clear();
results_map.clear(); results_map.clear();
copied_grad_nodes_.clear();
orig_to_copied_node_mapping_.clear();
}
GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
if (orig_to_copied_node_mapping_.count(orig_node.get())) {
return orig_to_copied_node_mapping_[orig_node.get()];
}
std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();
// Save node and update mapping
orig_to_copied_node_mapping_[orig_node.get()] = copied_node.get();
copied_grad_nodes_.push_back(copied_node);
return copied_node.get();
}
void ReconstructBackwardGraph(
const std::queue<GradNodeBase*>& orig_init_queue) {
std::queue<GradNodeBase*> queue = orig_init_queue;
std::unordered_set<GradNodeBase*> visited;
// BFS and recursively copy the grad nodes
while (!queue.empty()) {
GradNodeBase* orig_node = queue.front();
queue.pop();
if (visited.count(orig_node)) {
continue;
}
visited.insert(orig_node);
PADDLE_ENFORCE(
orig_to_copied_node_mapping_.count(orig_node),
paddle::platform::errors::Fatal(
"Cannot reconstruct backward graph,"
"unable to find copied target for certain grad node."));
GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node];
const std::vector<std::vector<Edge>>& orig_edges = orig_node->GetEdges();
std::vector<std::vector<Edge>>& copied_edges =
copied_node->GetMutableEdges();
for (size_t i = 0; i < orig_edges.size(); i++) {
for (size_t j = 0; j < orig_edges[i].size(); j++) {
const Edge& orig_edge = orig_edges[i][j];
Edge& copied_edge = copied_edges[i][j];
std::shared_ptr<GradNodeBase> orig_next_node =
orig_edge.GetMutableGradNode();
if (!orig_next_node) continue;
// Copy Next Node
std::shared_ptr<GradNodeBase> copied_next_node;
if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
copied_next_node =
orig_to_copied_node_mapping_[orig_next_node.get()]
->shared_from_this();
} else {
copied_next_node = orig_next_node->Copy();
orig_to_copied_node_mapping_[orig_next_node.get()] =
copied_next_node.get();
copied_grad_nodes_.push_back(copied_next_node);
}
// Update Edge's Grad Node
copied_edge.SetGradNode(copied_next_node);
// Update BFS queue
queue.push(orig_next_node.get());
}
}
}
} }
private: private:
...@@ -345,6 +434,10 @@ class GeneralGrad { ...@@ -345,6 +434,10 @@ class GeneralGrad {
std::unordered_set<GradNodeBase*> /* pre nodes */> std::unordered_set<GradNodeBase*> /* pre nodes */>
depending_nodes; depending_nodes;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map; std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;
std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
std::unordered_map<GradNodeBase*, GradNodeBase*> orig_to_copied_node_mapping_;
DISABLE_COPY_AND_ASSIGN(GeneralGrad); DISABLE_COPY_AND_ASSIGN(GeneralGrad);
}; };
...@@ -444,6 +537,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -444,6 +537,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// 1. Init queue with starting nodes // 1. Init queue with starting nodes
// 2. Prepare initial input buffers // 2. Prepare initial input buffers
std::queue<GradNodeBase*> queue; std::queue<GradNodeBase*> queue;
std::queue<GradNodeBase*> orig_queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>> std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict; node_input_buffers_dict;
for (size_t i = 0; i < tensors.size(); i++) { for (size_t i = 0; i < tensors.size(); i++) {
...@@ -468,6 +562,16 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -468,6 +562,16 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// TODO(zhanlve): Copy and Modify GradNode if is_general_grad // TODO(zhanlve): Copy and Modify GradNode if is_general_grad
GradNodeBase* grad_node = shared_grad_node.get(); GradNodeBase* grad_node = shared_grad_node.get();
if (is_general_grad) {
// Save orig grad node
orig_queue.push(grad_node);
// Replace grad_node with copied grad_node
grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);
// Record potential startup grad node
GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
}
// Prepare GradTensorHolder // Prepare GradTensorHolder
if (!node_input_buffers_dict.count(grad_node)) { if (!node_input_buffers_dict.count(grad_node)) {
...@@ -504,9 +608,11 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -504,9 +608,11 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// Prepare queue, potential startup_nodes // Prepare queue, potential startup_nodes
queue.push(grad_node); queue.push(grad_node);
if (is_general_grad) { }
GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node);
} if (is_general_grad) {
// Copy Backward Graph
GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
} }
VLOG(6) << "Update In degree Map for backward"; VLOG(6) << "Update In degree Map for backward";
......
...@@ -36,9 +36,10 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -36,9 +36,10 @@ class RunCustomOpNode : public GradNodeBase {
} }
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>>
std::vector<std::vector<paddle::experimental::Tensor>>& grads, operator()( // NOLINT
bool create_graph = false) // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) // NOLINT
override; override;
std::string name() { std::string name() {
...@@ -64,13 +65,15 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -64,13 +65,15 @@ class RunCustomOpNode : public GradNodeBase {
} }
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; } void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; }
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node =
std::shared_ptr<RunCustomOpNode>(new RunCustomOpNode(*this));
return copied_node;
}
public: public:
std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_outs; std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_outs;
std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_ins; std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_ins;
......
...@@ -326,6 +326,10 @@ const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const { ...@@ -326,6 +326,10 @@ const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
return adj_edges_; return adj_edges_;
} }
std::vector<std::vector<Edge>>& GradNodeBase::GetMutableEdges() {
return adj_edges_;
}
std::vector<std::vector<paddle::experimental::Tensor>> std::vector<std::vector<paddle::experimental::Tensor>>
GradNodeBase::ApplyGradientHooks( GradNodeBase::ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) { const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
......
...@@ -113,7 +113,11 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> { ...@@ -113,7 +113,11 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> {
virtual void ClearTensorWrappers() = 0; virtual void ClearTensorWrappers() = 0;
virtual bool IsTensorWrappersCleared() = 0; /**
* Self-Copy interface designed for use in DoubleGrad
* **/
virtual std::shared_ptr<GradNodeBase> Copy() const = 0;
/** /**
* AddEdges is designed to set input tensors' backward Node as current * AddEdges is designed to set input tensors' backward Node as current
* node's Edges. * node's Edges.
...@@ -191,6 +195,16 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> { ...@@ -191,6 +195,16 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> {
/** /**
* GetEdges is designed to get all edges of current node**/ * GetEdges is designed to get all edges of current node**/
const std::vector<std::vector<Edge>>& GetEdges() const; const std::vector<std::vector<Edge>>& GetEdges() const;
std::vector<std::vector<Edge>>& GetMutableEdges();
/**
* The following interfaces are designed for no_need_buffer
* **/
bool IsTensorWrappersCleared() { return is_tensor_wrappers_cleared_; }
void SetIsTensorWrappersCleared(bool is_tensor_wrappers_cleared) {
is_tensor_wrappers_cleared_ = is_tensor_wrappers_cleared;
}
private: private:
// TODO(zhanlve): Merge adj_edges_ into GradOutMeta // TODO(zhanlve): Merge adj_edges_ into GradOutMeta
...@@ -218,6 +232,7 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> { ...@@ -218,6 +232,7 @@ class GradNodeBase : public std::enable_shared_from_this<GradNodeBase> {
// We handle complex to real conversion only if any complex GradIn is involved // We handle complex to real conversion only if any complex GradIn is involved
bool need_complex_to_real_ = false; bool need_complex_to_real_ = false;
int64_t next_hook_id_{0}; int64_t next_hook_id_{0};
bool is_tensor_wrappers_cleared_ = false;
}; };
class Edge { class Edge {
...@@ -246,6 +261,11 @@ class Edge { ...@@ -246,6 +261,11 @@ class Edge {
return grad_node_; return grad_node_;
} }
void SetGradNode(const std::shared_ptr<GradNodeBase>& node) {
VLOG(6) << "Reseting Edge's Grad Node";
grad_node_ = node;
}
std::pair<size_t, size_t> GetEdgeRankInfo() const { std::pair<size_t, size_t> GetEdgeRankInfo() const {
return std::make_pair(in_slot_id_, in_rank_); return std::make_pair(in_slot_id_, in_rank_);
} }
......
...@@ -40,11 +40,6 @@ class GradNodePyLayer : public GradNodeBase { ...@@ -40,11 +40,6 @@ class GradNodePyLayer : public GradNodeBase {
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
std::string name() { std::string name() {
return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name); return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name);
} }
...@@ -72,6 +67,12 @@ class GradNodePyLayer : public GradNodeBase { ...@@ -72,6 +67,12 @@ class GradNodePyLayer : public GradNodeBase {
} }
} }
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node =
std::shared_ptr<GradNodePyLayer>(new GradNodePyLayer(*this));
return copied_node;
}
private: private:
PyObject* ctx_{nullptr}; PyObject* ctx_{nullptr};
PyObject* outputs_{nullptr}; PyObject* outputs_{nullptr};
......
...@@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase { ...@@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase {
GradTestNode() : GradNodeBase() { val_ = 1.0; } GradTestNode() : GradNodeBase() { val_ = 1.0; }
std::string name() override { return "GradTestNode"; } std::string name() override { return "GradTestNode"; }
std::vector<std::vector<paddle::experimental::Tensor>> operator()( std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override { bool create_graph = false) override {
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl()) val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
->data<float>()[0]; ->data<float>()[0];
...@@ -50,10 +50,14 @@ class GradTestNode : public egr::GradNodeBase { ...@@ -50,10 +50,14 @@ class GradTestNode : public egr::GradNodeBase {
return res; return res;
} }
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now"; std::shared_ptr<GradNodeBase> Copy() const override {
return false; {
auto copied_node = std::shared_ptr<GradTestNode>(new GradTestNode(*this));
return copied_node;
}
} }
float val_; float val_;
}; };
} // namespace eager_test } // namespace eager_test
...@@ -407,10 +407,6 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -407,10 +407,6 @@ class GradNodeRunProgram : public egr::GradNodeBase {
} }
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}
// SetAttrMap // SetAttrMap
void SetAttrMap(const paddle::framework::AttributeMap &attrs) { void SetAttrMap(const paddle::framework::AttributeMap &attrs) {
...@@ -468,6 +464,12 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -468,6 +464,12 @@ class GradNodeRunProgram : public egr::GradNodeBase {
} }
} }
std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node =
std::shared_ptr<GradNodeRunProgram>(new GradNodeRunProgram(*this));
return copied_node;
}
private: private:
// TensorWrappers // TensorWrappers
std::vector<paddle::experimental::Tensor> x_; std::vector<paddle::experimental::Tensor> x_;
......
...@@ -639,5 +639,40 @@ class TestDoubleGradResNet(TestCase): ...@@ -639,5 +639,40 @@ class TestDoubleGradResNet(TestCase):
self.assertTrue(np.array_equal(egr_g_numpy, g_numpy)) self.assertTrue(np.array_equal(egr_g_numpy, g_numpy))
class TestDoubleGradBasics(TestCase):
def test_matmul(self):
input_numpy = np.ones([3, 3]) * 2
with _test_eager_guard():
x = paddle.to_tensor(
input_numpy, stop_gradient=False, dtype='float32')
y = paddle.to_tensor(
input_numpy, stop_gradient=False, dtype='float32')
grad_out = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32')
out = paddle.matmul(x, y, False, False)
new_x_g, new_y_g = paddle.grad(
[out], [x, y], [grad_out], retain_graph=True, create_graph=True)
new_x_g.backward()
out_ref = np.ones([3, 3]) * 12.0
self.assertTrue(np.array_equal(out.numpy(), out_ref))
new_x_g_ref = np.ones([3, 3]) * 6.0
new_y_g_ref = np.ones([3, 3]) * 6.0
self.assertTrue(np.array_equal(new_x_g.numpy(), new_x_g_ref))
self.assertTrue(np.array_equal(new_y_g.numpy(), new_y_g_ref))
x_grad_ref = np.ones([3, 3]) * 0.0
self.assertTrue(np.array_equal(x.grad.numpy(), x_grad_ref))
y_grad_ref = np.ones([3, 3]) * 3.0
self.assertTrue(np.array_equal(y.grad.numpy(), y_grad_ref))
grad_out_grad_ref = np.ones([3, 3]) * 6.0
self.assertTrue(
np.array_equal(grad_out.grad.numpy(), grad_out_grad_ref))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册