From a31d7328b70fde9c860dfad918222d7b2d842d71 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 20 Mar 2020 07:34:01 -0500 Subject: [PATCH] Add dygraph double grad implementation (#22939) * add double grad implementation for dygraph, test=develop * polish code, add uts, test=develop * fix place bug, test=develop * polish codes, add more uts for coverages, test=develop * add no_grad_set, test=develop * add star gan ut, test=develop * follow comments, test=develop --- paddle/fluid/framework/grad_op_desc_maker.h | 16 +- .../no_need_buffer_vars_inference.cc | 1 + paddle/fluid/framework/type_defs.h | 2 +- paddle/fluid/imperative/CMakeLists.txt | 3 +- paddle/fluid/imperative/basic_engine.cc | 267 +++++ paddle/fluid/imperative/basic_engine.h | 57 + paddle/fluid/imperative/dygraph_grad_maker.h | 83 +- paddle/fluid/imperative/engine.cc | 255 ---- paddle/fluid/imperative/engine.h | 53 +- paddle/fluid/imperative/execution_context.h | 199 ++++ .../fluid/imperative/gradient_accumulator.cc | 206 +++- .../fluid/imperative/gradient_accumulator.h | 24 +- paddle/fluid/imperative/infer_shape_context.h | 371 ++++++ .../fluid/imperative/infer_var_type_context.h | 188 +++ paddle/fluid/imperative/layer.cc | 96 +- paddle/fluid/imperative/layer.h | 811 +------------ paddle/fluid/imperative/op_base.h | 211 ++++ .../fluid/imperative/partial_grad_engine.cc | 1028 +++++++++++++++++ paddle/fluid/imperative/partial_grad_engine.h | 58 + paddle/fluid/imperative/prepared_operator.cc | 3 + .../imperative/saved_variable_wrapper_list.h | 87 ++ .../tests/test_gradient_accmulator.cc | 197 ++++ paddle/fluid/imperative/tests/test_layer.cc | 28 +- paddle/fluid/imperative/tests/test_tracer.cc | 66 +- paddle/fluid/imperative/tracer.cc | 63 +- paddle/fluid/imperative/tracer.h | 21 +- paddle/fluid/imperative/type_defs.h | 25 +- paddle/fluid/imperative/variable_wrapper.h | 55 + .../operators/fused/fused_bn_activation_op.h | 1 + paddle/fluid/operators/minus_op.cc | 18 +- .../operators/reduce_ops/reduce_mean_op.cc | 25 +- paddle/fluid/operators/sum_op.cc | 32 +- .../pybind/global_value_getter_setter.cc | 2 + paddle/fluid/pybind/imperative.cc | 26 +- paddle/fluid/pybind/protobuf.cc | 4 +- python/paddle/fluid/backward.py | 7 +- python/paddle/fluid/dygraph/base.py | 71 ++ .../unittests/test_imperative_auto_prune.py | 6 +- .../unittests/test_imperative_double_grad.py | 278 +++++ ...perative_star_gan_with_gradient_penalty.py | 614 ++++++++++ 40 files changed, 4219 insertions(+), 1339 deletions(-) create mode 100644 paddle/fluid/imperative/basic_engine.cc create mode 100644 paddle/fluid/imperative/basic_engine.h delete mode 100644 paddle/fluid/imperative/engine.cc create mode 100644 paddle/fluid/imperative/execution_context.h create mode 100644 paddle/fluid/imperative/infer_shape_context.h create mode 100644 paddle/fluid/imperative/infer_var_type_context.h create mode 100644 paddle/fluid/imperative/op_base.h create mode 100644 paddle/fluid/imperative/partial_grad_engine.cc create mode 100644 paddle/fluid/imperative/partial_grad_engine.h create mode 100644 paddle/fluid/imperative/saved_variable_wrapper_list.h create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_double_grad.py create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 6bd55a8bcd0..8d55c79a0dd 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -209,14 +209,13 @@ class SingleGradOpMaker public: using GradOpBaseMakerBase::GradOpBaseMakerBase; - std::vector> operator()() const { - std::vector> retv{ - std::make_shared()}; + std::shared_ptr operator()() const { + auto node = this->NewGradNode(); { - imperative::TracedGradOp grad_op(retv.front()); - this->Apply(&grad_op); + imperative::TracedGradOp traced_grad_op(node); + this->Apply(&traced_grad_op); } - return retv; + return node->empty() ? nullptr : node; } protected: @@ -262,8 +261,9 @@ class EmptyGradOpMaker final : public imperative::GradOpBaseMakerBase { public: using GradOpBaseMakerBase::GradOpBaseMakerBase; - std::vector> operator()() const final { - return {}; + + std::shared_ptr operator()() const final { + return nullptr; } }; diff --git a/paddle/fluid/framework/no_need_buffer_vars_inference.cc b/paddle/fluid/framework/no_need_buffer_vars_inference.cc index 79bb1e223cf..07b84a151fe 100644 --- a/paddle/fluid/framework/no_need_buffer_vars_inference.cc +++ b/paddle/fluid/framework/no_need_buffer_vars_inference.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/imperative/saved_variable_wrapper_list.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 7e9a812d97b..0ff2b2fd732 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -56,7 +56,7 @@ using GradOpMakerFN = std::function>( const std::vector& grad_block)>; using DygraphGradOpMakerFN = - std::function>( + std::function( const std::string& /*op_type*/, const imperative::NameVarBaseMap& /*var_base_map_in*/, const imperative::NameVarBaseMap& /*var_base_map_out*/, diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 7db8d003b40..0403cf25c72 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -6,7 +6,8 @@ cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator add_subdirectory(jit) cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer) -cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator) +cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator) +cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator) cc_library(imperative_profiler SRCS profiler.cc) if(NOT WIN32) if(WITH_NCCL) diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc new file mode 100644 index 00000000000..9a9283a65ad --- /dev/null +++ b/paddle/fluid/imperative/basic_engine.cc @@ -0,0 +1,267 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/basic_engine.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/imperative/gradient_accumulator.h" +#include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/op_base.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace imperative { + +void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { + backward_strategy_ = strategy; + init_node_ = var->GradVarBase()->GradNode(); + var->GradVarBase()->ClearGradNode(); + + if (init_node_ == nullptr || var->OverridedStopGradient()) { + VLOG(3) << "Skip auto grad since there is no grad op for var or loss is " + "stop_gradient=True: " + << var->Name(); + return; + } + + VLOG(3) << "start backward"; + + PADDLE_ENFORCE_EQ( + var->HasGradVar(), true, + platform::errors::NotFound("Grad variable not exist for variable %s", + var->Name())); + + auto& fwd_var = var->Var().Get(); + auto* grad_var = + var->GradVarBase()->MutableVar()->GetMutable(); + VLOG(6) << "init loss grad:" << var->GradVarBase()->Name() + << " as stop_gradient false"; + var->GradVarBase()->InnerSetOverridedStopGradient(false); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place()); + grad_var->Resize(fwd_var.dims()); + grad_var->mutable_data(fwd_var.place(), fwd_var.type()); + operators::math::set_constant(*dev_ctx, grad_var, 1.0); +} + +void BasicEngine::CheckBackwardInputs(const OpBase& op) { + for (auto& pair : op.GetInsMap()) { + if (!pair.second.IsGrad()) { + continue; + } + + for (auto& var : pair.second) { + if (!var) { + continue; + } + + auto* inner_var = var->MutableVar(); + framework::Tensor* tensor = nullptr; + if (!inner_var->IsInitialized() || + inner_var->IsType()) { + tensor = inner_var->GetMutable(); + } + + if (tensor && !tensor->IsInitialized()) { + // if grad var has OverridedStopGradient skip this Op + VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero"; + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place()); + tensor->mutable_data(op.place(), var->DataType()); + operators::math::set_constant(*dev_ctx, tensor, 0.0); + } + } + } +} + +void BasicEngine::PrepareGradAccumulators(const OpBase& op) { + for (const auto& pair : op.GetOutsMap()) { + if (!pair.second.IsGrad()) { + continue; + } + + for (const auto& var : pair.second) { + if (!var) continue; + + auto& accumulator = accumulators_[var.get()]; + if (!accumulator) { + if (backward_strategy_.sorted_sum_gradient_) { + accumulator.reset(new SortedGradientAccumulator(var.get())); + } else { + accumulator.reset(new EagerGradientAccumulator(var.get())); + } + } + + accumulator->IncreaseRefCnt(); + + VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "(" + << var.get() << ") with reference count " + << accumulator->RefCnt(); + } + } +} + +void BasicEngine::PrepareDeps() { + PADDLE_ENFORCE_EQ( + node_deps_.empty(), true, + platform::errors::AlreadyExists("Op deps must be initialized here")); + PADDLE_ENFORCE_EQ( + accumulators_.empty(), true, + platform::errors::AlreadyExists("Accumulators must be initialized here")); + + std::queue q; + std::unordered_set visited; + + q.push(init_node_.get()); + visited.insert(init_node_.get()); + + while (!q.empty()) { + auto* cur_node = q.front(); + q.pop(); + + for (auto& cur_op : *cur_node) { + PADDLE_ENFORCE_NE( + cur_op.GetInsMap().empty() && cur_op.GetOutsMap().empty(), true, + platform::errors::NotFound( + "Inputs and outputs of %s do not exist. " + "This may be because you call \"backward()\" twice for the same " + "subgraph. Please try to call \"stop_gradient = True\" or " + "\"detach()\" if you use some same vars between two " + "\"backward()\" " + "calls.", + cur_op.Type())); + PrepareGradAccumulators(cur_op); + } + + const auto& grad_pending_nodes = cur_node->GradPendingNodes(); + for (auto& grad_pending_node : grad_pending_nodes) { + PADDLE_ENFORCE_NOT_NULL( + grad_pending_node, + platform::errors::NotFound("Grad pending node should not be null")); + ++node_deps_[grad_pending_node.get()]; + if (visited.count(grad_pending_node.get()) == 0) { + visited.insert(grad_pending_node.get()); + q.push(grad_pending_node.get()); + } + } + } +} + +void BasicEngine::Execute() { + if (init_node_ == nullptr) { + return; + } + + PrepareDeps(); + // Start execute Computation graph + std::queue> q; + q.push(std::move(init_node_)); + + size_t op_num = 0; + + while (!q.empty()) { + auto shared_cur_node = std::move(q.front()); + q.pop(); + + for (auto& cur_op : *shared_cur_node) { + ++op_num; + + // CheckBackWardInput + CheckBackwardInputs(cur_op); + + // Step 1: Run Backward + auto& bwd_ins = cur_op.GetInsMap(); + auto& bwd_outs = cur_op.GetOutsMap(); + + NameVarMap tmp_outs(bwd_outs); + // 1. construct the output map 2. replace the element in the map + // A var may be coresponding to several grad var in one op + for (auto& pair : tmp_outs) { + if (!pair.second.IsGrad()) { + continue; + } + + for (auto& var : pair.second) { + if (!var) { + continue; + } + + auto iter = accumulators_.find(var.get()); + PADDLE_ENFORCE_EQ( + iter != accumulators_.end(), true, + platform::errors::NotFound("Cannot find gradient of variable %s", + var->Name())); + if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) { + continue; + } + + var = std::make_shared("Gtmp@"); + need_accu_var_list_.emplace_back(iter->second.get(), var); + } + } + + { + VLOG(3) << "Start to execute grad op " << cur_op.Type(); + OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(), + cur_op.place()); + } + + // Step 2: Sum Gradient + for (auto& pair : need_accu_var_list_) { + pair.first->Add(std::move(pair.second), cur_op.id()); + } + + need_accu_var_list_.clear(); + + VLOG(3) << "Remove op after op " << cur_op.Type() << " runs"; + cur_op.ClearBackwardTrace(); + } + + // Step 3: Collect ready ops + for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) { + PADDLE_ENFORCE_NOT_NULL(grad_pending_node, + platform::errors::NotFound( + "Grad pending node should not be nullptr")); + auto iter = node_deps_.find(grad_pending_node.get()); + if (iter == node_deps_.end()) { + continue; + } + + if (--(iter->second) == 0) { + q.push(grad_pending_node); + } + } + } + Clear(); + + VLOG(1) << "Backward op number: " << op_num; +} + +void BasicEngine::Clear() { + init_node_.reset(); + node_deps_.clear(); + accumulators_.clear(); + need_accu_var_list_.clear(); +} + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/basic_engine.h b/paddle/fluid/imperative/basic_engine.h new file mode 100644 index 00000000000..2d517bb43d3 --- /dev/null +++ b/paddle/fluid/imperative/basic_engine.h @@ -0,0 +1,57 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/imperative/backward_strategy.h" +#include "paddle/fluid/imperative/engine.h" +#include "paddle/fluid/imperative/gradient_accumulator.h" + +namespace paddle { +namespace imperative { + +class VarBase; +class OpBase; + +class BasicEngine : public Engine { + public: + void Init(VarBase* var, const detail::BackwardStrategy& strategy); + + void Execute() override; + + private: + void PrepareDeps(); + + void CheckBackwardInputs(const OpBase& op); + + void PrepareGradAccumulators(const OpBase& op); + + void Clear(); + + private: + std::shared_ptr init_node_; + detail::BackwardStrategy backward_strategy_; + std::unordered_map node_deps_; + std::unordered_map> + accumulators_; + std::vector>> + need_accu_var_list_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index 7bd2643b14d..757f4193690 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -18,9 +18,11 @@ #include #include #include +#include #include #include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" @@ -51,11 +53,8 @@ class GradOpBaseMakerBase { attrs_(attrs) {} virtual ~GradOpBaseMakerBase() = default; - virtual std::vector> operator()() const = 0; - static std::shared_ptr CreateOp() { - return std::make_shared(); - } + virtual std::shared_ptr operator()() const = 0; TracedVarList InputGrad( const std::string& name, bool drop_empty_grad = true) const { @@ -138,6 +137,10 @@ class GradOpBaseMakerBase { return var_base_map_out_.count(name) > 0; } + static std::shared_ptr NewGradNode() { + return std::make_shared(); + } + private: template TracedVarList GetVarBaseList(const std::string& name, @@ -149,7 +152,13 @@ class GradOpBaseMakerBase { if (iterator != data_map.end()) { vec_temp.reserve(iterator->second.size()); + bool is_valid = false; for (auto& var_base_temp : iterator->second) { + if (!var_base_temp) { + vec_temp.emplace_back(); + continue; + } + if (kRole == TracedVarRole::kBackward) { if (!var_base_temp->HasGradVar()) { VLOG(6) << "GradVarBase of var " << var_base_temp->Name() @@ -168,6 +177,11 @@ class GradOpBaseMakerBase { } else { vec_temp.emplace_back(var_base_temp); } + is_valid = true; + } + + if (!is_valid) { + vec_temp.clear(); } } @@ -185,44 +199,63 @@ class TracedGradOp { DISABLE_COPY_AND_ASSIGN(TracedGradOp); public: - explicit TracedGradOp(const std::shared_ptr& op) : op_(op) {} + explicit TracedGradOp(const std::shared_ptr& node) + : node_(node), op_(&(node->emplace_back())) {} ~TracedGradOp() { - op_->SetGradPendingOps( - {grad_pending_ops_.begin(), grad_pending_ops_.end()}); - op_->CheckAttrs(); + if (UNLIKELY(op_->GetOutsMap().empty())) { + node_->pop_back(); + } else { + op_->CheckAttrs(); + } } template void SetInput(const std::string& name, const TracedVarList& vars) { + if (vars.empty()) { + return; + } + if (kRole == TracedVarRole::kBackward) { for (auto& var : vars) { - var->AddGradOp(op_); + if (var && !var->OverridedStopGradient()) { + var->SetGradNode(node_); + } } } - op_->SetInput(name, ToVarWrapperList(vars)); + + auto var_wrappers = ToVarWrapperList(vars); + if (!var_wrappers.empty()) { + op_->SetInput(name, std::move(var_wrappers), + kRole == TracedVarRole::kBackward); + } } template void SetOutput(const std::string& name, const TracedVarList& vars) { + if (vars.empty()) { + return; + } + if (kRole == TracedVarRole::kBackward) { if (vars.size() == 1 && vars.front()->OverridedStopGradient()) { - op_->SetOutput(name, VariableWrapperList{}); return; } else { for (auto& var : vars) { - if (!var->OverridedStopGradient()) { - for (auto& op : var->GradOps()) { - grad_pending_ops_.emplace(op); - } + if (var && !var->OverridedStopGradient() && var->GradNode()) { + node_->InsertGradPendingNode(var->GradNode()); } } } } - op_->SetOutput(name, ToVarWrapperList(vars)); + auto var_wrappers = ToVarWrapperList(vars); + if (!var_wrappers.empty()) { + op_->SetOutput(name, std::move(var_wrappers), + kRole == TracedVarRole::kBackward); + } } void SetType(const std::string& type) { op_->SetType(type); } @@ -247,19 +280,31 @@ class TracedGradOp { } private: + template static std::vector> ToVarWrapperList( const std::vector>& vars) { std::vector> result; result.reserve(vars.size()); + bool has_valid = false; for (auto& var : vars) { - result.emplace_back(var->SharedVar()); + if (UNLIKELY(!var || (kRole == TracedVarRole::kBackward && + var->OverridedStopGradient()))) { + result.emplace_back(); + } else { + result.emplace_back(var->SharedVar()); + has_valid = true; + } + } + + if (!has_valid) { + result.clear(); } return result; } private: - const std::shared_ptr& op_; - std::unordered_set> grad_pending_ops_; + const std::shared_ptr& node_; + OpBase* op_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/engine.cc b/paddle/fluid/imperative/engine.cc deleted file mode 100644 index cdc09040056..00000000000 --- a/paddle/fluid/imperative/engine.cc +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/imperative/engine.h" - -#include -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/imperative/gradient_accumulator.h" -#include "paddle/fluid/imperative/layer.h" -#include "paddle/fluid/imperative/tracer.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/profiler.h" - -namespace paddle { -namespace imperative { - -void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { - backward_strategy_ = strategy; - const auto& ops = var->GradVarBase()->GradOps(); - var->ClearGradOps(); - - if (ops.empty() || var->OverridedStopGradient()) { - VLOG(3) << "Skip auto grad since there is no grad op for var or loss is " - "stop_gradient=True: " - << var->Name(); - return; - } else { - bool valid = false; - for (const auto& op : ops) { - if (op) { - valid = true; - } - } - if (!valid) { - VLOG(3) << "Skip auto grad since all grad op of start VarBase is nullptr"; - return; - } - } - - init_ops_ = ops; - var->GradVarBase()->ClearGradOps(); - VLOG(3) << "start backward"; - - PADDLE_ENFORCE_EQ(var->HasGradVar(), true, - "Grad variable not exist for variable %s", var->Name()); - - auto& fwd_var = var->Var().Get(); - auto* grad_var = - var->GradVarBase()->MutableVar()->GetMutable(); - VLOG(6) << "init loss grad:" << var->GradVarBase()->Name() - << " as stop_gradient false"; - var->GradVarBase()->InnerSetOverridedStopGradient(false); - auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place()); - grad_var->Resize(fwd_var.dims()); - grad_var->mutable_data(fwd_var.place(), fwd_var.type()); - operators::math::set_constant(*dev_ctx, grad_var, 1.0); -} - -void BasicEngine::CheckBackwardInputs(OpBase* op) { - for (auto& pair : op->GetInsMap()) { - for (auto& var : pair.second) { - if (!var || op->IsAllowedEmptyVar(var.get())) { - continue; - } - - auto* inner_var = var->MutableVar(); - framework::Tensor* tensor = nullptr; - if (!inner_var->IsInitialized() || - inner_var->IsType()) { - tensor = inner_var->GetMutable(); - } - - if (tensor && !tensor->IsInitialized()) { - // if grad var has OverridedStopGradient skip this Op - VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero"; - auto* dev_ctx = - platform::DeviceContextPool::Instance().Get(op->place()); - tensor->mutable_data(op->place(), var->DataType()); - operators::math::set_constant(*dev_ctx, tensor, 0.0); - } - } - } -} - -void BasicEngine::PrepareGradAccumulators(OpBase* op) { - for (const auto& pair : op->GetOutsMap()) { - for (const auto& var : pair.second) { - if (!var) continue; - - auto& accumulator = accumulators_[var.get()]; - if (!accumulator) { - if (backward_strategy_.sorted_sum_gradient_) { - accumulator.reset(new SortedGradientAccumulator(var.get())); - } else { - accumulator.reset(new EagerGradientAccumulator(var.get())); - } - } - - accumulator->IncreaseRefCnt(); - - VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() - << "with reference count " << accumulator->RefCnt(); - } - } -} - -void BasicEngine::PrepareDeps() { - PADDLE_ENFORCE_EQ(op_deps_.empty(), true, "Op deps must be initialized here"); - PADDLE_ENFORCE_EQ(accumulators_.empty(), true, - "Accumulators must be initialized here"); - - std::queue q; - std::unordered_set visited; - for (const auto& init_op : init_ops_) { - q.push(init_op.get()); - visited.insert(init_op.get()); - } - - while (!q.empty()) { - auto* cur_op = q.front(); - q.pop(); - - PADDLE_ENFORCE_NE( - cur_op->GetInsMap().empty() && cur_op->GetOutsMap().empty(), true, - platform::errors::NotFound( - "Inputs and outputs of %s do not exist. " - "This may be because you call \"backward()\" twice for the same " - "subgraph. Please try to call \"stop_gradient = True\" or " - "\"detach()\" if you use some same vars between two \"backward()\" " - "calls.", - cur_op->Type())); - - PrepareGradAccumulators(cur_op); - - const auto& grad_pending_ops = cur_op->GradPendingOps(); - for (auto& grad_pending_op : grad_pending_ops) { - PADDLE_ENFORCE_NOT_NULL(grad_pending_op); - ++op_deps_[grad_pending_op.get()]; - if (visited.count(grad_pending_op.get()) == 0) { - visited.insert(grad_pending_op.get()); - q.push(grad_pending_op.get()); - } - } - } -} - -void BasicEngine::SumGradient(OpBase* op, std::shared_ptr src, - VariableWrapper* dst) { - auto iter = accumulators_.find(dst); - PADDLE_ENFORCE_EQ(iter != accumulators_.end(), true, - "Cannot find gradient of variable %s", dst->Name()); - iter->second->Add(std::move(src), op->id()); -} - -void BasicEngine::Execute() { - PrepareDeps(); - // Start execute Computation graph - std::queue> q; - for (const auto& init_op : init_ops_) { - q.push(std::move(init_op)); - } - - size_t op_num = 0; - - while (!q.empty()) { - auto shared_cur_op = std::move(q.front()); - q.pop(); - - auto* cur_op = shared_cur_op.get(); - ++op_num; - - // CheckBackWardInput - CheckBackwardInputs(cur_op); - - // Step 1: Run Backward - auto& bwd_ins = cur_op->GetInsMap(); - auto& bwd_outs = cur_op->GetOutsMap(); - - NameVarMap tmp_outs(bwd_outs); - // 1. construct the output map 2. replace the element in the map - // A var may be coresponding to several grad var in one op - for (auto it = tmp_outs.begin(); it != tmp_outs.end(); ++it) { - for (size_t i = 0; i < it->second.size(); ++i) { - auto tmp_var = - std::make_shared("Gtmp@"); // Do not need grad - - auto var = it->second[i]; - it->second[i] = tmp_var; - if (var) { - need_accu_var_list_.emplace_back(var.get(), std::move(tmp_var)); - } - } - } - - { - VLOG(3) << "Start to execute grad op " << cur_op->Type(); - OpBase::Run(cur_op->InnerOp(), bwd_ins, tmp_outs, cur_op->Attrs(), - cur_op->place()); - } - - // Step 2: Sum Gradient - - if (need_accu_var_list_.size() > 0) { - for (auto& pair : need_accu_var_list_) { - SumGradient(cur_op, std::move(pair.second), pair.first); - } - } - - need_accu_var_list_.clear(); - - // Step 3: Collect ready ops - - for (auto& grad_pending_op : cur_op->GradPendingOps()) { - PADDLE_ENFORCE_NOT_NULL(grad_pending_op); - auto iter = op_deps_.find(grad_pending_op.get()); - if (iter == op_deps_.end()) { - continue; - } - - VLOG(3) << "Found grad_pending op of " << cur_op->Type(); - // An Op is ready to go while its deps comes to zero - - if (--(iter->second) == 0) { - q.push(grad_pending_op); - VLOG(3) << "Push grad_pending op " << grad_pending_op->Type() - << " into queue"; - } - } - - // Step 4: Delete op to collect unused variables - VLOG(3) << "Remove op after op " << cur_op->Type() << " runs"; - cur_op->ClearBackwardTrace(); - } - Clear(); - - VLOG(1) << "Backward op number: " << op_num; -} -} // namespace imperative -} // namespace paddle diff --git a/paddle/fluid/imperative/engine.h b/paddle/fluid/imperative/engine.h index a66077c4bf0..df2d5ca78b2 100644 --- a/paddle/fluid/imperative/engine.h +++ b/paddle/fluid/imperative/engine.h @@ -14,63 +14,18 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/imperative/backward_strategy.h" -#include "paddle/fluid/imperative/gradient_accumulator.h" -#include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/platform/macros.h" namespace paddle { namespace imperative { -// It seems there is no need for Engine to be an -// singleton, we can have multi-engine to run -// mutil-graoh. For future use we may expose a interface -// to Python to support class Engine { + DISABLE_COPY_AND_ASSIGN(Engine); + public: + Engine() = default; virtual ~Engine() = default; virtual void Execute() = 0; - virtual void Init(VarBase* var, const detail::BackwardStrategy& strategy) = 0; -}; - -class BasicEngine : public Engine { - public: - void Init(VarBase* var, const detail::BackwardStrategy& strategy) override; - - void Execute() override; - - private: - void PrepareDeps(); - - void CheckBackwardInputs(OpBase* op); - - void PrepareGradAccumulators(OpBase* op); - - void SumGradient(OpBase* op, std::shared_ptr src, - VariableWrapper* dst); - - // TODO(jiabin): maybe we can optimize the performance of engine by cache the - // result - void Clear() { - init_ops_.clear(); - op_deps_.clear(); - accumulators_.clear(); - } - - std::vector> init_ops_; - detail::BackwardStrategy backward_strategy_; - std::unordered_map op_deps_; - std::unordered_map> - accumulators_; - - std::vector>> - need_accu_var_list_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h new file mode 100644 index 00000000000..0537370b074 --- /dev/null +++ b/paddle/fluid/imperative/execution_context.h @@ -0,0 +1,199 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/imperative/type_defs.h" + +namespace paddle { +namespace imperative { + +template +class DygraphExecutionContext : public framework::ExecutionContext { + using Variable = framework::Variable; + + public: + DygraphExecutionContext(const framework::OperatorBase& op, + const framework::Scope& scope, + const platform::DeviceContext& device_context, + const framework::RuntimeContext& ctx, + std::vector* configs, + const NameVarMap& var_base_map_in, + const NameVarMap& var_base_map_out, + const framework::AttributeMap& attrs) + : ExecutionContext(op, scope, device_context, ctx, configs), + var_base_map_in_(var_base_map_in), + var_base_map_out_(var_base_map_out), + attrs_(attrs) {} + + std::string InputName(const std::string& name) const override { + auto it = var_base_map_in_.find(name); + PADDLE_ENFORCE_NE(it, var_base_map_in_.end(), + platform::errors::PreconditionNotMet( + "Can not find [%s] in Input", name)); + return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName; + } + + std::vector InputNames(const std::string& name) const override { + auto it = var_base_map_in_.find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_in_.end(), + platform::errors::NotFound("Can not find [%s] in Input", name)); + std::vector vec_res; + vec_res.reserve(it->second.size()); + for (size_t i = 0; i < it->second.size(); ++i) { + if (it->second[i]) { + vec_res.push_back(it->second[i]->Name()); + } else { + vec_res.push_back(framework::kEmptyVarName); + } + } + return vec_res; + } + + std::string OutputName(const std::string& name) const override { + auto it = var_base_map_out_.find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_.end(), + platform::errors::NotFound("Can not find [%s] in Output", name)); + return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName; + } + + std::vector OutputNames(const std::string& name) const override { + auto it = var_base_map_out_.find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_.end(), + platform::errors::NotFound("Can not find [%s] in Output", name)); + std::vector vec_res; + vec_res.reserve(it->second.size()); + for (size_t i = 0; i < it->second.size(); ++i) { + if (it->second[i]) { + vec_res.push_back(it->second[i]->Name()); + } else { + vec_res.push_back(framework::kEmptyVarName); + } + } + return vec_res; + } + + bool HasAttr(const std::string& name) const override { + return attrs_.count(name) != 0; + } + + const framework::AttributeMap& Attrs() const override { return attrs_; } + + const framework::Attribute& GetAttr(const std::string& name) const override { + auto it = attrs_.find(name); + + PADDLE_ENFORCE_NE( + it, attrs_.end(), + platform::errors::NotFound("can not find [%s] in attrs", name)); + + return it->second; + } + + std::vector InNameList() const override { + std::vector vec_temp; + vec_temp.reserve(var_base_map_in_.size()); + + for (auto& v : var_base_map_in_) { + vec_temp.push_back(v.first); + } + + return vec_temp; + } + + bool HasInput(const std::string& name) const override { + auto it = var_base_map_in_.find(name); + return (it != var_base_map_in_.end() && it->second.size() > 0); + } + + bool HasOutput(const std::string& name) const override { + auto it = var_base_map_out_.find(name); + return (it != var_base_map_out_.end() && it->second.size() > 0); + } + + size_t InputSize(const std::string& name) const override { + return InputNames(name).size(); + } + + size_t OutputSize(const std::string& name) const override { + return OutputNames(name).size(); + } + + const Variable* InputVar(const std::string& name) const override { + auto it = var_base_map_in_.find(name); + if (it == var_base_map_in_.end()) { + return nullptr; + } + + return it->second.empty() || it->second[0] == nullptr + ? nullptr + : it->second[0]->MutableVar(); + } + + Variable* OutputVar(const std::string& name) const override { + auto it = var_base_map_out_.find(name); + if (it == var_base_map_out_.end()) { + return nullptr; + } + + return it->second.empty() || it->second[0] == nullptr + ? nullptr + : it->second[0]->MutableVar(); + } + + const std::vector MultiInputVar( + const std::string& name) const override { + auto it = var_base_map_in_.find(name); + if (it == var_base_map_in_.end()) { + return {}; + } + std::vector vec_res; + vec_res.reserve(it->second.size()); + for (size_t i = 0; i < it->second.size(); ++i) { + vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr); + } + + return vec_res; + } + + std::vector MultiOutputVar( + const std::string& name) const override { + auto it = var_base_map_out_.find(name); + if (it == var_base_map_out_.end()) { + return {}; + } + std::vector vec_res; + vec_res.reserve(it->second.size()); + for (size_t i = 0; i < it->second.size(); ++i) { + vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr); + } + + return vec_res; + } + + private: + const NameVarMap& var_base_map_in_; + const NameVarMap& var_base_map_out_; + const framework::AttributeMap& attrs_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 7ef77f1aa7e..e8ea7dc9263 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -29,6 +29,39 @@ namespace paddle { namespace imperative { +static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src, + bool force_copy) { + if (!force_copy) { + *dst = std::move(*src); + return; + } + + VLOG(10) << "Copy occurs when accumulating gradients"; + if (src->IsType()) { + auto& src_tensor = src->Get(); + if (!dst->IsType()) { + dst->Clear(); + } + auto* dst_tensor = dst->GetMutable(); + framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor); + dst_tensor->set_lod(src_tensor.lod()); + } else if (src->IsType()) { + auto& src_selected_rows = src->Get(); + if (!dst->IsType()) { + dst->Clear(); + } + auto* dst_selected_rows = dst->GetMutable(); + framework::TensorCopy(src_selected_rows.value(), + src_selected_rows.value().place(), + dst_selected_rows->mutable_value()); + dst_selected_rows->set_rows(src_selected_rows.rows()); + dst_selected_rows->set_height(src_selected_rows.height()); + } else { + PADDLE_THROW(platform::errors::PermissionDenied( + "Only support LoDTensor and SelectedRows for gradient accumulation")); + } +} + template class TensorAddFunctor : public boost::static_visitor<> { public: @@ -141,6 +174,49 @@ void SelectedRowsAddToTensor(const framework::Variable& src, framework::DataTypeToString(data_type))); } +static void SelectedRowsAddTensor( + const framework::Variable& src_selected_rows_var, + const framework::Variable& src_tensor_var, + framework::Variable* dst_tensor_var) { + const auto& src_selected_rows = + src_selected_rows_var.Get(); + const auto& src_tensor = src_tensor_var.Get(); + const auto& place = src_tensor.place(); + auto data_type = src_tensor.type(); + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + + auto* dst_tensor = dst_tensor_var->GetMutable(); + dst_tensor->Resize(src_tensor.dims()); + dst_tensor->mutable_data(place, data_type); + +#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type) \ + if (data_type == framework::DataTypeTrait::DataType()) { \ + paddle::operators::math::SelectedRowsAddTensor \ + functor; \ + functor(*(dynamic_cast(dev_ctx)), src_selected_rows, \ + src_tensor, dst_tensor); \ + return; \ + } + +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place)) { + PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, float); + PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, double); + } else { +#endif + PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, float); + PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, double); +#ifdef PADDLE_WITH_CUDA + } +#endif + + PADDLE_THROW(platform::errors::InvalidArgument( + "Not supported data type %s for SelectedRowsAddToTensor", + framework::DataTypeToString(data_type))); + +#undef PADDLE_SELECTED_ROWS_ADD_TENSOR +} + // Note(chenweihang): when two selected rows need to be added, // adding one to another is not equal to merging two selected rows // to one then add it to a empty selected rows, the after is correct @@ -189,7 +265,7 @@ std::shared_ptr SelectedRowsMerge( } void VariableWrapperAdd(std::shared_ptr var, - VariableWrapper* var_) { + VariableWrapper* var_, bool unchange_input) { auto& src = var->Var(); auto* dst = var_->MutableVar(); if (dst->IsType()) { @@ -204,10 +280,15 @@ void VariableWrapperAdd(std::shared_ptr var, } } else { if (src.IsType()) { - auto* src_mutable = var->MutableVar(); - SelectedRowsAddToTensor(*dst, src_mutable); - *dst = std::move(*(var->MutableVar())); - var_->SetType(framework::proto::VarType::LOD_TENSOR); + if (unchange_input) { + framework::Variable new_dst; + SelectedRowsAddTensor(*dst, src, &new_dst); + *dst = std::move(new_dst); + } else { + auto* src_mutable = var->MutableVar(); + SelectedRowsAddToTensor(*dst, src_mutable); + *dst = std::move(*(var->MutableVar())); + } } else if (src.IsType()) { auto temp = SelectedRowsMerge(src, *dst); *dst = std::move(*(temp->MutableVar())); @@ -234,18 +315,23 @@ static platform::Place GetPlaceOfVar( } void EagerGradientAccumulator::Add(std::shared_ptr var, - size_t trace_id) { + size_t trace_id, bool unchange_input) { + /** + * If var has grad node, it indicates that this var would be an input + * of a grad op. Therefore, it should not be changed. + */ + if (var->HasGradNode()) { + unchange_input = true; + } + auto* dst_var = var_->MutableVar(); platform::Place place = GetPlaceOfVar(var); if (!var_->OverridedStopGradient()) { VLOG(3) << "Sum Gradient for: " << var_->Name(); if (cur_cnt_ == 0) { - if (var->Var().IsType()) { - var_->SetType(framework::proto::VarType::SELECTED_ROWS); - } - *dst_var = std::move(*(var->MutableVar())); + MoveOrCopyVar(dst_var, var->MutableVar(), unchange_input); } else { - VariableWrapperAdd(var, var_); + VariableWrapperAdd(var, var_, unchange_input); } } else { if (!var_->Var().IsInitialized() || @@ -268,75 +354,91 @@ void EagerGradientAccumulator::Add(std::shared_ptr var, } } ++cur_cnt_; + + if (var_->Var().IsType()) { + var_->SetType(framework::proto::VarType::LOD_TENSOR); + } else if (var_->Var().IsType()) { + var_->SetType(framework::proto::VarType::SELECTED_ROWS); + } } void SortedGradientAccumulator::Add(std::shared_ptr var, - size_t trace_id) { + size_t trace_id, bool unchange_input) { auto* dst_var = var_->MutableVar(); platform::Place place = GetPlaceOfVar(var); if (!var_->OverridedStopGradient()) { if (ref_cnt_ == 1) { - if (var->Var().IsType()) { - var_->SetType(framework::proto::VarType::SELECTED_ROWS); - *dst_var = std::move(*(var->MutableVar())); - } else { - *dst_var = std::move(*(var->MutableVar())); - } + MoveOrCopyVar(dst_var, var->MutableVar(), + unchange_input || var->HasGradNode()); } else { if (tmp_grad_vars_.empty()) { tmp_grad_vars_.reserve(ref_cnt_); } - tmp_grad_vars_.emplace_back(std::move(var), trace_id); + tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input); if (tmp_grad_vars_.size() != ref_cnt_) { return; } - std::sort( - tmp_grad_vars_.begin(), tmp_grad_vars_.end(), - [](const std::pair, size_t>& p1, - const std::pair, size_t>& p2) { - return p1.second > p2.second; - }); + std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(), + [](const SavedVarInfo& info1, const SavedVarInfo& info2) { + return info1.trace_id > info2.trace_id; + }); + + for (auto& var_info : tmp_grad_vars_) { + if (var_info.var->HasGradNode()) { + var_info.unchange_input = true; + } + } #ifdef PADDLE_WITH_CUDA if (paddle::platform::is_gpu_place(place)) { bool dst_varbase_is_initialized = false; // accumulate selected rows firstly - for (size_t i = 0; i < tmp_grad_vars_.size(); ++i) { - if (tmp_grad_vars_[i] - .first->Var() - .IsType()) { - if (!dst_varbase_is_initialized) { - dst_varbase_is_initialized = true; - var_->SetType(framework::proto::VarType::SELECTED_ROWS); - *dst_var = std::move(*(tmp_grad_vars_[i].first->MutableVar())); - } else { - VariableWrapperAdd(tmp_grad_vars_[i].first, var_); - } + for (auto& var_info : tmp_grad_vars_) { + if (!var_info.var->Var().IsType()) { + continue; } - } - // accumulate lod tensor - for (size_t i = 0; i < tmp_grad_vars_.size(); ++i) { + if (!dst_varbase_is_initialized) { dst_varbase_is_initialized = true; - *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); + MoveOrCopyVar(dst_var, var_info.var->MutableVar(), + var_info.unchange_input); + } else { + VariableWrapperAdd(var_info.var, var_, var_info.unchange_input); } - if (tmp_grad_vars_[i].first->Var().IsType()) { - VariableWrapperAdd(tmp_grad_vars_[i].first, var_); + + var_info.var = nullptr; + } + + for (auto& var_info : tmp_grad_vars_) { + if (!var_info.var) { + continue; + } + + PADDLE_ENFORCE_EQ(var_info.var->Var().IsType(), + true, platform::errors::PermissionDenied( + "Gradient var must be LoDTensor")); + + if (!dst_varbase_is_initialized) { + dst_varbase_is_initialized = true; + MoveOrCopyVar(dst_var, var_info.var->MutableVar(), + var_info.unchange_input); + } else { + VariableWrapperAdd(var_info.var, var_, var_info.unchange_input); } + + var_info.var = nullptr; } } else { #endif - if (tmp_grad_vars_[0].first->Var().IsType()) { - var_->SetType(framework::proto::VarType::SELECTED_ROWS); - *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); - } else { - *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); - } + MoveOrCopyVar(dst_var, tmp_grad_vars_[0].var->MutableVar(), + tmp_grad_vars_[0].unchange_input); for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) { - VariableWrapperAdd(tmp_grad_vars_[i].first, var_); + VariableWrapperAdd(tmp_grad_vars_[i].var, var_, + tmp_grad_vars_[i].unchange_input); + tmp_grad_vars_[i].var = nullptr; } #ifdef PADDLE_WITH_CUDA } @@ -364,6 +466,12 @@ void SortedGradientAccumulator::Add(std::shared_ptr var, // looks like tmp_grad_vars will not have any member but just in case tmp_grad_vars_.clear(); } + + if (var_->Var().IsType()) { + var_->SetType(framework::proto::VarType::LOD_TENSOR); + } else if (var_->Var().IsType()) { + var_->SetType(framework::proto::VarType::SELECTED_ROWS); + } } } // namespace imperative diff --git a/paddle/fluid/imperative/gradient_accumulator.h b/paddle/fluid/imperative/gradient_accumulator.h index c2ae3139be3..a8ccb2a38d3 100644 --- a/paddle/fluid/imperative/gradient_accumulator.h +++ b/paddle/fluid/imperative/gradient_accumulator.h @@ -26,7 +26,8 @@ class GradientAccumulator { public: explicit GradientAccumulator(VariableWrapper* var) : var_(var) {} - virtual void Add(std::shared_ptr var, size_t trace_id) = 0; + virtual void Add(std::shared_ptr var, size_t trace_id, + bool unchange_input = false) = 0; virtual ~GradientAccumulator() = default; @@ -43,7 +44,8 @@ class EagerGradientAccumulator : public GradientAccumulator { public: using GradientAccumulator::GradientAccumulator; - void Add(std::shared_ptr var, size_t trace_id) override; + void Add(std::shared_ptr var, size_t trace_id, + bool unchange_input) override; private: size_t cur_cnt_{0}; @@ -53,11 +55,23 @@ class SortedGradientAccumulator : public GradientAccumulator { public: using GradientAccumulator::GradientAccumulator; - void Add(std::shared_ptr var, size_t trace_id) override; + void Add(std::shared_ptr var, size_t trace_id, + bool unchange_input) override; private: - std::vector, size_t>> - tmp_grad_vars_; + struct SavedVarInfo { + SavedVarInfo(std::shared_ptr&& v, size_t id, + bool enable_unchange_input) + : var(std::move(v)), + trace_id(id), + unchange_input(enable_unchange_input) {} + + std::shared_ptr var; + size_t trace_id; + bool unchange_input; + }; + + std::vector tmp_grad_vars_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h new file mode 100644 index 00000000000..a1f307e0e14 --- /dev/null +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -0,0 +1,371 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/shape_inference.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/imperative/variable_wrapper.h" + +namespace paddle { +namespace imperative { + +template +class DygraphInferShapeContext : public framework::InferShapeContext { + using DDim = framework::DDim; + + public: + DygraphInferShapeContext(const NameVarMap* in, + const NameVarMap* out, + const framework::AttributeMap* attr) + : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {} + + bool HasInput(const std::string& name) const override { + // has only one input + auto it = var_base_map_in_->find(name); + + if (it == var_base_map_in_->end()) { + return false; + } + const auto& in = it->second; + if (in.size() == 0) return false; + PADDLE_ENFORCE_EQ( + in.size(), 1UL, + platform::errors::PreconditionNotMet( + "Input %s should not have more than one inputs", name)); + return in[0] != nullptr; + } + + bool HasOutput(const std::string& name) const override { + // has only one output + auto it = var_base_map_out_->find(name); + if (it == var_base_map_out_->end()) { + return false; + } + const auto& out = it->second; + if (out.size() == 0) { + return false; + } + PADDLE_ENFORCE_EQ( + out.size(), 1UL, + platform::errors::PreconditionNotMet( + "Output %s should not have more than one outputs", name)); + return out[0] != nullptr; + } + + bool HasInputs(const std::string& name) const override { + auto it = var_base_map_in_->find(name); + if (it == var_base_map_in_->end() || it->second.empty()) { + return false; + } + for (auto& input : it->second) { + if (input == nullptr) { + return false; + } + } + return true; + } + + bool HasOutputs(const std::string& name) const override { + auto it = var_base_map_out_->find(name); + if (it == var_base_map_out_->end() || it->second.empty()) { + return false; + } + for (auto& output : it->second) { + if (output == nullptr) { + return false; + } + } + return true; + } + + framework::AttrReader Attrs() const override { + return framework::AttrReader(*attrs_); + } + + std::vector Inputs(const std::string& name) const override { + std::vector vec_res; + auto it = var_base_map_in_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_in_->end(), + platform::errors::NotFound("can not find [%s] in input", name)); + + vec_res.reserve(it->second.size()); + for (auto& var : it->second) { + if (var) { + vec_res.push_back(var->Name()); + } else { + vec_res.push_back(framework::kEmptyVarName); + } + } + + return vec_res; + } + + std::vector Outputs(const std::string& name) const override { + std::vector vec_res; + auto it = var_base_map_out_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_->end(), + platform::errors::NotFound("can not find [%s] in output", name)); + + vec_res.reserve(it->second.size()); + for (auto& var : it->second) { + if (var) { + vec_res.push_back(var->Name()); + } else { + vec_res.push_back(framework::kEmptyVarName); + } + } + + return vec_res; + } + + void ShareDim(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) override { + auto in_it = var_base_map_in_->find(in); + auto out_it = var_base_map_out_->find(out); + PADDLE_ENFORCE_NE( + in_it, var_base_map_in_->end(), + platform::errors::NotFound("can not found [%s] in input", in)); + PADDLE_ENFORCE_GT(in_it->second.size(), i, + platform::errors::PreconditionNotMet( + "Inputs %s should have %llu argument", in, i)); + PADDLE_ENFORCE_NE( + out_it, var_base_map_out_->end(), + platform::errors::NotFound("can not found [%s] in input", in)); + PADDLE_ENFORCE_GT(out_it->second.size(), j, + platform::errors::PreconditionNotMet( + "Outputs %s should have %llu argument", out, j)); + + framework::Variable* in_var = in_it->second[i]->MutableVar(); + framework::Variable* out_var = out_it->second[j]->MutableVar(); + + PADDLE_ENFORCE_EQ(in_var->Type(), out_var->Type(), + platform::errors::PreconditionNotMet( + "The type of %s and %s is not the same.", in, out)); + + if (in_var->IsType()) { + auto& in_lod_tensor = in_var->Get(); + auto* out_lod_tensor = out_var->GetMutable(); + out_lod_tensor->Resize(in_lod_tensor.dims()); + } else { + auto& in_sele_rows = in_var->Get(); + auto out_sele_rows = out_var->GetMutable(); + out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); + out_sele_rows->set_rows(in_sele_rows.rows()); + out_sele_rows->set_height(in_sele_rows.height()); + } + } + + void ShareAllLoD(const std::string& in, + const std::string& out) const override { + // do nothing + } + void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const override { + // do nothing + } + + bool IsRuntime() const override { return true; } + + // TODO(paddle-dev): Can this be template? + std::vector GetInputVarPtrs( + const std::string& name) override { + PADDLE_THROW(platform::errors::PermissionDenied( + "GetInputVarPtrs not support in dygraph runtime context")); + } + + std::vector GetOutputVarPtrs( + const std::string& name) override { + PADDLE_THROW(platform::errors::PermissionDenied( + "GetOutputVarPtrs not support in dygraph runtime context")); + } + + DDim GetInputDim(const std::string& name) const override { + auto it = var_base_map_in_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_in_->end(), + platform::errors::NotFound("can not find [%s] in input", name)); + PADDLE_ENFORCE_EQ( + it->second.size(), 1UL, + platform::errors::PreconditionNotMet( + "Input(%s) should hold one element, but now it holds %d", name, + it->second.size())); + return this->GetDim(it->second[0]->MutableVar()); + } + + std::vector GetInputsDim(const std::string& name) const override { + // const std::vector& vars = InputVars(name); + std::vector vec_res; + auto it = var_base_map_in_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_in_->end(), + platform::errors::NotFound("can not find [%s] in output", name)); + vec_res.reserve(it->second.size()); + for (size_t i = 0; i < it->second.size(); ++i) { + if (it->second[i]) { + vec_res.emplace_back(GetDim(it->second[i]->MutableVar())); + } else { + vec_res.emplace_back(); + } + } + + return vec_res; + } + + std::vector GetInputsVarType( + const std::string& name) const override { + std::vector vec_res; + auto it = var_base_map_in_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_in_->end(), + platform::errors::NotFound("can not find [%s] in input", name)); + vec_res.reserve(it->second.size()); + for (size_t i = 0; i < it->second.size(); ++i) { + if (it->second[i]) { + vec_res.emplace_back( + framework::ToVarType(it->second[i]->MutableVar()->Type())); + } else { + vec_res.emplace_back(); + } + } + return vec_res; + } + + std::vector GetOutputsVarType( + const std::string& name) const override { + std::vector vec_res; + auto it = var_base_map_out_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_->end(), + platform::errors::NotFound("can not find [%s] in output", name)); + vec_res.reserve(it->second.size()); + for (size_t i = 0; i < it->second.size(); ++i) { + if (it->second[i]) { + vec_res.emplace_back( + framework::ToVarType(it->second[i]->MutableVar()->Type())); + } else { + vec_res.emplace_back(static_cast(-1)); + } + } + return vec_res; + } + + void SetOutputDim(const std::string& name, const DDim& dim) override { + auto it = var_base_map_out_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_->end(), + platform::errors::NotFound("can not find [%s] in output", name)); + + if (it->second[0]) { + SetDim(it->second[0]->MutableVar(), dim); + } + } + + void SetOutputsDim(const std::string& name, + const std::vector& dims) override { + auto it = var_base_map_out_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_->end(), + platform::errors::NotFound("can not find [%s] in output", name)); + + PADDLE_ENFORCE_EQ(it->second.size(), dims.size(), + platform::errors::PreconditionNotMet( + "dim size [%d] is not match output var number [%d]", + dims.size(), it->second.size())); + + for (size_t i = 0; i < dims.size(); ++i) { + if (it->second[i]) { + SetDim(it->second[i]->MutableVar(), dims[i]); + } + } + } + + int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "GetLoDLevel function not support in dygraph mode")); + } + + void SetLoDLevel(const std::string& out, int32_t lod_level, + size_t j = 0) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "SetLoDLevel function not support in dygraph mode")); + } + + protected: + DDim GetDim(framework::Variable* var) const { + PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet( + "Input variable should not be null")); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().GetCompleteDims(); + } else { + PADDLE_THROW(platform::errors::PermissionDenied( + "Only LoDTensor/SelectedRows support 'GetDim', but Variables " + "type_id is xx.")); + } + } + + std::vector GetRepeatedDims(const std::string& name) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "GetRepeatedDims not support in dygraph runtime")); + } + + void SetDim(framework::Variable* var, const DDim& dim) { + if (var->IsType()) { + var->GetMutable()->Resize(dim); + } else if (var->IsType()) { + var->GetMutable()->set_height(dim[0]); + } else { + PADDLE_THROW(platform::errors::PermissionDenied( + "Variable type_id %s, expect LoDTensor/SelectedRows.")); + } + } + + void SetDims(const std::vector& vars, + const std::vector& dims) { + size_t length = vars.size(); + PADDLE_ENFORCE_EQ( + length, dims.size(), + platform::errors::PreconditionNotMet( + "Vars number [%d] should be equal with dims number [%d]", length, + dims.size())); + for (size_t i = 0; i < length; ++i) { + if (vars[i] == nullptr) { + continue; + } + SetDim(vars[i], dims[i]); + } + } + + void SetRepeatedDims(const std::string& name, + const std::vector& dims) override { + PADDLE_THROW(platform::errors::PermissionDenied( + "SetRepeatedDims not support in dygraph runtime")); + } + + private: + const NameVarMap* var_base_map_in_; + const NameVarMap* var_base_map_out_; + const framework::AttributeMap* attrs_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/infer_var_type_context.h b/paddle/fluid/imperative/infer_var_type_context.h new file mode 100644 index 00000000000..e46ac315d2d --- /dev/null +++ b/paddle/fluid/imperative/infer_var_type_context.h @@ -0,0 +1,188 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/imperative/variable_wrapper.h" + +namespace paddle { +namespace imperative { + +// infer var type context for imperative mode +template +class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { + public: + RuntimeInferVarTypeContext(const NameVarMap& inputs, + const NameVarMap& outputs, + const framework::AttributeMap& attrs_map) + : InferVarTypeContext(nullptr, nullptr), + inputs_(inputs), + outputs_(outputs), + attrs_(attrs_map), + input_names_(), + output_names_(), + var_set_() { + input_names_.reserve(inputs_.size()); + for (auto& it : inputs_) { + for (auto& var : it.second) { + if (var) { + input_names_[it.first].emplace_back(var->Name()); + var_set_[var->Name()] = var.get(); + } + } + } + + output_names_.reserve(outputs_.size()); + for (auto& it : outputs_) { + for (auto& var : it.second) { + if (var) { + output_names_[it.first].emplace_back(var->Name()); + var_set_[var->Name()] = var.get(); + } + } + } + } + + virtual ~RuntimeInferVarTypeContext() {} + + framework::Attribute GetAttr(const std::string& name) const override { + auto iter = attrs_.find(name); + PADDLE_ENFORCE_EQ( + iter != attrs_.end(), true, + platform::errors::NotFound("Cannot find attribute %s", name)); + return iter->second; + } + + bool HasVar(const std::string& name) const override { + return var_set_.count(name) > 0; + } + + bool HasInput(const std::string& name) const override { + auto it = inputs_.find(name); + return (it != inputs_.end() && it->second.size() > 0); + } + + bool HasOutput(const std::string& name) const override { + auto it = outputs_.find(name); + return (it != outputs_.end() && it->second.size() > 0); + } + + const std::vector& Input( + const std::string& name) const override { + auto iter = input_names_.find(name); + PADDLE_ENFORCE_EQ( + iter != input_names_.end(), true, + platform::errors::NotFound("Cannot find input var %s", name)); + return iter->second; + } + + const std::vector& Output( + const std::string& name) const override { + auto iter = output_names_.find(name); + + PADDLE_ENFORCE_EQ( + iter != output_names_.end(), true, + platform::errors::NotFound("Cannot find output var %s", name)); + return iter->second; + } + + framework::proto::VarType::Type GetType( + const std::string& name) const override { + auto iter = var_set_.find(name); + + PADDLE_ENFORCE_EQ( + iter != var_set_.end(), true, + platform::errors::NotFound("Cannot find var %s in GetType", name)); + return iter->second->Type(); + } + + void SetType(const std::string& name, + framework::proto::VarType::Type type) override { + if (name == "kLookupTablePath") { + VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++"; + } else { + var_set_[name]->SetType(type); + if ((var_set_[name]->MutableVar()->IsInitialized() == true) && + (var_set_[name]->MutableVar()->Type() != type)) { + var_set_[name]->MutableVar()->Clear(); + } + } + } + + framework::proto::VarType::Type GetDataType( + const std::string& name) const override { + auto iter = var_set_.find(name); + + PADDLE_ENFORCE_EQ( + iter != var_set_.end(), true, + platform::errors::NotFound("Cannot find var %s in GetDataType", name)); + return iter->second->DataType(); + } + + void SetDataType(const std::string& name, + framework::proto::VarType::Type type) override { + var_set_[name]->SetDataType(type); + } + + std::vector GetDataTypes( + const std::string& name) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "GetDataTypes is not supported in runtime InferVarType")); + } + + void SetDataTypes(const std::string& name, + const std::vector& + multiple_data_type) override { + PADDLE_THROW(platform::errors::PermissionDenied( + "SetDataTypes is not supported in runtime InferVarType")); + } + + std::vector GetShape(const std::string& name) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "Do not handle Shape in runtime InferVarType")); + } + + void SetShape(const std::string& name, + const std::vector& dims) override { + PADDLE_THROW(platform::errors::PermissionDenied( + "Do not handle Shape in runtime InferVarType")); + } + + int32_t GetLoDLevel(const std::string& name) const override { + PADDLE_THROW(platform::errors::PermissionDenied( + "Do not handle LoDLevel in runtime InferVarType")); + } + + void SetLoDLevel(const std::string& name, int32_t lod_level) override { + PADDLE_THROW(platform::errors::PermissionDenied( + "Do not handle LoDLevel in runtime InferVarType")); + } + + private: + const NameVarMap& inputs_; + const NameVarMap& outputs_; + const framework::AttributeMap& attrs_; + std::unordered_map> input_names_; + std::unordered_map> output_names_; + std::unordered_map var_set_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index face5f2b50b..3936435273f 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -19,6 +19,10 @@ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/imperative/execution_context.h" +#include "paddle/fluid/imperative/infer_shape_context.h" +#include "paddle/fluid/imperative/infer_var_type_context.h" +#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/prepared_operator.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" @@ -180,7 +184,7 @@ static std::string LayerDebugStringImpl(const std::string& op_type, size_t i = 0; for (auto& pair : ins) { if (i > 0) ss << ", "; - ss << DebugString(pair.first, pair.second); + ss << DebugString(pair.first, pair.second); ++i; } @@ -188,7 +192,7 @@ static std::string LayerDebugStringImpl(const std::string& op_type, i = 0; for (auto& pair : outs) { if (i > 0) ss << ", "; - ss << DebugString(pair.first, pair.second); + ss << DebugString(pair.first, pair.second); ++i; } return ss.str(); @@ -206,6 +210,27 @@ std::string LayerDebugString(const std::string& op_type, return LayerDebugStringImpl(op_type, ins, outs); } +VarBase::VarBase(bool has_grad, const std::shared_ptr& var) + : var_(var), grad_node_(var->GetGradNode()) { + if (has_grad) { + if (auto grad_var = var_->GetGradVar()) { + grad_var_ = std::make_shared(false, grad_var); + } else { + grad_var_ = std::make_shared(false, GradVarName()); + var_->SetGradVar(grad_var_->var_); + } + } + + if (IsDebugEnabled()) { + VLOG(10) << "Construct VarBase: " << Name(); + name_set_.Insert(Name()); + } +} + +size_t VarBase::GradOpNum() const { + return grad_node_ ? grad_node_->size() : 0; +} + void VarBase::ClearGradient() { if (grad_var_) { if (grad_var_->Var().IsType()) { @@ -292,8 +317,6 @@ void OpBase::SetType(const std::string& type) { } void OpBase::ClearBackwardTrace() { - grad_pending_ops_.clear(); - allow_empty_vars_.clear(); ins_.clear(); outs_.clear(); } @@ -308,14 +331,16 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); auto& info = op.Info(); if (info.infer_var_type_) { - RuntimeInferVarTypeContext infer_var_type_ctx(ins, &outs, attrs); + RuntimeInferVarTypeContext infer_var_type_ctx(ins, outs, attrs); info.infer_var_type_(&infer_var_type_ctx); } // Initialize output var type for (auto& var_pair : outs) { for (auto& var : var_pair.second) { - InitializeVariable(var->MutableVar(), var->Type()); + if (var) { + InitializeVariable(var->MutableVar(), var->Type()); + } } } @@ -344,5 +369,64 @@ void OpBase::Run(const framework::OperatorBase& op, OpBaseRunImpl(op, ins, outs, attrs, place); } +static void ClearNoNeedBufferInputs(OpBase* op) { + auto& inferer = op->Info().NoNeedBufferVarsInferer(); + if (!inferer) return; + auto* ins = op->GetMutableInsMap(); + const auto& no_need_buffer_slots = + inferer(*ins, op->GetOutsMap(), op->Attrs()); + if (no_need_buffer_slots.empty()) return; + + for (auto& slot : no_need_buffer_slots) { + auto iter = ins->find(slot); + if (iter == ins->end()) continue; + VLOG(2) << "Clear data buffer of " << slot << " in " << op->Type(); + + PADDLE_ENFORCE_EQ( + iter->second.IsGrad(), false, + platform::errors::InvalidArgument( + "Only forward variable buffers can be clear, this may be a bug")); + + for (auto& each_var : *(iter->second.MutableVarList())) { + if (!each_var) continue; + + auto& var = each_var->Var(); + PADDLE_ENFORCE_EQ(var.IsType(), true, + platform::errors::PermissionDenied( + "NoNeedBufferVars only support LoDTensor")); + // TODO(zjl): support higher order derivatives + auto new_var = new VariableWrapper(each_var->Name()); + auto* new_tensor = + new_var->MutableVar()->GetMutable(); + auto& old_tensor = var.Get(); + new_tensor->Resize(old_tensor.dims()); + new_tensor->set_lod(old_tensor.lod()); + each_var.reset(new_var); + } + } +} + +std::shared_ptr CreateGradOpNode( + const framework::OperatorBase& op, const NameVarBaseMap& ins, + const NameVarBaseMap& outs, const framework::AttributeMap& attrs, + const platform::Place& place) { + const auto& info = op.Info(); + if (!info.dygraph_grad_op_maker_) { + return nullptr; + } + + auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs); + if (grad_node && !grad_node->empty()) { + for (auto& op : *grad_node) { + op.SetId(OpBase::GenerateUniqueId()); + op.SetPlace(place); + ClearNoNeedBufferInputs(&op); + } + return grad_node; + } else { + return nullptr; + } +} + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 855e3e81995..9a9f0f9bef7 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -14,27 +14,22 @@ #pragma once #include -#include #include #include #include #include -#include // NOLINT #include #include #include #include #include #include -#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/type_defs.h" -#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/imperative/flags.h" +#include "paddle/fluid/imperative/saved_variable_wrapper_list.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/variable_wrapper.h" #include "paddle/fluid/platform/enforce.h" @@ -63,9 +58,15 @@ class VarBase { public: static std::vector AliveVarNames(); + + public: explicit VarBase(bool has_grad, const std::string& name) : var_(std::make_shared(name)), grad_var_(has_grad ? new VarBase(false, GradVarName()) : nullptr) { + if (has_grad) { + var_->SetGradVar(grad_var_->var_); + } + if (IsDebugEnabled()) { VLOG(10) << "Construct VarBase: " << Name(); name_set_.Insert(Name()); @@ -74,6 +75,9 @@ class VarBase { explicit VarBase(const std::string& name) : VarBase(true, name) {} + // NOTE(zengjinle): be careful when you use this constructor!!! + explicit VarBase(bool has_grad, const std::shared_ptr& var); + ~VarBase() { VLOG(10) << "Destruct VarBase: " << Name(); if (IsDebugEnabled()) { @@ -95,9 +99,15 @@ class VarBase { const std::shared_ptr& MutableGradVarBase() { if (grad_var_ == nullptr) { - grad_var_ = std::make_shared(false, GradVarName()); - // NOTE(zhiqiu): we should keep grad_var_'s stop_gradient property same as - // fwd varbase + if (auto grad_var_wrapper = var_->GetGradVar()) { + grad_var_ = std::make_shared(false, grad_var_wrapper); + } else { + grad_var_ = std::make_shared(false, GradVarName()); + var_->SetGradVar(grad_var_->var_); + grad_var_->var_->SetGradNode(grad_var_->grad_node_); + } + // NOTE(zhiqiu): we should keep grad_var_'s stop_gradient property + // same as fwd varbase grad_var_->SetOverridedStopGradient(var_->InnerOverridedStopGradient()); } return grad_var_; @@ -140,18 +150,16 @@ class VarBase { bool Persistable() const { return var_->Persistable(); } // Only grad var is allowed to call these 2 methods - void AddGradOp(const std::shared_ptr& op) { - if (op && - std::find(grad_ops_.begin(), grad_ops_.end(), op) == grad_ops_.end()) { - grad_ops_.emplace_back(op); - } + void SetGradNode(const std::shared_ptr& node) { + grad_node_ = node; + var_->SetGradNode(node); } - const std::vector>& GradOps() const { - return grad_ops_; - } + size_t GradOpNum() const; - void ClearGradOps() { grad_ops_.clear(); } + const std::shared_ptr& GradNode() const { return grad_node_; } + + void ClearGradNode() { SetGradNode(nullptr); } const std::string& Name() const { return var_->Name(); } @@ -191,15 +199,18 @@ class VarBase { const std::shared_ptr var_; std::shared_ptr grad_var_; - std::vector> grad_ops_; + + /** + * NOTE(zengjinle): should consider whether to implement an inlined vector + * or other things like that. + */ + std::shared_ptr grad_node_; mutable size_t copied_counter_ = 0; static ThreadSafeNameSet name_set_; }; -using VariableWrapperList = std::vector>; - class Layer { public: virtual ~Layer() {} @@ -210,760 +221,10 @@ class Layer { } }; -template -class DygraphExecutionContext : public framework::ExecutionContext { - using Variable = framework::Variable; - - public: - DygraphExecutionContext(const framework::OperatorBase& op, - const framework::Scope& scope, - const platform::DeviceContext& device_context, - const framework::RuntimeContext& ctx, - std::vector* configs, - const NameVarMap& var_base_map_in, - const NameVarMap& var_base_map_out, - const framework::AttributeMap& attrs) - : ExecutionContext(op, scope, device_context, ctx, configs), - var_base_map_in_(var_base_map_in), - var_base_map_out_(var_base_map_out), - attrs_(attrs) {} - - std::string InputName(const std::string& name) const override { - auto it = var_base_map_in_.find(name); - PADDLE_ENFORCE_NE(it, var_base_map_in_.end(), - platform::errors::PreconditionNotMet( - "Can not find [%s] in Input", name)); - return it->second[0]->Name(); - } - std::vector InputNames(const std::string& name) const override { - auto it = var_base_map_in_.find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_in_.end(), - platform::errors::NotFound("Can not find [%s] in Input", name)); - std::vector vec_res; - vec_res.reserve(it->second.size()); - for (size_t i = 0; i < it->second.size(); ++i) { - vec_res.push_back(it->second[i]->Name()); - } - return vec_res; - } - - std::string OutputName(const std::string& name) const override { - auto it = var_base_map_out_.find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_out_.end(), - platform::errors::NotFound("Can not find [%s] in Output", name)); - return it->second[0]->Name(); - } - - std::vector OutputNames(const std::string& name) const override { - auto it = var_base_map_out_.find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_out_.end(), - platform::errors::NotFound("Can not find [%s] in Output", name)); - std::vector vec_res; - vec_res.reserve(it->second.size()); - for (size_t i = 0; i < it->second.size(); ++i) { - vec_res.push_back(it->second[i]->Name()); - } - return vec_res; - } - - bool HasAttr(const std::string& name) const override { - return attrs_.count(name) != 0; - } - - const framework::AttributeMap& Attrs() const override { return attrs_; } - - const framework::Attribute& GetAttr(const std::string& name) const override { - auto it = attrs_.find(name); - - PADDLE_ENFORCE_NE( - it, attrs_.end(), - platform::errors::NotFound("can not find [%s] in attrs", name)); - - return it->second; - } - - std::vector InNameList() const override { - std::vector vec_temp; - vec_temp.reserve(var_base_map_in_.size()); - - for (auto& v : var_base_map_in_) { - vec_temp.push_back(v.first); - } - - return vec_temp; - } - bool HasInput(const std::string& name) const override { - auto it = var_base_map_in_.find(name); - return (it != var_base_map_in_.end() && it->second.size() > 0); - } - - bool HasOutput(const std::string& name) const override { - auto it = var_base_map_out_.find(name); - return (it != var_base_map_out_.end() && it->second.size() > 0); - } - - size_t InputSize(const std::string& name) const override { - return InputNames(name).size(); - } - - size_t OutputSize(const std::string& name) const override { - return OutputNames(name).size(); - } - - const Variable* InputVar(const std::string& name) const override { - auto it = var_base_map_in_.find(name); - if (it == var_base_map_in_.end()) { - return nullptr; - } - - return it->second.empty() ? nullptr : it->second[0]->MutableVar(); - } - - Variable* OutputVar(const std::string& name) const override { - auto it = var_base_map_out_.find(name); - if (it == var_base_map_out_.end()) { - return nullptr; - } - - return it->second.empty() ? nullptr : it->second[0]->MutableVar(); - } - - const std::vector MultiInputVar( - const std::string& name) const override { - auto it = var_base_map_in_.find(name); - if (it == var_base_map_in_.end()) { - return {}; - } - std::vector vec_res; - vec_res.reserve(it->second.size()); - for (size_t i = 0; i < it->second.size(); ++i) { - vec_res.push_back(it->second[i]->MutableVar()); - } - - return vec_res; - } - - std::vector MultiOutputVar( - const std::string& name) const override { - auto it = var_base_map_out_.find(name); - if (it == var_base_map_out_.end()) { - return {}; - } - std::vector vec_res; - vec_res.reserve(it->second.size()); - for (size_t i = 0; i < it->second.size(); ++i) { - vec_res.push_back(it->second[i]->MutableVar()); - } - - return vec_res; - } - - private: - const NameVarMap& var_base_map_in_; - const NameVarMap& var_base_map_out_; - const framework::AttributeMap& attrs_; -}; - -// infer var type context for imperative mode -template -class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { - public: - RuntimeInferVarTypeContext(const NameVarMap& inputs, - const NameVarMap* outputs, - const framework::AttributeMap& attrs_map) - : InferVarTypeContext(nullptr, nullptr), - inputs_(inputs), - outputs_(outputs), - attrs_(attrs_map), - input_names_(), - output_names_(), - var_set_() { - input_names_.reserve(inputs_.size()); - for (auto& it : inputs_) { - for (auto& var : it.second) { - input_names_[it.first].emplace_back(var->Name()); - var_set_[var->Name()] = var.get(); - } - } - - output_names_.reserve(outputs_->size()); - for (auto& it : *outputs_) { - for (auto& var : it.second) { - output_names_[it.first].emplace_back(var->Name()); - var_set_[var->Name()] = var.get(); - } - } - } - - virtual ~RuntimeInferVarTypeContext() {} - - framework::Attribute GetAttr(const std::string& name) const override { - auto iter = attrs_.find(name); - PADDLE_ENFORCE_EQ(iter != attrs_.end(), true, "Cannot find attribute %s", - name); - return iter->second; - } - - bool HasVar(const std::string& name) const override { - return var_set_.count(name) > 0; - } - - bool HasInput(const std::string& name) const override { - auto it = inputs_.find(name); - return (it != inputs_.end() && it->second.size() > 0); - } - - bool HasOutput(const std::string& name) const override { - PADDLE_ENFORCE_NOT_NULL(outputs_); - auto it = outputs_->find(name); - return (it != outputs_->end() && it->second.size() > 0); - } - - const std::vector& Input( - const std::string& name) const override { - auto iter = input_names_.find(name); - PADDLE_ENFORCE_EQ(iter != input_names_.end(), true, "Cannot find input %s", - name); - return iter->second; - } - - const std::vector& Output( - const std::string& name) const override { - auto iter = output_names_.find(name); - - PADDLE_ENFORCE_EQ(iter != output_names_.end(), true, - "Cannot find output %s", name); - return iter->second; - } - - framework::proto::VarType::Type GetType( - const std::string& name) const override { - auto iter = var_set_.find(name); - - PADDLE_ENFORCE_EQ(iter != var_set_.end(), true, - "Cannot find var %s in GetType", name); - return iter->second->Type(); - } - - void SetType(const std::string& name, - framework::proto::VarType::Type type) override { - if (name == "kLookupTablePath") { - VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++"; - } else { - var_set_[name]->SetType(type); - if ((var_set_[name]->MutableVar()->IsInitialized() == true) && - (var_set_[name]->MutableVar()->Type() != type)) { - var_set_[name]->MutableVar()->Clear(); - } - } - } - - framework::proto::VarType::Type GetDataType( - const std::string& name) const override { - auto iter = var_set_.find(name); - - PADDLE_ENFORCE_EQ(iter != var_set_.end(), true, - "Cannot find var %s in GetDataType", name); - return iter->second->DataType(); - } - - void SetDataType(const std::string& name, - framework::proto::VarType::Type type) override { - var_set_[name]->SetDataType(type); - } - - std::vector GetDataTypes( - const std::string& name) const override { - PADDLE_THROW("GetDataTypes is not supported in runtime InferVarType"); - } - - void SetDataTypes(const std::string& name, - const std::vector& - multiple_data_type) override { - PADDLE_THROW("SetDataTypes is not supported in runtime InferVarType"); - } - - std::vector GetShape(const std::string& name) const override { - PADDLE_THROW("Do not handle Shape in runtime InferVarType"); - } - - void SetShape(const std::string& name, - const std::vector& dims) override { - PADDLE_THROW("Do not handle Shape in runtime InferVarType"); - } - - int32_t GetLoDLevel(const std::string& name) const override { - PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType"); - } - - void SetLoDLevel(const std::string& name, int32_t lod_level) override { - PADDLE_THROW("Do not handle LoDLevel in runtime InferVarType"); - } - - private: - const NameVarMap& inputs_; - const NameVarMap* outputs_; - const framework::AttributeMap& attrs_; - std::unordered_map> input_names_; - std::unordered_map> output_names_; - std::unordered_map var_set_; -}; - -// TODO(zjl): to support py_func layer -class OpBase { - DISABLE_COPY_AND_ASSIGN(OpBase); - - public: - OpBase() = default; - - ~OpBase() { VLOG(3) << "Destruct Op: " << Type(); } - - size_t id() const { return id_; } - - const std::string& Type() const { return op_->Type(); } - - const framework::AttributeMap& Attrs() const { return attrs_; } - - const framework::OpInfo& Info() const { return op_->Info(); } - - const framework::OperatorBase& InnerOp() const { return *op_; } - - void ClearBackwardTrace(); - - const std::vector>& GradPendingOps() const { - return grad_pending_ops_; - } - - void SetGradPendingOps(std::vector> pending_ops) { - grad_pending_ops_ = std::move(pending_ops); - } - - NameVarMap* GetMutableOutsMap() { return &outs_; } - - NameVarMap* GetMutableInsMap() { return &ins_; } - - const NameVarMap& GetInsMap() { return ins_; } - - const NameVarMap& GetOutsMap() { return outs_; } - - const platform::Place& place() const { return place_; } - - // TODO(jiabin) prepare for backward hook - void RegisterBackwardHooks(const std::function& func) { - backward_hooks_.emplace_back(func); - } - - void InvokeBackwardHooks() { - for (const auto& func : backward_hooks_) { - func(); - VLOG(5) << "Invoke Backward Hook for: " << Type() << std::endl; - } - } - - void SetType(const std::string& type); - - void CheckAttrs() { - auto& info = op_->Info(); - if (info.Checker() != nullptr) { - info.Checker()->Check(&attrs_, true); - } - } - - void SetInput(const std::string& name, VariableWrapperList vars) { - ins_[name] = std::move(vars); - } - - void SetOutput(const std::string& name, VariableWrapperList vars) { - outs_[name] = std::move(vars); - } - - void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } - - void SetAttr(const std::string& name, const framework::Attribute& v) { - attrs_[name] = v; - } - - void SetBlockAttr(const std::string& name, framework::BlockDesc* block) { - PADDLE_THROW("SetBlockAttr is not support in dygraph OpBase"); - } - - const framework::AttributeMap& Attrs() { return attrs_; } - - void SetId(size_t id) { id_ = id; } - - void SetPlace(const platform::Place& place) { place_ = place; } - - bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; } - - const framework::Attribute& GetAttr(const std::string& name) const { - auto it = attrs_.find(name); - PADDLE_ENFORCE(it != attrs_.end(), "can not find attribute [%s]", name); - - return it->second; - } - - template - inline const T& Attr(const std::string& name) const { - return boost::get(GetAttr(name)); - } - - void AddAllowedEmptyVar(const VariableWrapper* var) { - allow_empty_vars_.emplace(var); - } - - bool IsAllowedEmptyVar(const VariableWrapper* var) { - return allow_empty_vars_.count(var) > 0; - } - - static void Run(const framework::OperatorBase& op, - const NameVarMap& ins, - const NameVarMap& outs, - const framework::AttributeMap& attrs, - const platform::Place& place); - - static void Run(const framework::OperatorBase& op, - const NameVarMap& ins, - const NameVarMap& outs, - const framework::AttributeMap& attrs, - const platform::Place& place); - - private: - NameVarMap ins_; - NameVarMap outs_; - framework::AttributeMap attrs_; - std::unique_ptr op_; - - std::vector> grad_pending_ops_; - platform::Place place_; - - std::unordered_set allow_empty_vars_; - - size_t id_{-1UL}; - - std::vector> backward_hooks_; -}; - -template -class DygraphInferShapeContext : public framework::InferShapeContext { - using DDim = framework::DDim; - - public: - DygraphInferShapeContext(const NameVarMap* in, - const NameVarMap* out, - const framework::AttributeMap* attr) - : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {} - - bool HasInput(const std::string& name) const override { - // has only one input - auto it = var_base_map_in_->find(name); - - if (it == var_base_map_in_->end()) { - return false; - } - const auto& in = it->second; - if (in.size() == 0) return false; - PADDLE_ENFORCE_EQ( - in.size(), 1UL, - platform::errors::PreconditionNotMet( - "Input %s should not have more than one inputs", name)); - return in[0] != nullptr; - } - - bool HasOutput(const std::string& name) const override { - // has only one output - auto it = var_base_map_out_->find(name); - if (it == var_base_map_out_->end()) { - return false; - } - const auto& out = it->second; - if (out.size() == 0) { - return false; - } - PADDLE_ENFORCE_EQ( - out.size(), 1UL, - platform::errors::PreconditionNotMet( - "Output %s should not have more than one outputs", name)); - return out[0] != nullptr; - } - - bool HasInputs(const std::string& name) const override { - auto it = var_base_map_in_->find(name); - if (it == var_base_map_in_->end() || it->second.empty()) { - return false; - } - for (auto& input : it->second) { - if (input == nullptr) { - return false; - } - } - return true; - } - - bool HasOutputs(const std::string& name) const override { - auto it = var_base_map_out_->find(name); - if (it == var_base_map_out_->end() || it->second.empty()) { - return false; - } - for (auto& output : it->second) { - if (output == nullptr) { - return false; - } - } - return true; - } - - framework::AttrReader Attrs() const override { - return framework::AttrReader(*attrs_); - } - - std::vector Inputs(const std::string& name) const override { - // return op_.Inputs(name); - std::vector vec_res; - auto it = var_base_map_in_->find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_in_->end(), - platform::errors::NotFound("can not find [%s] in input", name)); - - vec_res.reserve(it->second.size()); - for (auto& var : it->second) { - vec_res.push_back(var->Name()); - } - - return vec_res; - } - - std::vector Outputs(const std::string& name) const override { - std::vector vec_res; - auto it = var_base_map_out_->find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_out_->end(), - platform::errors::NotFound("can not find [%s] in output", name)); - - vec_res.reserve(it->second.size()); - for (auto& var : it->second) { - vec_res.push_back(var->Name()); - } - - return vec_res; - } - - void ShareDim(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) override { - auto in_it = var_base_map_in_->find(in); - auto out_it = var_base_map_out_->find(out); - PADDLE_ENFORCE_NE( - in_it, var_base_map_in_->end(), - platform::errors::NotFound("can not found [%s] in input", in)); - PADDLE_ENFORCE_GT(in_it->second.size(), i, - platform::errors::PreconditionNotMet( - "Inputs %s should have %llu argument", in, i)); - PADDLE_ENFORCE_NE( - out_it, var_base_map_out_->end(), - platform::errors::NotFound("can not found [%s] in input", in)); - PADDLE_ENFORCE_GT(out_it->second.size(), j, - platform::errors::PreconditionNotMet( - "Outputs %s should have %llu argument", out, j)); - - framework::Variable* in_var = in_it->second[i]->MutableVar(); - framework::Variable* out_var = out_it->second[j]->MutableVar(); - - PADDLE_ENFORCE_EQ(in_var->Type(), out_var->Type(), - platform::errors::PreconditionNotMet( - "The type of %s and %s is not the same.", in, out)); - - if (in_var->IsType()) { - auto& in_lod_tensor = in_var->Get(); - auto* out_lod_tensor = out_var->GetMutable(); - out_lod_tensor->Resize(in_lod_tensor.dims()); - } else { - auto& in_sele_rows = in_var->Get(); - auto out_sele_rows = out_var->GetMutable(); - out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); - out_sele_rows->set_rows(in_sele_rows.rows()); - out_sele_rows->set_height(in_sele_rows.height()); - } - } - - void ShareAllLoD(const std::string& in, - const std::string& out) const override { - // do nothing - } - void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) const override { - // do nothing - } - - bool IsRuntime() const override { return true; } - - // TODO(paddle-dev): Can this be template? - std::vector GetInputVarPtrs( - const std::string& name) override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetInputVarPtrs not support in dygraph runtime context")); - } - - std::vector GetOutputVarPtrs( - const std::string& name) override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetOutputVarPtrs not support in dygraph runtime context")); - } - - DDim GetInputDim(const std::string& name) const override { - auto it = var_base_map_in_->find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_in_->end(), - platform::errors::NotFound("can not find [%s] in input", name)); - PADDLE_ENFORCE_EQ( - it->second.size(), 1UL, - platform::errors::PreconditionNotMet( - "Input(%s) should hold one element, but now it holds %d", name, - it->second.size())); - return this->GetDim(it->second[0]->MutableVar()); - } - - std::vector GetInputsDim(const std::string& name) const override { - // const std::vector& vars = InputVars(name); - std::vector vec_res; - auto it = var_base_map_in_->find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_in_->end(), - platform::errors::NotFound("can not find [%s] in output", name)); - vec_res.reserve(it->second.size()); - for (size_t i = 0; i < it->second.size(); ++i) { - vec_res.emplace_back(GetDim(it->second[i]->MutableVar())); - } - - return vec_res; - } - - std::vector GetInputsVarType( - const std::string& name) const override { - std::vector vec_res; - auto it = var_base_map_in_->find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_in_->end(), - platform::errors::NotFound("can not find [%s] in input", name)); - vec_res.reserve(it->second.size()); - for (size_t i = 0; i < it->second.size(); ++i) { - vec_res.emplace_back( - framework::ToVarType(it->second[i]->MutableVar()->Type())); - } - return vec_res; - } - - std::vector GetOutputsVarType( - const std::string& name) const override { - std::vector vec_res; - auto it = var_base_map_out_->find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_out_->end(), - platform::errors::NotFound("can not find [%s] in output", name)); - vec_res.reserve(it->second.size()); - for (size_t i = 0; i < it->second.size(); ++i) { - vec_res.emplace_back( - framework::ToVarType(it->second[i]->MutableVar()->Type())); - } - return vec_res; - } - - void SetOutputDim(const std::string& name, const DDim& dim) override { - auto it = var_base_map_out_->find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_out_->end(), - platform::errors::NotFound("can not find [%s] in output", name)); - - SetDim(it->second[0]->MutableVar(), dim); - } - - void SetOutputsDim(const std::string& name, - const std::vector& dims) override { - auto it = var_base_map_out_->find(name); - PADDLE_ENFORCE_NE( - it, var_base_map_out_->end(), - platform::errors::NotFound("can not find [%s] in output", name)); - - PADDLE_ENFORCE_EQ(it->second.size(), dims.size(), - platform::errors::PreconditionNotMet( - "dim size [%d] is not match output var number [%d]", - dims.size(), it->second.size())); - - for (size_t i = 0; i < dims.size(); ++i) { - SetDim(it->second[i]->MutableVar(), dims[i]); - } - } - - int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetLoDLevel function not support in dygraph mode")); - } - - void SetLoDLevel(const std::string& out, int32_t lod_level, - size_t j = 0) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "SetLoDLevel function not support in dygraph mode")); - } - - protected: - DDim GetDim(framework::Variable* var) const { - PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet( - "Input variable should not be null")); - if (var->IsType()) { - return var->Get().dims(); - } else if (var->IsType()) { - return var->Get().GetCompleteDims(); - } else { - PADDLE_THROW(platform::errors::PermissionDenied( - "Only LoDTensor/SelectedRows support 'GetDim', but Variables " - "type_id is xx.")); - } - } - - std::vector GetRepeatedDims(const std::string& name) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetRepeatedDims not support in dygraph runtime")); - } - - void SetDim(framework::Variable* var, const DDim& dim) { - if (var->IsType()) { - var->GetMutable()->Resize(dim); - } else if (var->IsType()) { - var->GetMutable()->set_height(dim[0]); - } else { - PADDLE_THROW(platform::errors::PermissionDenied( - "Variable type_id %s, expect LoDTensor/SelectedRows.")); - } - } - - void SetDims(const std::vector& vars, - const std::vector& dims) { - size_t length = vars.size(); - PADDLE_ENFORCE_EQ( - length, dims.size(), - platform::errors::PreconditionNotMet( - "Vars number [%d] should be equal with dims number [%d]", length, - dims.size())); - for (size_t i = 0; i < length; ++i) { - if (vars[i] == nullptr) { - continue; - } - SetDim(vars[i], dims[i]); - } - } - - void SetRepeatedDims(const std::string& name, - const std::vector& dims) override { - PADDLE_THROW(platform::errors::PermissionDenied( - "SetRepeatedDims not support in dygraph runtime")); - } - - private: - const NameVarMap* var_base_map_in_; - const NameVarMap* var_base_map_out_; - const framework::AttributeMap* attrs_; -}; +std::shared_ptr CreateGradOpNode( + const framework::OperatorBase& op, const NameVarBaseMap& ins, + const NameVarBaseMap& outs, const framework::AttributeMap& attrs, + const platform::Place& place); } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h new file mode 100644 index 00000000000..b3044cef49b --- /dev/null +++ b/paddle/fluid/imperative/op_base.h @@ -0,0 +1,211 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/imperative/variable_wrapper.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace imperative { + +// TODO(zjl): to support py_func layer +class OpBase { + public: + OpBase() = default; + + OpBase(const OpBase&) = delete; + + OpBase(OpBase&&) = default; + + OpBase& operator=(const OpBase&) = delete; + + OpBase& operator=(OpBase&&) = default; + + ~OpBase() { VLOG(3) << "Destruct Op: " << Type(); } + + const std::string& Type() const { return op_->Type(); } + + const framework::AttributeMap& Attrs() const { return attrs_; } + + const framework::OpInfo& Info() const { return op_->Info(); } + + const framework::OperatorBase& InnerOp() const { return *op_; } + + void ClearBackwardTrace(); + + NameVarMap* GetMutableOutsMap() { return &outs_; } + + NameVarMap* GetMutableInsMap() { return &ins_; } + + const NameVarMap& GetInsMap() const { return ins_; } + + const NameVarMap& GetOutsMap() const { return outs_; } + + void SetType(const std::string& type); + + void CheckAttrs() { + auto& info = op_->Info(); + if (info.Checker() != nullptr) { + info.Checker()->Check(&attrs_, true); + } + } + + void SetInput(const std::string& name, VariableWrapperList vars, + bool is_grad) { + auto& in_vars = ins_[name]; + *(in_vars.MutableVarList()) = std::move(vars); + in_vars.SetIsGrad(is_grad); + } + + void SetOutput(const std::string& name, VariableWrapperList vars, + bool is_grad) { + auto& out_vars = outs_[name]; + *(out_vars.MutableVarList()) = std::move(vars); + out_vars.SetIsGrad(is_grad); + } + + void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } + + void SetAttr(const std::string& name, const framework::Attribute& v) { + attrs_[name] = v; + } + + void SetBlockAttr(const std::string& name, framework::BlockDesc* block) { + PADDLE_THROW(platform::errors::PermissionDenied( + "SetBlockAttr is not support in dygraph OpBase")); + } + + const framework::AttributeMap& Attrs() { return attrs_; } + + bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; } + + const framework::Attribute& GetAttr(const std::string& name) const { + auto it = attrs_.find(name); + PADDLE_ENFORCE_NE( + it, attrs_.end(), + platform::errors::NotFound("can not find attribute [%s]", name)); + return it->second; + } + + template + inline const T& Attr(const std::string& name) const { + return boost::get(GetAttr(name)); + } + + size_t id() const { return id_; } + + void SetId(size_t id) { id_ = id; } + + const platform::Place& place() const { return place_; } + + void SetPlace(const platform::Place& place) { place_ = place; } + + static size_t GenerateUniqueId() { + static std::atomic unique_id{0}; + return unique_id.fetch_add(1); + } + + static void Run(const framework::OperatorBase& op, + const NameVarMap& ins, + const NameVarMap& outs, + const framework::AttributeMap& attrs, + const platform::Place& place); + + static void Run(const framework::OperatorBase& op, + const NameVarMap& ins, + const NameVarMap& outs, + const framework::AttributeMap& attrs, + const platform::Place& place); + + private: + NameVarMap ins_; + NameVarMap outs_; + framework::AttributeMap attrs_; + std::unique_ptr op_; + platform::Place place_; + size_t id_{-1UL}; + + std::vector> backward_hooks_; +}; + +class GradOpNode { + public: + GradOpNode() = default; + + void reserve(size_t size) { ops_.reserve(size); } + + size_t size() const { return ops_.size(); } + + bool empty() const { return ops_.empty(); } + + void clear() { ops_.clear(); } + + void pop_back() { ops_.pop_back(); } + + template + OpBase& emplace_back(ARGS&&... args) { // NOLINT + ops_.emplace_back(std::forward(args)...); + return ops_.back(); + } + + const OpBase& back() const { return ops_.back(); } + + OpBase& back() { return ops_.back(); } + + OpBase& operator[](size_t idx) { return ops_[idx]; } + + const OpBase& operator[](size_t idx) const { return ops_[idx]; } + + /* Iterator related */ + using Iterator = std::vector::iterator; + using ConstIterator = std::vector::const_iterator; + + Iterator begin() { return ops_.begin(); } + + Iterator end() { return ops_.end(); } + + ConstIterator begin() const { return ops_.begin(); } + + ConstIterator end() const { return ops_.end(); } + + void InsertGradPendingNode(const std::shared_ptr& node) { + if (node && + std::find(grad_pending_nodes_.begin(), grad_pending_nodes_.end(), + node) == grad_pending_nodes_.end()) { + grad_pending_nodes_.emplace_back(node); + } + } + + const std::vector>& GradPendingNodes() const { + return grad_pending_nodes_; + } + + private: + DISABLE_COPY_AND_ASSIGN(GradOpNode); + + private: + std::vector ops_; + std::vector> grad_pending_nodes_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc new file mode 100644 index 00000000000..99cabaf70b0 --- /dev/null +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -0,0 +1,1028 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/partial_grad_engine.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/imperative/gradient_accumulator.h" +#include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/op_base.h" +#include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace imperative { + +/** + * This function prunes the graph to get the ops between `output_targets` + * and `input_target_grads`. + * + * + * The inputs are: + * + * - input_target_grads: the input target grads. It may be changed. + * - output_targets: the output target vars. It may be changed. + * + * + * The outputs are: + * + * - startup_op_ptr: startup ops of the pruned graph. + * - pending_ops_ptr: contains all the pending ops of each op in the graph. + * - op_deps_ptr: the preceding op number of each op in the graph. + * - related_grad_vars_ptr: all grad vars in the pruned graph. + */ +static void GetGraphInfoBetweenTargets( + std::unordered_set *input_target_grads, + std::unordered_set *output_targets, + std::unordered_set *startup_ops_ptr, + std::unordered_map> + *pending_ops_ptr, + std::unordered_map *op_deps_ptr, + std::unordered_set *related_grad_vars_ptr, + const std::unordered_set &no_grad_var_grad) { + /** + * Step 1. Find the candidate startup grad ops, prepared for following BFS. + */ + std::queue> q; + std::unordered_set visited; + for (auto iter = output_targets->begin(); iter != output_targets->end();) { + auto *output_target = *iter; + PADDLE_ENFORCE_NOT_NULL( + output_target, + platform::errors::NotFound("output_target must not be nullptr")); + if (output_target->OverridedStopGradient() || + output_target->GradVarBase() == nullptr || + output_target->GradVarBase()->GradNode() == nullptr) { + VLOG(10) << output_target->Name() + << " is pruned because it stops gradient or has no grad var"; + iter = output_targets->erase(iter); + continue; + } + + auto &grad_node = output_target->GradVarBase()->GradNode(); + if (visited.count(grad_node.get()) == 0) { + for (auto &op : *grad_node) { + q.emplace(&op, grad_node.get()); + } + } + ++iter; + } + + /** + * Step 2. BFS the graph and find all grad ops which generate the + * input_target_grads. Notice that not all candidate startup ops + * would be connected with input_target_grads, that is to say, + * not all input_target_grads would be found. + */ + std::unordered_set found_input_target_grads; + std::unordered_set endpoint_ops; + std::unordered_map> + preceding_ops; + while (!q.empty()) { + auto op_node_pair = q.front(); + q.pop(); + + auto *op = op_node_pair.first; + auto *node = op_node_pair.second; + + for (auto &output_pair : op->GetOutsMap()) { + if (!output_pair.second.IsGrad()) { + VLOG(10) << "WARNING: " << op->Type() << " outputs a forward var"; + continue; + } + + for (auto &out_var : output_pair.second) { + if (out_var && input_target_grads->count(out_var.get()) > 0) { + VLOG(10) << "Found endpoint op " << op->Type() << " which generates " + << out_var->Name(); + found_input_target_grads.insert(out_var.get()); + endpoint_ops.emplace(op); + } + } + } + + for (auto &pending_node : node->GradPendingNodes()) { + if (visited.count(pending_node.get()) == 0) { + for (auto &pending_op : *pending_node) { + preceding_ops[&pending_op].insert(op); + q.emplace(&pending_op, pending_node.get()); + } + } + } + } + + /** + * Step 3. Based on the found input_target_grads, BFS the graph in reverse + * order. `target_vars` would record all grad vars in the graph, and + * `startup_ops` would be the final startup ops of the graph. + */ + *input_target_grads = found_input_target_grads; + + auto &pending_ops = *pending_ops_ptr; + pending_ops.clear(); + + auto &startup_ops = *startup_ops_ptr; + startup_ops.clear(); + + auto &op_deps = *op_deps_ptr; + op_deps.clear(); + + auto &target_vars = *related_grad_vars_ptr; + target_vars = *input_target_grads; + + std::queue> + op_queue; + for (auto &endpoint_op : endpoint_ops) { + op_queue.emplace(endpoint_op, nullptr); + } + + while (!op_queue.empty()) { + auto op_pair = op_queue.front(); + auto *op = op_pair.first; + auto *pending_op = op_pair.second; + + op_queue.pop(); + + bool is_valid = false; + for (auto &output_pair : op->GetOutsMap()) { + if (!output_pair.second.IsGrad()) { + continue; + } + + for (auto &out_var : output_pair.second) { + if (out_var && target_vars.count(out_var.get()) > 0) { + is_valid = true; + break; + } + } + + if (is_valid) { + break; + } + } + + if (!is_valid) { + continue; + } + + is_valid = false; + for (auto &input_pair : op->GetInsMap()) { + if (!input_pair.second.IsGrad()) { + continue; + } + + for (auto &in_var : input_pair.second) { + if (in_var && no_grad_var_grad.count(in_var.get()) == 0) { + target_vars.insert(in_var.get()); + is_valid = true; + } + } + } + + if (!is_valid) { + continue; + } + + op_deps[op]; + if (pending_op) { + VLOG(10) << "Pending op of " << op->Type() << " is " + << pending_op->Type(); + pending_ops[op].insert(pending_op); + ++op_deps[pending_op]; + } else { + pending_ops[op]; + } + + auto iter = preceding_ops.find(op); + if (iter != preceding_ops.end()) { + for (auto &preceding_op : iter->second) { + op_queue.emplace(preceding_op, op); + } + } + } + + for (auto &pair : op_deps) { + if (pair.second == 0) { + auto *op = pair.first; + VLOG(10) << "Found startup op " << op->Type(); + startup_ops.insert(op); + } + } + + /** + * Step 4. Prune output_targets which is not the input of startup_ops + */ + for (auto iter = output_targets->begin(); iter != output_targets->end();) { + auto &grad_node = (*iter)->GradVarBase()->GradNode(); + bool is_valid = std::find_if(grad_node->begin(), grad_node->end(), + [&](const OpBase &op) { + return startup_ops.count(&op) > 0; + }) != grad_node->end(); + if (is_valid) { + ++iter; + } else { + iter = output_targets->erase(iter); + } + } +} + +// Get debug string of op types contained in `node` +static std::string GradOpTypes(const GradOpNode &node) { + std::vector node_types; + for (auto &op : node) { + node_types.emplace_back(op.Type()); + } + return string::join_strings(node_types, ','); +} + +// Get debug string of grad node of `var`'s gradient +static std::string GradOpTypes(const VarBase &var) { + if (!var.GradVarBase() || !var.GradVarBase()->GradNode()) { + return ""; + } else { + return GradOpTypes(*(var.GradVarBase()->GradNode())); + } +} + +// Get pending op types of `node` +static std::string GradPendingOpTypes(const GradOpNode &node) { + std::vector node_types; + for (auto &n : node.GradPendingNodes()) { + node_types.emplace_back(GradOpTypes(*n)); + } + return string::join_strings(node_types, ','); +} + +static void FillConstantLike(const VariableWrapper &ref_var, + VariableWrapper *dst_var, + const platform::Place &place, float value) { + auto &ref_tensor = ref_var.Var().Get(); + auto *dst_tensor = dst_var->MutableVar()->GetMutable(); + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); + dst_tensor->Resize(ref_tensor.dims()); + dst_tensor->mutable_data(place, ref_var.DataType()); + operators::math::set_constant(*dev_ctx, dst_tensor, value); +} + +/** + * A data structure for gradient accumulation + */ +class GradientAccumulationInfo { + private: + using PartialGradGradTraceIdPair = + std::pair /*partial grad grad var*/, + size_t /*trace_id*/>; + + public: + explicit GradientAccumulationInfo(const std::shared_ptr &var, + bool sort_gradient, bool create_graph) + : mapped_grad_var_(var.get()), + sort_gradient_(sort_gradient), + create_graph_(create_graph) {} + + void IncreaseTotalRefCnt() { + ++total_ref_cnt_; + + // The gradient accumulator is needed only when total_ref_cnt_ > 1. + // grad_var_ would be created only when total_ref_cnt_ > 1. + if (total_ref_cnt_ > 1) { + if (!grad_var_) { + grad_var_ = std::make_shared(true, mapped_grad_var_->Name()); + grad_var_->SetOverridedStopGradient(false); + if (sort_gradient_) { + accumulator_.reset( + new SortedGradientAccumulator(grad_var_->SharedVar().get())); + } else { + accumulator_.reset( + new EagerGradientAccumulator(grad_var_->SharedVar().get())); + } + accumulator_->IncreaseRefCnt(); + } + accumulator_->IncreaseRefCnt(); + } + } + + size_t TotalRefCnt() { return total_ref_cnt_; } + + const std::shared_ptr &GradVarBase() const { return grad_var_; } + + std::shared_ptr GradVar() const { + return grad_var_ == nullptr ? nullptr : grad_var_->SharedVar(); + } + + VariableWrapper *MappedGradVar() { return mapped_grad_var_; } + + std::vector> SumGradient( + std::shared_ptr grad_var_partial, size_t trace_id, + bool *is_finished, bool unchange_input = false) { + PADDLE_ENFORCE_NOT_NULL(grad_var_partial, + platform::errors::PermissionDenied( + "Partial grad of %s would not be nullptr", + mapped_grad_var_->Name())); + PADDLE_ENFORCE_GT(total_ref_cnt_, 1, + platform::errors::PermissionDenied( + "Gradient accumulation should not be called when " + "reference count is 1 or 0")); + + ++cur_ref_cnt_; + PADDLE_ENFORCE_LE(cur_ref_cnt_, total_ref_cnt_, + platform::errors::PermissionDenied( + "Reference count overflows, this may be a bug")); + + *is_finished = (cur_ref_cnt_ == total_ref_cnt_); + accumulator_->Add(grad_var_partial, trace_id, unchange_input); + + if (create_graph_) { + VLOG(10) << "Store partial grad grad for double grad " + << mapped_grad_var_->Name(); + partial_grad_grads_.emplace_back(grad_var_partial->GetWeakGradVar(), + trace_id); + } + + if (!(*is_finished) || !create_graph_) { + return {}; + } + + if (sort_gradient_) { + std::sort(partial_grad_grads_.begin(), partial_grad_grads_.end(), + [](const PartialGradGradTraceIdPair &p1, + const PartialGradGradTraceIdPair &p2) { + return p1.second > p2.second; + }); + } + + // Only when create_graph_ = True, the return value would be not empty + std::vector> result; + result.reserve(partial_grad_grads_.size()); + for (auto &pair : partial_grad_grads_) { + if (auto var = pair.first.lock()) { + result.emplace_back(var); + } + } + return result; + } + + private: + std::shared_ptr grad_var_; + VariableWrapper *mapped_grad_var_; + std::unique_ptr accumulator_; + std::vector partial_grad_grads_; + size_t total_ref_cnt_{0}; + size_t cur_ref_cnt_{0}; + bool sort_gradient_; + bool create_graph_; +}; + +class ReadyGradVarInfoMap { + private: + struct ReadyVarInfo { + std::shared_ptr var; + size_t cur_ref_cnt{0}; + size_t total_ref_cnt{0}; + }; + + public: + void IncreaseRefCnt(const VariableWrapper *var) { + ++(vars_[var].total_ref_cnt); + } + + std::shared_ptr Get(const VariableWrapper *var, + const platform::Place &place, bool *is_last) { + auto iter = vars_.find(var); + PADDLE_ENFORCE_EQ( + iter != vars_.end(), true, + platform::errors::NotFound("Variable %s not found, this may be a bug", + var->Name())); + auto &ready_var = iter->second; + PADDLE_ENFORCE_LT(ready_var.cur_ref_cnt, ready_var.total_ref_cnt, + platform::errors::PermissionDenied( + "Reference count overflows for %s", var->Name())); + + if (ready_var.var == nullptr && ready_var.cur_ref_cnt == 0) { + ready_var.var = std::make_shared(var->Name()); + VLOG(10) << "Fill zero for " << var->Name() << " because it is not ready"; + FillConstantLike(*var, ready_var.var->SharedVar().get(), place, 0.0f); + } else { + PADDLE_ENFORCE_NOT_NULL( + ready_var.var, + platform::errors::NotFound( + "%s is not found when reference count does not decreases to 0")); + } + + if (++ready_var.cur_ref_cnt == ready_var.total_ref_cnt) { + *is_last = true; + return std::move(ready_var.var); // move to set ready_var.var to nullptr + } else { + *is_last = false; + return ready_var.var; + } + } + + // Set a var as a ready var. + // If the var is one of target vars, store it inside `target_vars_` as well. + bool Set(const VariableWrapper *mapped_var, + const std::shared_ptr &var) { + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::PermissionDenied( + "Cannot set nullptr as ready grad var for %s", mapped_var->Name())); + { + auto target_iter = target_vars_.find(mapped_var); + if (target_iter != target_vars_.end()) { + PADDLE_ENFORCE_EQ( + target_iter->second, nullptr, + platform::errors::PermissionDenied("Cannot set target var %s twice", + mapped_var->Name())); + target_iter->second = var; + } + } + + auto iter = vars_.find(mapped_var); + if (iter != vars_.end()) { // This var is ready for next op's input + auto &ready_var = iter->second; + PADDLE_ENFORCE_EQ( + ready_var.var, nullptr, + platform::errors::PermissionDenied("Cannot set target var %s twice", + mapped_var->Name())); + PADDLE_ENFORCE_EQ( + ready_var.cur_ref_cnt, 0, + platform::errors::PermissionDenied( + "Reference count must be 0 when ready var %s is set", + mapped_var->Name())); + ready_var.var = var; + return true; + } else { + VLOG(10) << "Do not record " << mapped_var->Name() + << " because it is not input of any following ops"; + return false; + } + } + + void Clear() { + vars_.clear(); + target_vars_.clear(); + } + + // Mark a var as target var + void SetTarget(const VariableWrapper *var) { + PADDLE_ENFORCE_EQ(target_vars_[var], nullptr, + platform::errors::PermissionDenied( + "Target var would not be generated when marking")); + } + + // Get target var + const std::shared_ptr &GetTarget(const VariableWrapper *var) const { + auto iter = target_vars_.find(var); + PADDLE_ENFORCE_EQ(iter != target_vars_.end(), true, + platform::errors::NotFound("Target var %s does not exist", + var->Name())); + PADDLE_ENFORCE_NOT_NULL( + iter->second, platform::errors::PermissionDenied( + "Target var %s should not be nullptr", var->Name())); + return iter->second; + } + + private: + std::unordered_map vars_; + std::unordered_map> + target_vars_; +}; + +class PartialGradTask { + public: + PartialGradTask(const std::vector> &input_targets, + const std::vector> &output_targets, + const std::vector> &output_grads, + const std::vector> &no_grad_vars, + const platform::Place &place, + const detail::BackwardStrategy &strategy, bool create_graph); + + std::vector> Run(); + + private: + void RunEachOp(const OpBase *op); + + void PrepareInitialReadyVarsMap(const OpBase *op); + + void PrepareInitialGradientAccumulators(const OpBase *op); + + std::vector> CreateResult(); + + bool IsValidGradVar(const std::shared_ptr &var) const { + return var && no_grad_var_grad_.count(var.get()) == 0; + } + + private: + std::unordered_set startup_ops_; + std::unordered_map> + pending_ops_; + std::unordered_map op_deps_; + + ReadyGradVarInfoMap ready_grad_vars_; + + std::unordered_map> + grad_accumulators_; + + std::vector> double_grad_nodes_; + + std::vector< + std::pair>> + grads_to_accumulate_; + + // Input targets that are reachable + std::vector> input_targets_; + std::unordered_set input_target_grads_; + + std::unordered_set no_grad_var_grad_; + std::vector> reset_stop_gradient_vars_; + + platform::Place place_; + bool create_graph_; + detail::BackwardStrategy strategy_; +}; + +PartialGradTask::PartialGradTask( + const std::vector> &input_targets, + const std::vector> &output_targets, + const std::vector> &output_grads, + const std::vector> &no_grad_vars, + const platform::Place &place, const detail::BackwardStrategy &strategy, + bool create_graph) { + input_targets_ = input_targets; + place_ = place; + create_graph_ = create_graph; + strategy_ = strategy; + + for (auto &var : no_grad_vars) { + if (var && var->GradVarBase()) { + no_grad_var_grad_.insert(var->GradVarBase()->SharedVar().get()); + } + } + + PADDLE_ENFORCE_EQ( + input_targets.empty(), false, + platform::errors::PermissionDenied("inputs can not be empty")); + PADDLE_ENFORCE_EQ( + output_targets.empty(), false, + platform::errors::PermissionDenied("outputs can not be empty")); + + std::unordered_set out_set; + for (auto &output : output_targets) { + PADDLE_ENFORCE_NOT_NULL(output, + platform::errors::PermissionDenied( + "Variable inside outputs should not be null")); + PADDLE_ENFORCE_EQ( + output->GradVarBase() && !output->OverridedStopGradient(), true, + platform::errors::PermissionDenied( + "Variable %s inside outputs has no gradient", output->Name())); + PADDLE_ENFORCE_EQ( + out_set.count(output.get()), 0, + platform::errors::AlreadyExists("outputs contain duplicate variable %s", + output->Name())); + PADDLE_ENFORCE_EQ(IsValidGradVar(output->GradVarBase()->SharedVar()), true, + platform::errors::PermissionDenied( + "outputs contain var that is inside no_grad_set")); + + out_set.insert(output.get()); + } + + std::unordered_set in_set; + std::unordered_set one_grad_vars; + for (auto &input : input_targets) { + PADDLE_ENFORCE_NOT_NULL(input, + platform::errors::PermissionDenied( + "Variable inside inputs should not be null")); + PADDLE_ENFORCE_EQ( + input->GradVarBase() && !input->OverridedStopGradient(), true, + platform::errors::PermissionDenied( + "Variable %s inside inputs has no gradient", input->Name())); + PADDLE_ENFORCE_EQ( + in_set.count(input.get()), 0, + platform::errors::AlreadyExists("inputs contain duplicate variable %s", + input->Name())); + in_set.insert(input.get()); + input_target_grads_.insert(input->GradVarBase()->SharedVar().get()); + + PADDLE_ENFORCE_EQ(IsValidGradVar(input->GradVarBase()->SharedVar()), true, + platform::errors::PermissionDenied( + "inputs contain var that is inside no_grad_set")); + + // Record same vars between inputs and outputs + if (out_set.count(input.get()) > 0) { + one_grad_vars.insert(input->GradVarBase()->SharedVar().get()); + } + } + + std::unordered_set related_grad_vars; + GetGraphInfoBetweenTargets(&input_target_grads_, &out_set, &startup_ops_, + &pending_ops_, &op_deps_, &related_grad_vars, + no_grad_var_grad_); + + for (auto &op_pair : pending_ops_) { + auto *op = op_pair.first; + PrepareInitialReadyVarsMap(op); + PrepareInitialGradientAccumulators(op); + } + + for (auto &input_grad : input_target_grads_) { + ready_grad_vars_.SetTarget(input_grad); + } + + for (auto &one_grad : one_grad_vars) { + VLOG(10) << "Add same in/out target " << one_grad->Name(); + input_target_grads_.insert(one_grad); + ready_grad_vars_.SetTarget(one_grad); + } + + VLOG(10) << "Valid op number " << pending_ops_.size(); + + if (!output_grads.empty()) { + PADDLE_ENFORCE_EQ(output_targets.size(), output_grads.size(), + platform::errors::InvalidArgument( + "grad_outputs number should be equal to outputs")); + } + + for (size_t i = 0; i < output_targets.size(); ++i) { + auto *mapped_out_grad_var = + output_targets[i]->GradVarBase()->SharedVar().get(); + + if (related_grad_vars.count(mapped_out_grad_var) == 0 && + one_grad_vars.count(mapped_out_grad_var) == 0) { + VLOG(10) << mapped_out_grad_var->Name() << " should be None"; + continue; + } + + std::shared_ptr out_grad_var; + bool unchange_input = false; + if (output_grads.empty() || output_grads[i] == nullptr) { + VLOG(10) << "Fill 1.0f for " << output_targets[i]->Name(); + out_grad_var = std::make_shared( + framework::GradVarName(output_targets[i]->Name())); + FillConstantLike(*(output_targets[i]->SharedVar()), out_grad_var.get(), + place_, 1.0f); + } else { + VLOG(10) << "Use user provided grad var for " + << output_targets[i]->Name(); + const auto &out_tensor = + output_targets[i]->Var().Get(); + const auto &grad_tensor = + output_grads[i]->Var().Get(); + PADDLE_ENFORCE_EQ( + grad_tensor.dims(), out_tensor.dims(), + platform::errors::InvalidArgument( + "The %d-th grad_output's shape does not match the %d-th output", + i, i)); + PADDLE_ENFORCE_EQ(grad_tensor.type(), out_tensor.type(), + platform::errors::InvalidArgument( + "The %d-th grad_output's data type does not " + "match the %d-th output", + i, i)); + out_grad_var = output_grads[i]->SharedVar(); + PADDLE_ENFORCE_EQ(IsValidGradVar(out_grad_var), true, + platform::errors::PermissionDenied( + "grad_outputs contain var inside no_grad_set")); + + if (out_grad_var->OverridedStopGradient()) { + VLOG(10) << "Grad var " << out_grad_var->Name() + << " should reset stop gradient"; + reset_stop_gradient_vars_.emplace_back(out_grad_var); + } + + unchange_input = true; + } + + out_grad_var->SetOverridedStopGradient(false); + auto grad_accumulator_iter = grad_accumulators_.find(mapped_out_grad_var); + if (grad_accumulator_iter == grad_accumulators_.end()) { + ready_grad_vars_.Set(mapped_out_grad_var, + std::make_shared(false, out_grad_var)); + VLOG(10) << "Fill 1.0f or user-provided gradient as ready var " + << out_grad_var->Name(); + } else { + auto &accumulator = grad_accumulator_iter->second; + accumulator->IncreaseTotalRefCnt(); + bool is_finished = false; + accumulator->SumGradient(out_grad_var, 0, &is_finished, unchange_input); + PADDLE_ENFORCE_EQ( + is_finished, false, + platform::errors::Fatal("gradient accumulator should not finish")); + VLOG(10) << "Add 1.0f or user-provided gradient to gradient accumulator" + << out_grad_var->Name(); + } + } +} + +std::vector> PartialGradTask::Run() { + VLOG(10) << "Startup op number " << startup_ops_.size(); + std::queue q; + for (auto *op : startup_ops_) { + q.push(op); + } + + while (!q.empty()) { + auto *op = q.front(); + q.pop(); + VLOG(10) << "Start to run " << op->Type(); + RunEachOp(op); + VLOG(10) << "End to run " << op->Type(); + + auto iter = pending_ops_.find(op); + if (iter == pending_ops_.end()) { + VLOG(10) << "Finish running because " << op->Type() + << " has no pending ops"; + continue; + } + + for (auto &pending_op : iter->second) { + auto dep_iter = op_deps_.find(pending_op); + PADDLE_ENFORCE_EQ( + dep_iter != op_deps_.end(), true, + platform::errors::Fatal("Dependency number of %s does not exist", + pending_op->Type())); + if (--(dep_iter->second) == 0) { + q.push(pending_op); + } + } + } + + VLOG(10) << "Created " << double_grad_nodes_.size() << " double grad ops"; + return CreateResult(); +} + +void PartialGradTask::RunEachOp(const OpBase *op) { + // Prepare new inputs + NameVarMap tmp_ins; + for (auto &input_pair : op->GetInsMap()) { + auto &new_inputs = tmp_ins[input_pair.first]; + new_inputs.reserve(input_pair.second.size()); + + if (!input_pair.second.IsGrad()) { + for (auto &fwd_var : input_pair.second) { + if (fwd_var) { + new_inputs.emplace_back(new VarBase(true, fwd_var)); + VLOG(10) << "Unpacked forward var " << fwd_var->Name() + << ", grad ops: " << GradOpTypes(*new_inputs.back()); + } else { + new_inputs.emplace_back(); + } + } + } else { + for (auto &grad_var : input_pair.second) { + if (grad_var) { + bool is_last; + new_inputs.emplace_back( + ready_grad_vars_.Get(grad_var.get(), op->place(), &is_last)); + VLOG(10) << "Got ready grad var " << grad_var->Name() << " " + << new_inputs.back().get(); + } else { + new_inputs.emplace_back(); + } + } + } + } + + // Prepare new outputs + NameVarMap tmp_outs; + for (auto &output_pair : op->GetOutsMap()) { + auto &new_outputs = tmp_outs[output_pair.first]; + if (!output_pair.second.IsGrad()) { + for (auto &fwd_var : output_pair.second) { + // unpack forward var + if (fwd_var) { + new_outputs.emplace_back(new VarBase(true, fwd_var)); + VLOG(10) << "Unpacked forward var " << fwd_var->Name(); + } else { + new_outputs.emplace_back(); + } + } + } else { + for (auto &grad_var : output_pair.second) { + if (IsValidGradVar(grad_var)) { + VLOG(10) << "Creating output grad var " << grad_var->Name(); + auto new_grad_var_iter = grad_accumulators_.find(grad_var.get()); + PADDLE_ENFORCE_EQ(new_grad_var_iter != grad_accumulators_.end(), true, + platform::errors::Fatal( + "Cannot find gradient accumulator of %s %p", + grad_var->Name(), grad_var.get())); + + auto new_grad_var = std::make_shared(true, grad_var->Name()); + new_grad_var->SetOverridedStopGradient(false); + if (new_grad_var_iter->second->TotalRefCnt() > 1) { + grads_to_accumulate_.emplace_back(new_grad_var_iter->second.get(), + new_grad_var->SharedVar()); + } else { + PADDLE_ENFORCE_EQ( + new_grad_var_iter->second->GradVar(), nullptr, + platform::errors::AlreadyExists( + "When reference count is 1, the grad var should not be " + "created in gradient accumulator")); + grad_accumulators_.erase(new_grad_var_iter); + ready_grad_vars_.Set(grad_var.get(), new_grad_var); + } + VLOG(10) << "Created output grad var " << grad_var->Name(); + new_outputs.emplace_back(std::move(new_grad_var)); + } else { + new_outputs.emplace_back(); + } + } + } + } + + // Run op + OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->place()); + + if (create_graph_) { + auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, + op->Attrs(), op->place()); + if (double_grad_node) { + VLOG(10) << "Create " << double_grad_node->size() + << " double grad op(s) for " << op->Type() + << ", pending ops: " << GradPendingOpTypes(*double_grad_node); + double_grad_nodes_.emplace_back(std::move(double_grad_node)); + } + } + + VLOG(10) << "There are " << grads_to_accumulate_.size() << " to sum gradient"; + + // Gradient accumulation and add assign op + for (auto &pair : grads_to_accumulate_) { + auto *accumulator_info = pair.first; + auto &grad_var = pair.second; + + bool is_finished = false; + VLOG(10) << "Start to sum " << accumulator_info->MappedGradVar()->Name(); + auto partial_grad_grads = accumulator_info->SumGradient( + std::move(grad_var), op->id(), &is_finished); + + if (is_finished) { + VLOG(10) << "Sum has finished for " + << accumulator_info->MappedGradVar()->Name() << " " + << accumulator_info->GradVarBase(); + ready_grad_vars_.Set(accumulator_info->MappedGradVar(), + accumulator_info->GradVarBase()); + } + + if (partial_grad_grads.empty()) { + continue; + } + + auto sum_grad_var_grad = + accumulator_info->GradVarBase()->MutableGradVarBase(); + sum_grad_var_grad->SetOverridedStopGradient(false); + + auto assign_node = std::make_shared(); + sum_grad_var_grad->SetGradNode(assign_node); + + VLOG(10) << "Add " << partial_grad_grads.size() << " assign op for " + << sum_grad_var_grad->Name(); + + for (auto &grad_grad : partial_grad_grads) { + auto *assign_op = &(assign_node->emplace_back()); + assign_op->SetType("assign"); // Can use "scale" as static graph mode + assign_op->SetInput("X", {sum_grad_var_grad->SharedVar()}, true); + assign_op->SetOutput("Out", {grad_grad}, true); + assign_op->CheckAttrs(); + assign_op->SetId(OpBase::GenerateUniqueId()); + assign_op->SetPlace(op->place()); + + if (auto grad_pending_node = grad_grad->GetGradNode()) { + assign_node->InsertGradPendingNode(std::move(grad_pending_node)); + } + } + VLOG(10) << "Pending ops of assign is " << GradPendingOpTypes(*assign_node); + grad_accumulators_.erase(accumulator_info->MappedGradVar()); + double_grad_nodes_.emplace_back(assign_node); + } + + grads_to_accumulate_.clear(); +} + +void PartialGradTask::PrepareInitialReadyVarsMap(const OpBase *op) { + for (auto &in_var_pair : op->GetInsMap()) { + if (!in_var_pair.second.IsGrad()) { + continue; + } + + for (auto &var : in_var_pair.second) { + if (var) { + ready_grad_vars_.IncreaseRefCnt(var.get()); + } + } + } +} + +void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) { + for (auto &out_var_pair : op->GetOutsMap()) { + if (!out_var_pair.second.IsGrad()) { + continue; + } + + for (auto &var : out_var_pair.second) { + if (var == nullptr) { + continue; + } + + auto &accumulator = grad_accumulators_[var.get()]; + + if (!accumulator) { + accumulator.reset(new GradientAccumulationInfo( + var, strategy_.sorted_sum_gradient_, create_graph_)); + } + + accumulator->IncreaseTotalRefCnt(); + } + } +} + +std::vector> PartialGradTask::CreateResult() { + std::vector> result; + result.reserve(input_targets_.size()); + for (auto &input_target : input_targets_) { + PADDLE_ENFORCE_NOT_NULL( + input_target->GradVarBase(), + platform::errors::InvalidArgument("input should have gradient")); + auto *original_grad_var = input_target->GradVarBase()->SharedVar().get(); + auto iter = input_target_grads_.find(original_grad_var); + if (iter != input_target_grads_.end()) { + auto ready_var = ready_grad_vars_.GetTarget(original_grad_var); + ready_var->SetOverridedStopGradient(!create_graph_); + result.emplace_back(std::move(ready_var)); + } else { // return None if it does not appear in the graph + result.emplace_back(); + } + } + + for (auto &weak_var : reset_stop_gradient_vars_) { + if (auto var = weak_var.lock()) { + VLOG(10) << "Reset " << var->Name() << " stop gradient"; + var->SetOverridedStopGradient(!var->OverridedStopGradient()); + } + } + + ready_grad_vars_.Clear(); + grad_accumulators_.clear(); + double_grad_nodes_.clear(); + reset_stop_gradient_vars_.clear(); + return result; +} + +PartialGradEngine::PartialGradEngine( + const std::vector> &input_targets, + const std::vector> &output_targets, + const std::vector> &output_grads, + const std::vector> &no_grad_vars, + const platform::Place &place, const detail::BackwardStrategy &strategy, + bool create_graph) + : input_targets_(input_targets), + output_targets_(output_targets), + output_grads_(output_grads), + no_grad_vars_(no_grad_vars), + place_(place), + strategy_(strategy), + create_graph_(create_graph) {} + +std::vector> PartialGradEngine::GetResult() const { + return results_; +} + +void PartialGradEngine::Clear() { + input_targets_.clear(); + output_targets_.clear(); + output_grads_.clear(); + no_grad_vars_.clear(); +} + +void PartialGradEngine::Execute() { + PartialGradTask task(input_targets_, output_targets_, output_grads_, + no_grad_vars_, place_, strategy_, create_graph_); + VLOG(10) << "Starts to execute PartialGradEngine"; + results_ = task.Run(); + Clear(); +} + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/partial_grad_engine.h b/paddle/fluid/imperative/partial_grad_engine.h new file mode 100644 index 00000000000..fde4703ad42 --- /dev/null +++ b/paddle/fluid/imperative/partial_grad_engine.h @@ -0,0 +1,58 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/imperative/backward_strategy.h" +#include "paddle/fluid/imperative/engine.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace imperative { + +class VarBase; + +class PartialGradEngine : public Engine { + public: + PartialGradEngine(const std::vector> &input_targets, + const std::vector> &output_targets, + const std::vector> &output_grads, + const std::vector> &no_grad_vars, + const platform::Place &place, + const detail::BackwardStrategy &strategy, + bool create_graph); + + void Execute() override; + + std::vector> GetResult() const; + + private: + void Clear(); + + private: + std::vector> input_targets_; + std::vector> output_targets_; + std::vector> output_grads_; + std::vector> no_grad_vars_; + platform::Place place_; + detail::BackwardStrategy strategy_; + bool create_graph_; + + std::vector> results_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 93a8959099b..c4aa2f7392a 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -14,6 +14,9 @@ #include "paddle/fluid/imperative/prepared_operator.h" #include +#include "paddle/fluid/imperative/execution_context.h" +#include "paddle/fluid/imperative/infer_shape_context.h" +#include "paddle/fluid/imperative/infer_var_type_context.h" namespace paddle { namespace imperative { diff --git a/paddle/fluid/imperative/saved_variable_wrapper_list.h b/paddle/fluid/imperative/saved_variable_wrapper_list.h new file mode 100644 index 00000000000..1e7aba6dcaf --- /dev/null +++ b/paddle/fluid/imperative/saved_variable_wrapper_list.h @@ -0,0 +1,87 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace imperative { + +class VariableWrapper; + +class SavedVariableWrapperList { + public: + SavedVariableWrapperList() : vars_(), is_grad_(false) {} + + template + explicit SavedVariableWrapperList(bool is_grad, Args&&... args) + : vars_(std::forward(args)...), is_grad_(is_grad) {} + + bool IsGrad() const { return is_grad_; } + + void SetIsGrad(bool is_grad) { is_grad_ = is_grad; } + + const std::vector>& VarList() const { + return vars_; + } + + std::vector>* MutableVarList() { + return &vars_; + } + + /* Borrow method from std::vector */ + size_t size() const { return vars_.size(); } + + bool empty() const { return vars_.empty(); } + + template + void emplace_back(ARGS&&... args) { + vars_.emplace_back(std::forward(args)...); + } + + using Iterator = std::vector>::iterator; + + using ConstIterator = + std::vector>::const_iterator; + + Iterator begin() { return vars_.begin(); } + + Iterator end() { return vars_.end(); } + + ConstIterator begin() const { return vars_.begin(); } + + ConstIterator end() const { return vars_.end(); } + + std::shared_ptr& operator[](size_t idx) { + return vars_[idx]; + } + + const std::shared_ptr& operator[](size_t idx) const { + return vars_[idx]; + } + + operator const std::vector>&() const { + return vars_; + } + + private: + std::vector> vars_; + bool is_grad_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/tests/test_gradient_accmulator.cc b/paddle/fluid/imperative/tests/test_gradient_accmulator.cc index 29a51733c93..0d567f5d055 100644 --- a/paddle/fluid/imperative/tests/test_gradient_accmulator.cc +++ b/paddle/fluid/imperative/tests/test_gradient_accmulator.cc @@ -117,5 +117,202 @@ TEST(test_add_functor, add_functor) { #endif } +static void CopyVar(const framework::Variable& var, + framework::Variable* dst_ptr) { + auto& dst = *dst_ptr; + dst.Clear(); + if (var.IsType()) { + const auto& src_tensor = var.Get(); + auto* dst_tensor = dst.GetMutable(); + framework::TensorCopySync(src_tensor, src_tensor.place(), dst_tensor); + } else { + const auto& src_selected_rows = var.Get(); + auto* dst_selected_rows = dst.GetMutable(); + dst_selected_rows->set_rows(src_selected_rows.rows()); + dst_selected_rows->set_height(src_selected_rows.height()); + framework::TensorCopySync(src_selected_rows.value(), + src_selected_rows.value().place(), + dst_selected_rows->mutable_value()); + } +} + +static bool IsEqualVar(const framework::Variable& var1, + const framework::Variable& var2) { + if (var1.Type() != var2.Type()) { + return false; + } + + framework::Tensor t1, t2; + + if (var1.IsType()) { + framework::TensorCopySync(var1.Get(), + platform::CPUPlace(), &t1); + framework::TensorCopySync(var2.Get(), + platform::CPUPlace(), &t2); + } else { + auto& s1 = var1.Get(); + auto& s2 = var2.Get(); + + if (s1.height() != s2.height()) { + return false; + } + + if (s1.rows().size() != s2.rows().size()) { + return false; + } + + auto row1_data = s1.rows().data(); + auto row2_data = s2.rows().data(); + if (std::memcmp(row1_data, row2_data, + s1.rows().size() * sizeof(*row1_data)) != 0) { + return false; + } + + framework::TensorCopySync(var1.Get().value(), + platform::CPUPlace(), &t1); + framework::TensorCopySync(var2.Get().value(), + platform::CPUPlace(), &t2); + } + + if (t1.type() != t2.type() || t1.dims() != t2.dims()) { + return false; + } + + auto* t1_p = t1.data(); + auto* t2_p = t2.data(); + return std::memcmp(t1_p, t2_p, + t1.numel() * framework::SizeOfType(t1.type())) == 0; +} + +template +static framework::Variable RandomTensor(const framework::DDim& dims, + const platform::Place& place, + int low = -10, int high = 10) { + framework::Tensor cpu_tensor; + cpu_tensor.Resize(dims); + auto* ptr = cpu_tensor.mutable_data(platform::CPUPlace()); + std::uniform_int_distribution dist(low, high); + std::random_device rd; + std::mt19937 engine(rd()); + for (int64_t i = 0; i < cpu_tensor.numel(); ++i) { + ptr[i] = dist(engine); + } + + framework::Variable ret; + framework::TensorCopySync(cpu_tensor, place, + ret.GetMutable()); + return ret; +} + +template +static framework::Variable RandomSelectedRows(framework::DDim dims, + const platform::Place& place, + int64_t row_number, int low = -10, + int high = 10) { + auto height = dims[0]; + dims[0] = row_number; + + framework::Variable ret; + auto* sr = ret.GetMutable(); + auto tensor_var = RandomTensor(dims, place, low, high); + sr->mutable_value()->ShareDataWith( + tensor_var.template Get()); + sr->set_height(height); + sr->mutable_rows()->resize(row_number); + auto* row_data = sr->mutable_rows()->data(); + std::uniform_int_distribution dist(0, height - 1); + std::random_device rd; + std::mt19937 engine(rd()); + for (int64_t i = 0; i < dims[0]; ++i) { + row_data[i] = dist(engine); + } + return ret; +} + +static std::unique_ptr CreateAccumulator( + const std::shared_ptr& var, bool sort_gradient) { + if (sort_gradient) { + return std::unique_ptr( + new SortedGradientAccumulator(var.get())); + } else { + return std::unique_ptr( + new EagerGradientAccumulator(var.get())); + } +} + +static void TestGradientAccumulatorTestUnchangeInput( + const platform::Place& place, bool sort_gradient) { + framework::DDim dim{10, 20}; + int64_t maximum_row_number = 100; + + std::uniform_int_distribution dist(1, maximum_row_number); + int seed; + { + std::random_device rd; + seed = rd(); + } + + std::mt19937 engine(seed); + + auto create_var = [&](bool use_tensor) { + if (use_tensor) { + return RandomTensor(dim, place); + } else { + return RandomSelectedRows(dim, place, dist(engine)); + } + }; + + std::vector use_tensors = {false, true}; + + for (auto use_tensor1 : use_tensors) { + for (auto use_tensor2 : use_tensors) { + auto g_var1 = std::make_shared("g_var1"); + g_var1->SetOverridedStopGradient(false); + auto g_accum1 = CreateAccumulator(g_var1, sort_gradient); + g_accum1->IncreaseRefCnt(); + g_accum1->IncreaseRefCnt(); + + auto g_var2 = std::make_shared("g_var2"); + g_var2->SetOverridedStopGradient(false); + auto g_accum2 = CreateAccumulator(g_var2, sort_gradient); + g_accum2->IncreaseRefCnt(); + g_accum2->IncreaseRefCnt(); + + auto var1 = create_var(use_tensor1); + auto var_wrapper1_1 = std::make_shared("tmp1_1"); + auto var_wrapper2_1 = std::make_shared("tmp2_1"); + CopyVar(var1, var_wrapper1_1->MutableVar()); + CopyVar(var1, var_wrapper2_1->MutableVar()); + + auto var2 = create_var(use_tensor2); + auto var_wrapper1_2 = std::make_shared("tmp1_2"); + auto var_wrapper2_2 = std::make_shared("tmp2_2"); + CopyVar(var2, var_wrapper1_2->MutableVar()); + CopyVar(var2, var_wrapper2_2->MutableVar()); + + g_accum1->Add(var_wrapper1_1, 0, false); + g_accum1->Add(var_wrapper1_2, 1, false); + + g_accum2->Add(var_wrapper2_1, 0, true); + g_accum2->Add(var_wrapper2_2, 1, true); + + ASSERT_TRUE(IsEqualVar(var_wrapper2_1->Var(), var1)); + ASSERT_TRUE(IsEqualVar(var_wrapper2_2->Var(), var2)); + ASSERT_TRUE(IsEqualVar(g_var1->Var(), g_var2->Var())); + } + } +} + +TEST(test_gradient_accumulator, test_unchange_input) { + for (auto sort_gradient : {false, true}) { + TestGradientAccumulatorTestUnchangeInput(platform::CPUPlace(), + sort_gradient); +#ifdef PADDLE_WITH_CUDA + TestGradientAccumulatorTestUnchangeInput(platform::CUDAPlace(0), + sort_gradient); +#endif + } +} + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index dabfb5f7bcd..9f7cb3344fb 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -21,6 +21,9 @@ #include #include #include "gtest/gtest.h" +#include "paddle/fluid/imperative/execution_context.h" +#include "paddle/fluid/imperative/infer_shape_context.h" +#include "paddle/fluid/imperative/infer_var_type_context.h" #include "paddle/fluid/imperative/layer.h" namespace imperative = paddle::imperative; @@ -45,7 +48,7 @@ TEST(test_layer, test_runtime_context) { imperative::NameVarBaseMap outs = {out_pair}; framework::AttributeMap attrs; auto *ctx = new imperative::RuntimeInferVarTypeContext( - ins, &outs, attrs); + ins, outs, attrs); ASSERT_TRUE(ctx->HasVar("vin")); ASSERT_TRUE(ctx->HasInput("X")); ASSERT_TRUE(ctx->HasOutput("Out")); @@ -120,11 +123,12 @@ TEST(test_layer, test_debug_string) { ASSERT_TRUE(res_sr.find("SelectedRows") != std::string::npos); } -static std::shared_ptr CreateOpBase( +static std::shared_ptr CreateGradNode( size_t id, const std::string &type, const imperative::NameVarBaseMap &ins, const imperative::NameVarBaseMap &outs, const framework::AttributeMap &attrs, const platform::Place &place) { - auto op = std::make_shared(); + auto node = std::make_shared(); + auto *op = &(node->emplace_back()); op->SetId(id); op->SetPlace(place); op->SetType(type); @@ -134,7 +138,7 @@ static std::shared_ptr CreateOpBase( for (auto &var : pair.second) { vars.emplace_back(var->SharedVar()); } - op->SetInput(pair.first, vars); + op->SetInput(pair.first, vars, false); } for (auto &pair : outs) { @@ -142,10 +146,10 @@ static std::shared_ptr CreateOpBase( for (auto &var : pair.second) { vars.emplace_back(var->SharedVar()); } - op->SetOutput(pair.first, vars); + op->SetOutput(pair.first, vars, false); } - return op; + return node; } TEST(test_layer, test_clear_backward_info) { @@ -163,19 +167,21 @@ TEST(test_layer, test_clear_backward_info) { framework::AttributeMap concat_att_map; concat_att_map["axis"] = 1; - auto op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place); - auto preceding_op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place); - op->SetGradPendingOps({preceding_op}); + auto node = CreateGradNode(0, "mul", ins, outs, concat_att_map, place); + auto pending_node = + CreateGradNode(0, "mul", ins, outs, concat_att_map, place); + node->InsertGradPendingNode(pending_node); + + ASSERT_EQ(node->size(), 1UL); + auto *op = &(node->back()); ASSERT_GT(op->GetInsMap().size(), 0UL); ASSERT_GT(op->GetOutsMap().size(), 0UL); - ASSERT_GT(op->GradPendingOps().size(), 0UL); op->ClearBackwardTrace(); ASSERT_EQ(op->GetInsMap().size(), 0UL); ASSERT_EQ(op->GetOutsMap().size(), 0UL); - ASSERT_EQ(op->GradPendingOps().size(), 0UL); } TEST(test_layer, test_varbase_basic) { diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index 9dd0ae63332..5852e60a481 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -22,6 +22,7 @@ #include #include #include "gtest/gtest.h" +#include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/memory/memcpy.h" @@ -148,9 +149,9 @@ TEST(test_tracer, test_track_backward_output) { framework::AttributeMap mul_attr_map; mul_attr_map["use_mkldnn"] = false; tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); - ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL); - ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL); - ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL); + ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL); } TEST(test_tracer, test_track_backward_input) { @@ -188,9 +189,9 @@ TEST(test_tracer, test_track_backward_input) { mul_attr_map["use_mkldnn"] = false; tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); - ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL); - ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL); - ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL); + ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL); } #if defined(PADDLE_WITH_CUDA) TEST(test_tracer, test_trace_op_with_multi_device_inputs) { @@ -240,9 +241,9 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) { tracer.TraceOp("reduce_sum", reduce_in, reduce_out, reduce_attr_map, gpu_place, true); detail::BackwardStrategy back_st; - imperative::Engine* engine = tracer.GetDefaultEngine(); - engine->Init(reduce_sum_out.get(), back_st); - engine->Execute(); + imperative::BasicEngine engine; + engine.Init(reduce_sum_out.get(), back_st); + engine.Execute(); framework::LoDTensor rlt; framework::TensorCopySync(vout->Var().Get(), place, @@ -346,14 +347,14 @@ TEST(test_tracer, test_var_without_grad_var) { ASSERT_EQ(out_tensor.data()[i], 20.0); } - ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL); - ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL); - ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL); + ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL); detail::BackwardStrategy back_st; - imperative::Engine* engine = tracer.GetDefaultEngine(); - engine->Init(vout.get(), back_st); - engine->Execute(); + imperative::BasicEngine engine; + engine.Init(vout.get(), back_st); + engine.Execute(); // check the grad framework::LoDTensor x_grad; @@ -382,7 +383,7 @@ static void TestVarOpDestructionMain(const platform::Place& place, size_t loop_num = 10) { WeakPtrSet var_wrappers; WeakPtrSet var_bases; - WeakPtrSet op_bases; + WeakPtrSet op_bases; Tracer tracer; @@ -413,30 +414,31 @@ static void TestVarOpDestructionMain(const platform::Place& place, NameVarBaseMap{{"Out", {z}}}, framework::AttributeMap{}, place, true); - ASSERT_EQ(z->GradOps().size(), 0UL); - ASSERT_EQ(z->GradVarBase()->GradOps().size(), 1UL); - auto new_op = z->GradVarBase()->GradOps()[0]; + ASSERT_EQ(z->GradOpNum(), 0UL); + ASSERT_EQ(z->GradVarBase()->GradOpNum(), 1UL); + auto new_op = z->GradVarBase()->GradNode(); - ASSERT_EQ(x->GradOps().size(), 0UL); - ASSERT_EQ(y->GradOps().size(), 0UL); + ASSERT_EQ(x->GradOpNum(), 0UL); + ASSERT_EQ(y->GradOpNum(), 0UL); - std::unordered_set> expected_pending_ops; + std::unordered_set> expected_pending_ops; if (i == 0) { - ASSERT_EQ(x->GradVarBase()->GradOps().size(), 0UL); - ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL); + ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL); + ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL); } else { - ASSERT_EQ(x->GradVarBase()->GradOps().size(), 1UL); - ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL); + ASSERT_EQ(x->GradVarBase()->GradOpNum(), 1UL); + ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL); - for (auto& op : x->GradVarBase()->GradOps()) { - expected_pending_ops.emplace(op); + if (x->GradVarBase()->GradNode()) { + expected_pending_ops.emplace(x->GradVarBase()->GradNode()); } - for (auto& op : y->GradVarBase()->GradOps()) { - expected_pending_ops.emplace(op); + + if (y->GradVarBase()->GradNode()) { + expected_pending_ops.emplace(y->GradVarBase()->GradNode()); } - std::unordered_set> actual_pending_ops; - for (auto& op : new_op->GradPendingOps()) { + std::unordered_set> actual_pending_ops; + for (auto& op : new_op->GradPendingNodes()) { actual_pending_ops.emplace(op); } diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index e99bd6700a1..9db241fb0e9 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -16,6 +16,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/string/string_helper.h" @@ -31,49 +32,6 @@ void SetCurrentTracer(const std::shared_ptr& tracer) { VLOG(6) << "Set current tracer: " << g_current_tracer; } -static void ClearNoNeedBufferInputs(OpBase* op) { - auto& inferer = op->Info().NoNeedBufferVarsInferer(); - if (!inferer) return; - auto* ins = op->GetMutableInsMap(); - const auto& no_need_buffer_slots = - inferer(*ins, op->GetOutsMap(), op->Attrs()); - if (no_need_buffer_slots.empty()) return; - - for (auto& slot : no_need_buffer_slots) { - auto iter = ins->find(slot); - if (iter == ins->end()) continue; - VLOG(2) << "Clear data buffer of " << slot << " in " << op->Type(); - - for (auto& each_var : iter->second) { - if (!each_var) continue; - - auto& var = each_var->Var(); - PADDLE_ENFORCE_EQ(var.IsType(), true, - "Only support LoDTensor"); - // TODO(zjl): support higher order derivatives - auto new_var = new VariableWrapper(each_var->Name()); - auto* new_tensor = - new_var->MutableVar()->GetMutable(); - auto& old_tensor = var.Get(); - new_tensor->Resize(old_tensor.dims()); - new_tensor->set_lod(old_tensor.lod()); - each_var.reset(new_var); - op->AddAllowedEmptyVar(new_var); - } - } -} - -static std::vector> CreateGradOpBases( - const framework::OpInfo& info, const std::string& type, - const NameVarBaseMap& in, const NameVarBaseMap& out, - const framework::AttributeMap& attrs) { - if (info.dygraph_grad_op_maker_) { - return info.dygraph_grad_op_maker_(type, in, out, attrs); - } else { - return {}; - } -} - static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) { for (const auto& name_pair : outs) { for (const auto& vb : name_pair.second) { @@ -103,7 +61,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, } if (ComputeRequiredGrad(ins, outs, trace_backward)) { - TraceBackward(op_info, type, ins, outs, attrs, place); + CreateGradOpNode(*op, ins, outs, attrs, place); } else { VLOG(3) << "No Grad to track for Op: " << type; } @@ -133,22 +91,5 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, return false; } -void Tracer::TraceBackward(const framework::OpInfo& info, - const std::string& type, const NameVarBaseMap& ins, - const NameVarBaseMap& outs, - const framework::AttributeMap& attrs, - const platform::Place& place) { - auto grad_op_bases = CreateGradOpBases(info, type, ins, outs, attrs); - auto grad_op_num = grad_op_bases.size(); - if (grad_op_num == 0) return; - - size_t trace_id = GenerateUniqueId(); - for (auto& grad_op : grad_op_bases) { - grad_op->SetPlace(place); - grad_op->SetId(trace_id); - ClearNoNeedBufferInputs(grad_op.get()); - } -} - } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index f9971c588d3..90758c4acb9 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -21,7 +21,7 @@ #include #include #include "ThreadPool.h" -#include "paddle/fluid/imperative/engine.h" +#include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/jit/program_desc_tracer.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/platform/macros.h" @@ -46,7 +46,7 @@ class Tracer { public: Tracer() - : engine_(new BasicEngine()), + : basic_engine_(new BasicEngine()), program_desc_tracer_(new jit::ProgramDescTracer()), generator_(new UniqueNameGenerator()) { expected_place_ = platform::CPUPlace(); @@ -64,8 +64,6 @@ class Tracer { bool ComputeRequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs, bool trace_backward); - Engine* GetDefaultEngine() const { return engine_.get(); } - void SetEnableProgramDescTracing(bool enabled) { enable_program_desc_tracing_ = enabled; } @@ -82,6 +80,8 @@ class Tracer { return generator_->Generate(key); } + BasicEngine* GetEngine() const { return basic_engine_.get(); } + platform::Place ExpectedPlace() const { return expected_place_; } void SetExpectedPlace(platform::Place place) { expected_place_ = place; } @@ -91,18 +91,7 @@ class Tracer { void SetNoGrad(bool no_grad) { no_grad_ = no_grad; } private: - void TraceBackward(const framework::OpInfo& info, const std::string& type, - const NameVarBaseMap& ins, const NameVarBaseMap& outs, - const framework::AttributeMap& attrs, - const platform::Place& place); - - static size_t GenerateUniqueId() { - static std::atomic id{0}; - return id.fetch_add(1); - } - - private: - std::unique_ptr engine_; + std::unique_ptr basic_engine_; std::unique_ptr program_desc_tracer_; bool enable_program_desc_tracing_{false}; std::unique_ptr generator_; diff --git a/paddle/fluid/imperative/type_defs.h b/paddle/fluid/imperative/type_defs.h index eb7cc9c400f..74fd152e72a 100644 --- a/paddle/fluid/imperative/type_defs.h +++ b/paddle/fluid/imperative/type_defs.h @@ -23,18 +23,37 @@ namespace paddle { namespace imperative { class VariableWrapper; +class SavedVariableWrapperList; class VarBase; class OpBase; +class GradOpNode; class Tracer; +using WeakNameVarBaseMap = + std::map>>; + +namespace details { template -using NameVarMap = std::map>>; +struct NameVarMapTrait {}; + +template <> +struct NameVarMapTrait { + using Type = std::map>>; +}; + +template <> +struct NameVarMapTrait { + using Type = std::map; +}; +} // namespace details + +template +using NameVarMap = typename details::NameVarMapTrait::Type; using NameVarBaseMap = NameVarMap; using NameVariableWrapperMap = NameVarMap; -using WeakNameVarBaseMap = - std::map>>; +using VariableWrapperList = std::vector>; } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index a10490866fd..9c2ff39e867 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -14,14 +14,20 @@ #pragma once +#include #include #include "paddle/fluid/framework/variable.h" namespace paddle { namespace imperative { +class VarBase; +class GradOpNode; + class VariableWrapper { public: + friend class VarBase; + explicit VariableWrapper(const std::string& name) : name_(name) {} const framework::Variable& Var() const { return var_; } @@ -31,6 +37,10 @@ class VariableWrapper { // This is used for python api void SetOverridedStopGradient(bool stop_gradient) { overrided_stop_gradient_ = static_cast(stop_gradient); + + if (auto grad_var = grad_var_.lock()) { + grad_var->SetOverridedStopGradient(stop_gradient); + } } // This is used for python api @@ -47,6 +57,10 @@ class VariableWrapper { VLOG(6) << "Ignore Stop gradient conversion for Var: " << Name() << "Set value is: " << overrided_stop_gradient_; } + + if (auto grad_var = grad_var_.lock()) { + grad_var->InnerSetOverridedStopGradient(stop_gradient); + } } void SetPersistable(bool persistable) { persistable_ = persistable; } @@ -65,6 +79,18 @@ class VariableWrapper { data_type_ = data_type; } + std::shared_ptr GetGradVar() const { + return grad_var_.lock(); + } + + const std::weak_ptr& GetWeakGradVar() const { + return grad_var_; + } + + std::shared_ptr GetGradNode() const { return grad_node_.lock(); } + + bool HasGradNode() const { return !grad_node_.expired(); } + framework::proto::VarType::Type DataType() const { const framework::Tensor* tensor = nullptr; if (var_.IsInitialized()) { @@ -85,6 +111,32 @@ class VariableWrapper { } } + private: + void SetGradVar(const std::shared_ptr& var) { + auto shared_var = grad_var_.lock(); + if (shared_var != var) { + PADDLE_ENFORCE_EQ(shared_var, nullptr, + platform::errors::PermissionDenied( + "Cannot set gradient var wrapper twice")); + grad_var_ = var; + } + } + + void SetGradNode(const std::shared_ptr& grad_node) { + if (!grad_node) { + grad_node_.reset(); + return; + } + + auto shared_node = grad_node_.lock(); + if (shared_node != grad_node) { + PADDLE_ENFORCE_EQ( + shared_node, nullptr, + platform::errors::PermissionDenied("Cannot set gradient op twice")); + grad_node_ = grad_node; + } + } + private: framework::Variable var_; std::string name_; @@ -96,6 +148,9 @@ class VariableWrapper { framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR}; framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32}; + + std::weak_ptr grad_var_; + std::weak_ptr grad_node_; }; } // namespace imperative diff --git a/paddle/fluid/operators/fused/fused_bn_activation_op.h b/paddle/fluid/operators/fused/fused_bn_activation_op.h index 1a18a129ee9..0b7b75fe6f2 100644 --- a/paddle/fluid/operators/fused/fused_bn_activation_op.h +++ b/paddle/fluid/operators/fused/fused_bn_activation_op.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/var_type_inference.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/minus_op.cc b/paddle/fluid/operators/minus_op.cc index db6d2f596af..08696e32fca 100644 --- a/paddle/fluid/operators/minus_op.cc +++ b/paddle/fluid/operators/minus_op.cc @@ -109,31 +109,29 @@ class MinusGradMaker : public imperative::GradOpBaseMakerBase { public: using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase; - std::vector> operator()() const override { - std::vector> ops; + std::shared_ptr operator()() const override { auto x_g = this->InputGrad("X"); + auto y_g = this->InputGrad("Y"); + + auto node = this->NewGradNode(); + if (!x_g.empty()) { - auto x_g_op = CreateOp(); - imperative::TracedGradOp op(x_g_op); + imperative::TracedGradOp op(node); op.SetType("scale"); op.SetInput("X", this->OutputGrad("Out")); op.SetOutput("Out", x_g); op.SetAttr("scale", 1.0f); - ops.emplace_back(x_g_op); } - auto y_g = this->InputGrad("Y"); if (!y_g.empty()) { - auto y_g_op = CreateOp(); - imperative::TracedGradOp op(y_g_op); + imperative::TracedGradOp op(node); op.SetType("scale"); op.SetInput("X", this->OutputGrad("Out")); op.SetOutput("Out", y_g); op.SetAttr("scale", -1.0f); - ops.emplace_back(y_g_op); } - return ops; + return node; } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc index bcc8c411740..1d6e2d52ed5 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc @@ -64,21 +64,22 @@ class ReduceMeanDoubleGradOpBaseMaker : public imperative::GradOpBaseMakerBase { public: using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase; - std::vector> operator()() const override { - std::vector> ops; - auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx + std::shared_ptr operator()() const override { auto out_grads = InputGrad(framework::GradVarName("Out")); if (!out_grads.empty()) { - auto out_grad_op = CreateOp(); - imperative::TracedGradOp op(out_grad_op); - op.SetType("reduce_mean"); - op.SetInput("X", x_gg); - op.SetAttrMap(Attrs()); - op.SetOutput("Out", out_grads); - ops.emplace_back(out_grad_op); + auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx + auto node = this->NewGradNode(); + { + imperative::TracedGradOp op(node); + op.SetType("reduce_mean"); + op.SetInput("X", x_gg); + op.SetAttrMap(Attrs()); + op.SetOutput("Out", out_grads); + } + return node; + } else { + return nullptr; } - - return ops; } }; DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ReduceMeanGradNoNeedBufferVarInference, diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 1876cd8221e..a854a11f64c 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -272,25 +272,25 @@ class SumGradOpBaseMaker : public imperative::GradOpBaseMakerBase { public: using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase; - std::vector> operator()() const override { + std::shared_ptr operator()() const override { auto x_grads = InputGrad("X", false); using InputGradsType = decltype(x_grads); - std::vector> grad_ops; - grad_ops.reserve(x_grads.size()); - auto og = OutputGrad("Out"); - std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), - [&og](const std::shared_ptr& x_grad) { - auto grad_op = CreateOp(); - imperative::TracedGradOp op(grad_op); - op.SetType("scale"); - op.SetInput("X", og); - op.SetOutput("Out", InputGradsType{x_grad}); - op.SetAttr("scale", 1.0f); - return grad_op; - }); - - return grad_ops; + if (!x_grads.empty()) { + auto node = this->NewGradNode(); + node->reserve(x_grads.size()); + auto og = OutputGrad("Out"); + for (auto& x_grad : x_grads) { + imperative::TracedGradOp op(node); + op.SetType("scale"); + op.SetInput("X", og); + op.SetOutput("Out", InputGradsType{x_grad}); + op.SetAttr("scale", 1.0f); + } + return node; + } else { + return nullptr; + } } }; diff --git a/paddle/fluid/pybind/global_value_getter_setter.cc b/paddle/fluid/pybind/global_value_getter_setter.cc index cf928326c5a..2bc419f49b1 100644 --- a/paddle/fluid/pybind/global_value_getter_setter.cc +++ b/paddle/fluid/pybind/global_value_getter_setter.cc @@ -34,6 +34,7 @@ DECLARE_bool(free_idle_chunk); DECLARE_bool(free_when_no_cache_hit); #ifdef PADDLE_WITH_CUDA DECLARE_uint64(gpu_memory_limit_mb); +DECLARE_bool(cudnn_deterministic); #endif DECLARE_string(allocator_strategy); DECLARE_bool(enable_parallel_graph); @@ -180,6 +181,7 @@ static void RegisterGlobalVarGetterSetter() { REGISTER_GLOBAL_VAR_GETTER_ONLY(FLAGS_free_when_no_cache_hit); #ifdef PADDLE_WITH_CUDA REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_gpu_memory_limit_mb); + REGISTER_GLOBAL_VAR_GETTER_SETTER(FLAGS_cudnn_deterministic); #endif } diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index cbe37e0f728..932d1b3992b 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -25,9 +25,11 @@ limitations under the License. */ #include #include #include "paddle/fluid/imperative/backward_strategy.h" +#include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/data_loader.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/nccl_context.h" +#include "paddle/fluid/imperative/partial_grad_engine.h" #include "paddle/fluid/imperative/profiler.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" @@ -599,10 +601,9 @@ void BindImperative(py::module *m_ptr) { const imperative::Tracer &tracer) { // TODO(jiabin): when we impl more backward execution we can select // them - - imperative::Engine *engine = tracer.GetDefaultEngine(); - VLOG(3) << "Start backward"; + auto *engine = tracer.GetEngine(); engine->Init(&self, bckst); + VLOG(3) << "Start backward"; engine->Execute(); VLOG(3) << "Finish backward"; }, @@ -772,6 +773,25 @@ void BindImperative(py::module *m_ptr) { }, [](imperative::ParallelStrategy &self, const std::string &ep) { self.current_endpoint_ = ep; }); + + m.def( + "dygraph_partial_grad", + [](const std::vector> &input_targets, + const std::vector> + &output_targets, + const std::vector> &output_grads, + const std::vector> &no_grad_vars, + const platform::Place &place, + const imperative::detail::BackwardStrategy &strategy, + bool create_graph) { + imperative::PartialGradEngine engine(input_targets, output_targets, + output_grads, no_grad_vars, place, + strategy, create_graph); + engine.Execute(); + return engine.GetResult(); + }, + py::call_guard()); + #if defined(PADDLE_WITH_NCCL) py::class_ nccl_ctx(m, "NCCLParallelContext"); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 04c1457a053..f9a9120e1f1 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -256,7 +256,9 @@ void BindOpDesc(pybind11::module *m) { .def("set_is_target", &pd::OpDesc::SetIsTarget) .def("serialize_to_string", SerializeMessage) .def("block", [](pd::OpDesc &self) { return self.Block(); }, - pybind11::return_value_policy::reference); + pybind11::return_value_policy::reference) + .def("inputs", &pd::OpDesc::Inputs) + .def("outputs", &pd::OpDesc::Outputs); } } // namespace pybind diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 095ca5772d7..51324a21a1e 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -883,9 +883,10 @@ def _append_backward_ops_(block, # Set device for grad_op according to forward Op device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() - op_device = op.desc.attr(device_attr_name) - for op_desc in grad_op_desc: - op_desc._set_attr(device_attr_name, op_device) + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) # If input_grad_names_set is not None, extend grad_op_descs only when # any input grad in outputs of previous grad ops. diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index e0460eadcb2..0bf620ea217 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -249,6 +249,77 @@ def _print_debug_msg(parameter_list, limit=5, is_test=False): return unique_name_size, tracer_var_size, alive_cpp_var_size +@framework.dygraph_only +def grad(outputs, + inputs, + grad_outputs=None, + no_grad_set=None, + create_graph=False, + backward_strategy=None): + def check_in_out(in_out_list, name): + assert in_out_list is not None, "{} should not be None".format(name) + + if isinstance(in_out_list, (list, tuple)): + assert len(in_out_list) > 0, "{} cannot be empty".format(name) + for each_var in in_out_list: + assert isinstance( + each_var, + core.VarBase), "Elements of {} must be Variable".format( + name) + return in_out_list + else: + assert isinstance( + in_out_list, + core.VarBase), "{} must be Variable or list of Variable".format( + name) + return [in_out_list] + + outputs = check_in_out(outputs, 'outputs') + inputs = check_in_out(inputs, 'inputs') + + if grad_outputs is not None: + if not isinstance(grad_outputs, (list, tuple)): + grad_outputs = [grad_outputs] + + for each_var in grad_outputs: + if each_var is not None: + assert isinstance( + each_var, core.VarBase + ), "grad_outputs must be None, a Variable or a list containing None or Variables" + else: + grad_outputs = [] + + if len(grad_outputs) > 0: + assert len(grad_outputs) == len( + outputs), "The length of grad_outputs must be equal to outputs" + + if no_grad_set is None: + no_grad_set = [] + elif isinstance(no_grad_set, core.VarBase): + no_grad_set = [no_grad_set] + elif isinstance(no_grad_set, (list, tuple, set)): + no_grad_set = list(no_grad_set) + for var in no_grad_set: + assert isinstance( + var, core.VarBase), "no_grad_set can only contains Variable" + else: + raise AssertionError( + "no_grad_set must be None, Variable or list/tuple/set of Variables") + + if backward_strategy is None: + backward_strategy = core.BackwardStrategy() + + assert isinstance(backward_strategy, core.BackwardStrategy), \ + "backward_strategy must be type paddle.fluid.dygraph.BackwardStrategy" + + assert isinstance(create_graph, bool), "create_graph must be True or False" + + place = core.Place() + place.set_place(framework._current_expected_place()) + return core.dygraph_partial_grad(inputs, outputs, grad_outputs, no_grad_set, + place, backward_strategy, create_graph) + + @framework.dygraph_only def to_variable(value, name=None, zero_copy=None): """ diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py index d9c6df62da4..24981d4b6ab 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py @@ -222,7 +222,7 @@ class TestImperativeAutoPrune(unittest.TestCase): out = fluid.layers.concat(input=[out1, out2, c], axis=1) out.backward() self.assertTrue(linear.weight.gradient() is None) - self.assertTrue((out1.gradient() == 0).all()) + self.assertTrue(out1.gradient() is None) def test_auto_prune7(self): with fluid.dygraph.guard(): @@ -241,7 +241,7 @@ class TestImperativeAutoPrune(unittest.TestCase): backward_strategy = fluid.dygraph.BackwardStrategy() out.backward(backward_strategy) self.assertTrue(linear.weight.gradient() is None) - self.assertTrue((out1.gradient() == 0).all()) + self.assertTrue(out1.gradient() is None) def test_auto_prune8(self): with fluid.dygraph.guard(): @@ -315,7 +315,7 @@ class TestImperativeAutoPrune(unittest.TestCase): backward_strategy.sort_sum_gradient = True out.backward(backward_strategy) self.assertTrue(linear.weight.gradient() is None) - self.assertTrue((out1.gradient() == 0).all()) + self.assertTrue(out1.gradient() is None) def test_auto_prune_with_optimizer(self): vocab_size = 100 diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py new file mode 100644 index 00000000000..26d3ae7b3c7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -0,0 +1,278 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.wrapped_decorator import wrap_decorator +import unittest +from unittest import TestCase +import numpy as np +from paddle.fluid.dygraph.base import grad + + +def _dygraph_guard_(func): + def __impl__(*args, **kwargs): + if fluid.in_dygraph_mode(): + return func(*args, **kwargs) + else: + with fluid.dygraph.guard(): + return func(*args, **kwargs) + + return __impl__ + + +dygraph_guard = wrap_decorator(_dygraph_guard_) + + +def random_var(size, low=-1, high=1, dtype='float32'): + x_np = np.random.uniform(low=low, high=high, size=size).astype(dtype) + return fluid.dygraph.to_variable(x_np) + + +class TestDygraphDoubleGrad(TestCase): + def setUp(self): + self.sort_sum_gradient = False + self.shape = [5, 10] + + def grad(self, + outputs, + inputs, + grad_outputs=None, + no_grad_set=None, + create_graph=False): + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = self.sort_sum_gradient + return grad( + outputs=outputs, + inputs=inputs, + grad_outputs=grad_outputs, + no_grad_set=no_grad_set, + create_graph=create_graph, + backward_strategy=backward_strategy) + + @dygraph_guard + def test_exception(self): + with self.assertRaises(AssertionError): + self.grad(None, None) + + shape = self.shape + + with self.assertRaises(AssertionError): + self.grad(1, random_var(shape)) + + with self.assertRaises(AssertionError): + self.grad(random_var(shape), 1) + + with self.assertRaises(AssertionError): + self.grad([1], [random_var(shape)]) + + with self.assertRaises(AssertionError): + self.grad([random_var(shape)], [1]) + + with self.assertRaises(AssertionError): + self.grad([random_var(shape), random_var(shape)], + [random_var(shape)], [random_var(shape)]) + + with self.assertRaises(AssertionError): + self.grad([random_var(shape)], [random_var(shape)], no_grad_set=[1]) + + with self.assertRaises(AssertionError): + self.grad([random_var(shape)], [random_var(shape)], no_grad_set=1) + + @dygraph_guard + def test_simple_example(self): + x = random_var(self.shape) + x.stop_gradient = False + y = x + 1 + + for create_graph in [False, True]: + dx, = self.grad([x], [x], create_graph=create_graph) + self.assertEqual(dx.shape, x.shape) + self.assertTrue(np.all(dx.numpy() == 1)) + self.assertNotEqual(dx.stop_gradient, create_graph) + + dx_mul_2, = self.grad([y, x], [x], create_graph=create_graph) + self.assertEqual(dx_mul_2.shape, x.shape) + self.assertTrue(np.all(dx_mul_2.numpy() == 2)) + self.assertNotEqual(dx_mul_2.stop_gradient, create_graph) + + none_grad, = self.grad([x], [y], create_graph=create_graph) + self.assertTrue(none_grad is None) + + grad_with_none_and_not_none, = self.grad( + [x, y], [y], create_graph=create_graph) + self.assertTrue(grad_with_none_and_not_none.shape, x.shape) + self.assertTrue(np.all(grad_with_none_and_not_none.numpy() == 1)) + self.assertNotEqual(grad_with_none_and_not_none.stop_gradient, + create_graph) + + @dygraph_guard + def test_none_one_initial_gradient(self): + x = random_var(self.shape) + x.stop_gradient = False + + y = fluid.layers.relu(x) + y = y * y + z = y * y + + x_np = x.numpy() + relu_x_np = np.maximum(x_np, 0).astype('float32') + relu_x_grad_np = (x_np > 0).astype('float32') + dy_expected = (relu_x_np * relu_x_grad_np * 2).astype('float32') + dz_expected = (np.power(relu_x_np, 3) * relu_x_grad_np * + 4).astype('float32') + + random_grad_y = random_var(y.shape) + random_grad_z = random_var(z.shape) + ones_grad_y = np.ones(y.shape).astype('float32') + ones_grad_z = np.ones(z.shape).astype('float32') + + original_random_grad_y = random_grad_y.numpy() + original_random_grad_z = random_grad_z.numpy() + + for grad_y in [random_grad_y]: + for grad_z in [random_grad_z]: + for create_graph in [False, True]: + dx_actual, = self.grad( + outputs=[y, z], + inputs=[x], + grad_outputs=[grad_y, grad_z], + create_graph=create_graph) + + grad_y_np = ones_grad_y if grad_y is None else grad_y.numpy( + ) + grad_z_np = ones_grad_z if grad_z is None else grad_z.numpy( + ) + + dx_expected = dy_expected * grad_y_np + dz_expected * grad_z_np + self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) + + if grad_y is not None: + self.assertTrue(grad_y.stop_gradient) + self.assertTrue( + np.array_equal(grad_y.numpy(), + original_random_grad_y)) + + if grad_z is not None: + self.assertTrue(grad_z.stop_gradient) + self.assertTrue( + np.array_equal(grad_z.numpy(), + original_random_grad_z)) + + @dygraph_guard + def test_example_with_gradient_accumulation_and_create_graph(self): + x = random_var(self.shape) + x_np = x.numpy() + numel = x_np.size + x.stop_gradient = False + + y = fluid.layers.relu(x) + z = y + 1 + w = z * z + + w_mean = fluid.layers.reduce_mean(w) + del y, z, w + + dx_actual, = self.grad([w_mean], [x], create_graph=True) + del w_mean + + self.assertFalse(dx_actual.stop_gradient) + + # Theoritical result based on math calculation + dx_expected = (1.0 / float(numel) * (np.maximum(x_np, 0) + 1) * + (x_np > 0) * 2).astype('float32') + self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) + + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() + + x_grad_actual = x.gradient() + x_grad_expected = (2.0 / float(numel) * + (x_np + dx_expected * + (x_np > 0) * 2 / float(numel))).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + @dygraph_guard + def test_example_with_gradient_accumulation_and_no_grad_set(self): + x = random_var(self.shape) + x_np = x.numpy() + numel = x_np.size + x.stop_gradient = False + + y1 = fluid.layers.relu(x) + y2 = fluid.layers.relu(x) + z = y1 + y2 + w = z * z + + w_mean = fluid.layers.reduce_mean(w) + del y1, z, w + + dx_actual, = self.grad( + [w_mean], [x], create_graph=True, no_grad_set=[y2]) + + self.assertFalse(y2.stop_gradient) + self.assertFalse(dx_actual.stop_gradient) + + dx_expected = (1.0 / float(numel) * (np.maximum(x_np, 0) + y2.numpy()) * + (x_np > 0) * 2).astype('float32') + self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) + + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() + + x_grad_actual = x.gradient() + x_grad_expected = (2.0 / float(numel) * + (x_np + dx_expected * + (x_np > 0) * 4 / float(numel))).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + @dygraph_guard + def test_example_with_gradient_accumulation_and_not_create_graph(self): + x = random_var(self.shape) + x_np = x.numpy() + numel = x_np.size + x.stop_gradient = False + + y = fluid.layers.relu(x) + z = y + 1 + w = z * z + + w_mean = fluid.layers.reduce_mean(w) + del y, z, w + + dx_actual, = self.grad([w_mean], [x], create_graph=False) + del w_mean + + self.assertTrue(dx_actual.stop_gradient) + + dx_expected = (1.0 / float(numel) * (np.maximum(x_np, 0) + 1) * + (x_np > 0) * 2).astype('float32') + + self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) + + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() + + x_grad_actual = x.gradient() + x_grad_expected = (2.0 * x_np / float(numel)).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + +class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad): + def setUp(self): + self.sort_sum_gradient = True + self.shape = [5, 10] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py new file mode 100644 index 00000000000..6122cc6ab88 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_star_gan_with_gradient_penalty.py @@ -0,0 +1,614 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import numpy as np +import unittest + +if fluid.is_compiled_with_cuda(): + fluid.core.globals()['FLAGS_cudnn_deterministic'] = True + + +class Config(object): + def __init__(self, place, sort_sum_gradient=True): + self.place = place + + if isinstance(place, fluid.CPUPlace): + # CPU cases are extremely slow + self.g_base_dims = 1 + self.d_base_dims = 1 + + self.g_repeat_num = 1 + self.d_repeat_num = 1 + + self.image_size = 32 + else: + self.g_base_dims = 64 + self.d_base_dims = 64 + + self.g_repeat_num = 6 + self.d_repeat_num = 6 + + self.image_size = 256 + + self.c_dim = 10 + self.batch_size = 1 + + self.seed = 1 + + self.lambda_rec = 10 + self.lambda_gp = 10 + + self.iterations = 10 + + self.sort_sum_gradient = sort_sum_gradient + + +def create_mnist_dataset(cfg): + def create_target_label(label): + return label + # return (label + 1) % cfg.c_dim # fake label target + + def create_one_hot(label): + ret = np.zeros([cfg.c_dim]) + ret[label] = 1 + return ret + + def __impl__(): + dataset = paddle.dataset.mnist.train() + image_reals = [] + label_orgs = [] + label_trgs = [] + num = 0 + + for image_real, label_org in dataset(): + image_real = np.reshape(np.array(image_real), [28, 28]) + image_real = np.resize(image_real, [cfg.image_size, cfg.image_size]) + image_real = np.array([image_real] * 3) + + label_trg = create_target_label(label_org) + + image_reals.append(np.array(image_real)) + label_orgs.append(create_one_hot(label_org)) + label_trgs.append(create_one_hot(label_trg)) + + if len(image_reals) == cfg.batch_size: + image_real_np = np.array(image_reals).astype('float32') + label_org_np = np.array(label_orgs).astype('float32') + label_trg_np = np.array(label_trgs).astype('float32') + + yield image_real_np, label_org_np, label_trg_np + + num += 1 + if num == cfg.iterations: + break + + image_reals = [] + label_orgs = [] + label_trgs = [] + + return __impl__ + + +class InstanceNorm(fluid.dygraph.Layer): + def __init__(self, num_channels, epsilon=1e-5): + super(InstanceNorm, self).__init__() + self.epsilon = epsilon + + self.scale = self.create_parameter(shape=[num_channels], is_bias=False) + self.bias = self.create_parameter(shape=[num_channels], is_bias=True) + + def forward(self, input): + if fluid.in_dygraph_mode(): + inputs = {'X': [input], 'Scale': [self.scale], 'Bias': [self.bias]} + attrs = {'epsilon': self.epsilon} + return fluid.core.ops.instance_norm(inputs, attrs)['Y'][0] + else: + return fluid.layers.instance_norm( + input, + epsilon=self.epsilon, + param_attr=fluid.ParamAttr(self.scale.name), + bias_attr=fluid.ParamAttr(self.bias.name)) + + +class Conv2DLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters=64, + filter_size=7, + stride=1, + padding=0, + norm=None, + use_bias=False, + relufactor=None): + super(Conv2DLayer, self).__init__() + self._conv = fluid.dygraph.Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + bias_attr=None if use_bias else False) + + if norm is not None: + self._norm = InstanceNorm(num_filters) + else: + self._norm = None + + self.relufactor = relufactor + + def forward(self, input): + conv = self._conv(input) + + if self._norm: + conv = self._norm(conv) + + if self.relufactor is not None: + conv = fluid.layers.leaky_relu(conv, alpha=self.relufactor) + + return conv + + +class Deconv2DLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters=64, + filter_size=7, + stride=1, + padding=0, + norm=None, + use_bias=False, + relufactor=None): + super(Deconv2DLayer, self).__init__() + + self._deconv = fluid.dygraph.Conv2DTranspose( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + bias_attr=None if use_bias else False) + + if norm is not None: + self._norm = InstanceNorm(num_filters) + else: + self._norm = None + + self.relufactor = relufactor + + def forward(self, input): + deconv = self._deconv(input) + + if self._norm: + deconv = self._norm(deconv) + + if self.relufactor is not None: + deconv = fluid.layers.leaky_relu(deconv, alpha=self.relufactor) + + return deconv + + +class ResidualBlock(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super(ResidualBlock, self).__init__() + self._conv0 = Conv2DLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1, + norm=True, + relufactor=0) + + self._conv1 = Conv2DLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=1, + padding=1, + norm=True, + relufactor=None) + + def forward(self, input): + conv0 = self._conv0(input) + conv1 = self._conv1(conv0) + return input + conv1 + + +class Generator(fluid.dygraph.Layer): + def __init__(self, cfg, num_channels=3): + super(Generator, self).__init__() + conv_base = Conv2DLayer( + num_channels=cfg.c_dim + num_channels, + num_filters=cfg.g_base_dims, + filter_size=7, + stride=1, + padding=3, + norm=True, + relufactor=0) + + sub_layers = [conv_base] + cur_channels = cfg.g_base_dims + for i in range(2): + sub_layer = Conv2DLayer( + num_channels=cur_channels, + num_filters=cur_channels * 2, + filter_size=4, + stride=2, + padding=1, + norm=True, + relufactor=0) + + cur_channels *= 2 + sub_layers.append(sub_layer) + + self._conv0 = fluid.dygraph.Sequential(*sub_layers) + + repeat_num = cfg.g_repeat_num + sub_layers = [] + for i in range(repeat_num): + res_block = ResidualBlock( + num_channels=cur_channels, num_filters=cfg.g_base_dims * 4) + sub_layers.append(res_block) + + self._res_block = fluid.dygraph.Sequential(*sub_layers) + + cur_channels = cfg.g_base_dims * 4 + sub_layers = [] + for i in range(2): + rate = 2**(1 - i) + deconv = Deconv2DLayer( + num_channels=cur_channels, + num_filters=cfg.g_base_dims * rate, + filter_size=4, + stride=2, + padding=1, + relufactor=0, + norm=True) + cur_channels = cfg.g_base_dims * rate + sub_layers.append(deconv) + + self._deconv = fluid.dygraph.Sequential(*sub_layers) + + self._conv1 = Conv2DLayer( + num_channels=cur_channels, + num_filters=3, + filter_size=7, + stride=1, + padding=3, + relufactor=None) + + def forward(self, input, label_trg): + shape = input.shape + label_trg_e = fluid.layers.reshape(label_trg, + [-1, label_trg.shape[1], 1, 1]) + label_trg_e = fluid.layers.expand( + x=label_trg_e, expand_times=[1, 1, shape[2], shape[3]]) + + input1 = fluid.layers.concat([input, label_trg_e], 1) + + conv0 = self._conv0(input1) + res_block = self._res_block(conv0) + deconv = self._deconv(res_block) + conv1 = self._conv1(deconv) + out = fluid.layers.tanh(conv1) + return out + + +class Discriminator(fluid.dygraph.Layer): + def __init__(self, cfg, num_channels=3): + super(Discriminator, self).__init__() + + cur_dim = cfg.d_base_dims + + conv_base = Conv2DLayer( + num_channels=num_channels, + num_filters=cur_dim, + filter_size=4, + stride=2, + padding=1, + relufactor=0.2) + + repeat_num = cfg.d_repeat_num + sub_layers = [conv_base] + for i in range(1, repeat_num): + sub_layer = Conv2DLayer( + num_channels=cur_dim, + num_filters=cur_dim * 2, + filter_size=4, + stride=2, + padding=1, + relufactor=0.2) + cur_dim *= 2 + sub_layers.append(sub_layer) + + self._conv0 = fluid.dygraph.Sequential(*sub_layers) + + kernel_size = int(cfg.image_size / np.power(2, repeat_num)) + + self._conv1 = Conv2DLayer( + num_channels=cur_dim, + num_filters=1, + filter_size=3, + stride=1, + padding=1) + + self._conv2 = Conv2DLayer( + num_channels=cur_dim, + num_filters=cfg.c_dim, + filter_size=kernel_size) + + def forward(self, input): + conv = self._conv0(input) + out1 = self._conv1(conv) + out2 = self._conv2(conv) + return out1, out2 + + +def loss_cls(cls, label, cfg): + cls_shape = cls.shape + cls = fluid.layers.reshape( + cls, [-1, cls_shape[1] * cls_shape[2] * cls_shape[3]]) + return fluid.layers.reduce_sum( + fluid.layers.sigmoid_cross_entropy_with_logits(cls, + label)) / cfg.batch_size + + +def calc_gradients(outputs, inputs, no_grad_set): + if fluid.in_dygraph_mode(): + from paddle.fluid.dygraph.base import grad + return grad( + outputs=outputs, + inputs=inputs, + no_grad_set=no_grad_set, + create_graph=True) + else: + return fluid.gradients( + targets=outputs, inputs=inputs, no_grad_set=no_grad_set) + + +def gradient_penalty(f, real, fake, no_grad_set, cfg): + def _interpolate(a, b): + shape = [a.shape[0]] + alpha = fluid.layers.uniform_random_batch_size_like( + input=a, shape=shape, min=0.1, max=1.0, seed=cfg.seed) + + inner = fluid.layers.elementwise_mul( + b, 1.0 - alpha, axis=0) + fluid.layers.elementwise_mul( + a, alpha, axis=0) + return inner + + x = _interpolate(real, fake) + pred, _ = f(x) + if isinstance(pred, tuple): + pred = pred[0] + + gradient = calc_gradients( + outputs=[pred], inputs=[x], no_grad_set=no_grad_set) + + if gradient is None: + return None + + gradient = gradient[0] + grad_shape = gradient.shape + + gradient = fluid.layers.reshape( + gradient, [-1, grad_shape[1] * grad_shape[2] * grad_shape[3]]) + + epsilon = 1e-16 + norm = fluid.layers.sqrt( + fluid.layers.reduce_sum( + fluid.layers.square(gradient), dim=1) + epsilon) + + gp = fluid.layers.reduce_mean(fluid.layers.square(norm - 1.0)) + return gp + + +def get_generator_loss(image_real, label_org, label_trg, generator, + discriminator, cfg): + fake_img = generator(image_real, label_trg) + rec_img = generator(fake_img, label_org) + g_loss_rec = fluid.layers.reduce_mean( + fluid.layers.abs(fluid.layers.elementwise_sub(image_real, rec_img))) + + pred_fake, cls_fake = discriminator(fake_img) + + g_loss_fake = -fluid.layers.mean(pred_fake) + g_loss_cls = loss_cls(cls_fake, label_trg, cfg) + g_loss = g_loss_fake + cfg.lambda_rec * g_loss_rec + g_loss_cls + return g_loss + + +def get_discriminator_loss(image_real, label_org, label_trg, generator, + discriminator, cfg): + fake_img = generator(image_real, label_trg) + pred_real, cls_real = discriminator(image_real) + pred_fake, _ = discriminator(fake_img) + d_loss_cls = loss_cls(cls_real, label_org, cfg) + d_loss_fake = fluid.layers.mean(pred_fake) + d_loss_real = -fluid.layers.mean(pred_real) + d_loss = d_loss_real + d_loss_fake + d_loss_cls + + d_loss_gp = gradient_penalty(discriminator, image_real, fake_img, + discriminator.parameters(), cfg) + if d_loss_gp is not None: + d_loss += cfg.lambda_gp * d_loss_gp + + return d_loss + + +def build_optimizer(layer, cfg, loss=None): + learning_rate = 1e-3 + beta1 = 0.5 + beta2 = 0.999 + if fluid.in_dygraph_mode(): + return fluid.optimizer.Adam( + learning_rate=learning_rate, + beta1=beta1, + beta2=beta2, + parameter_list=layer.parameters()) + else: + optimizer = fluid.optimizer.Adam( + learning_rate=learning_rate, beta1=beta1, beta2=beta2) + + optimizer.minimize(loss, parameter_list=layer.parameters()) + return optimizer + + +class DyGraphTrainModel(object): + def __init__(self, cfg): + fluid.default_startup_program().random_seed = cfg.seed + fluid.default_main_program().random_seed = cfg.seed + + self.generator = Generator(cfg) + self.discriminator = Discriminator(cfg) + + self.g_optimizer = build_optimizer(self.generator, cfg) + self.d_optimizer = build_optimizer(self.discriminator, cfg) + + self.cfg = cfg + + self.backward_strategy = fluid.dygraph.BackwardStrategy() + self.backward_strategy.sort_sum_gradient = cfg.sort_sum_gradient + + def run(self, image_real, label_org, label_trg): + image_real = fluid.dygraph.to_variable(image_real) + label_org = fluid.dygraph.to_variable(label_org) + label_trg = fluid.dygraph.to_variable(label_trg) + + g_loss = get_generator_loss(image_real, label_org, label_trg, + self.generator, self.discriminator, + self.cfg) + g_loss.backward(self.backward_strategy) + if self.g_optimizer: + self.g_optimizer.minimize(g_loss) + self.generator.clear_gradients() + + d_loss = get_discriminator_loss(image_real, label_org, label_trg, + self.generator, self.discriminator, + self.cfg) + d_loss.backward(self.backward_strategy) + if self.d_optimizer: + self.d_optimizer.minimize(d_loss) + self.discriminator.clear_gradients() + + return g_loss.numpy()[0], d_loss.numpy()[0] + + +class StaticGraphTrainModel(object): + def __init__(self, cfg): + self.cfg = cfg + + def create_data_layer(): + image_real = fluid.data( + shape=[None, 3, cfg.image_size, cfg.image_size], + dtype='float32', + name='image_real') + label_org = fluid.data( + shape=[None, cfg.c_dim], dtype='float32', name='label_org') + label_trg = fluid.data( + shape=[None, cfg.c_dim], dtype='float32', name='label_trg') + return image_real, label_org, label_trg + + self.gen_program = fluid.Program() + gen_startup_program = fluid.Program() + + with fluid.program_guard(self.gen_program, gen_startup_program): + self.gen_program.random_seed = cfg.seed + gen_startup_program.random_seed = cfg.seed + with fluid.unique_name.guard(): + image_real, label_org, label_trg = create_data_layer() + generator = Generator(cfg) + discriminator = Discriminator(cfg) + g_loss = get_generator_loss(image_real, label_org, label_trg, + generator, discriminator, cfg) + build_optimizer(generator, cfg, loss=g_loss) + + self.dis_program = fluid.Program() + dis_startup_program = fluid.Program() + with fluid.program_guard(self.dis_program, dis_startup_program): + self.dis_program.random_seed = cfg.seed + dis_startup_program.random_seed = cfg.seed + with fluid.unique_name.guard(): + image_real, label_org, label_trg = create_data_layer() + generator = Generator(cfg) + discriminator = Discriminator(cfg) + d_loss = get_discriminator_loss(image_real, label_org, + label_trg, generator, + discriminator, cfg) + build_optimizer(discriminator, cfg, loss=d_loss) + + self.executor = fluid.Executor(cfg.place) + self.scope = fluid.Scope() + + with fluid.scope_guard(self.scope): + self.executor.run(gen_startup_program) + self.executor.run(dis_startup_program) + + self.g_loss = g_loss + self.d_loss = d_loss + + def run(self, image_real, label_org, label_trg): + feed = { + 'image_real': image_real, + 'label_org': label_org, + 'label_trg': label_trg + } + with fluid.scope_guard(self.scope): + g_loss_val = self.executor.run(self.gen_program, + feed=feed, + fetch_list=[self.g_loss])[0] + d_loss_val = self.executor.run(self.dis_program, + feed=feed, + fetch_list=[self.d_loss])[0] + return g_loss_val[0], d_loss_val[0] + + +class TestStarGANWithGradientPenalty(unittest.TestCase): + def test_main(self): + self.place_test(fluid.CPUPlace()) + + if fluid.is_compiled_with_cuda(): + self.place_test(fluid.CUDAPlace(0)) + + def place_test(self, place): + cfg = Config(place) + + dataset = create_mnist_dataset(cfg) + dataset = fluid.io.cache(dataset) + + static_graph_model = StaticGraphTrainModel(cfg) + static_loss = [] + for batch_id, (image_real, label_org, + label_trg) in enumerate(dataset()): + loss = static_graph_model.run(image_real, label_org, label_trg) + static_loss.append(loss) + + dygraph_loss = [] + with fluid.dygraph.guard(cfg.place): + dygraph_model = DyGraphTrainModel(cfg) + for batch_id, (image_real, label_org, + label_trg) in enumerate(dataset()): + loss = dygraph_model.run(image_real, label_org, label_trg) + dygraph_loss.append(loss) + + for (g_loss_s, d_loss_s), (g_loss_d, d_loss_d) in zip(static_loss, + dygraph_loss): + self.assertEqual(g_loss_s, g_loss_d) + self.assertEqual(d_loss_s, d_loss_d) + + +if __name__ == '__main__': + unittest.main() -- GitLab