未验证 提交 a31d7328 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add dygraph double grad implementation (#22939)

* add double grad implementation for dygraph, test=develop

* polish code, add uts, test=develop

* fix place bug, test=develop

* polish codes, add more uts for coverages, test=develop

* add no_grad_set, test=develop

* add star gan ut, test=develop

* follow comments, test=develop
上级 995a6376
......@@ -209,14 +209,13 @@ class SingleGradOpMaker<imperative::OpBase>
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const {
std::vector<std::shared_ptr<imperative::OpBase>> retv{
std::make_shared<imperative::OpBase>()};
std::shared_ptr<imperative::GradOpNode> operator()() const {
auto node = this->NewGradNode();
{
imperative::TracedGradOp grad_op(retv.front());
this->Apply(&grad_op);
imperative::TracedGradOp traced_grad_op(node);
this->Apply(&traced_grad_op);
}
return retv;
return node->empty() ? nullptr : node;
}
protected:
......@@ -262,8 +261,9 @@ class EmptyGradOpMaker<imperative::OpBase> final
: public imperative::GradOpBaseMakerBase {
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const final {
return {};
std::shared_ptr<imperative::GradOpNode> operator()() const final {
return nullptr;
}
};
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include <string>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/imperative/saved_variable_wrapper_list.h"
namespace paddle {
namespace framework {
......
......@@ -56,7 +56,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>;
using DygraphGradOpMakerFN =
std::function<std::vector<std::shared_ptr<imperative::OpBase>>(
std::function<std::shared_ptr<imperative::GradOpNode>(
const std::string& /*op_type*/,
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
......
......@@ -6,7 +6,8 @@ cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator
add_subdirectory(jit)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer)
cc_library(engine SRCS engine.cc DEPS layer gradient_accumulator)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc)
if(NOT WIN32)
if(WITH_NCCL)
......
......@@ -12,17 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include <algorithm>
#include <memory>
#include <queue>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -32,33 +35,22 @@ namespace imperative {
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
backward_strategy_ = strategy;
const auto& ops = var->GradVarBase()->GradOps();
var->ClearGradOps();
init_node_ = var->GradVarBase()->GradNode();
var->GradVarBase()->ClearGradNode();
if (ops.empty() || var->OverridedStopGradient()) {
if (init_node_ == nullptr || var->OverridedStopGradient()) {
VLOG(3) << "Skip auto grad since there is no grad op for var or loss is "
"stop_gradient=True: "
<< var->Name();
return;
} else {
bool valid = false;
for (const auto& op : ops) {
if (op) {
valid = true;
}
}
if (!valid) {
VLOG(3) << "Skip auto grad since all grad op of start VarBase is nullptr";
return;
}
}
init_ops_ = ops;
var->GradVarBase()->ClearGradOps();
VLOG(3) << "start backward";
PADDLE_ENFORCE_EQ(var->HasGradVar(), true,
"Grad variable not exist for variable %s", var->Name());
PADDLE_ENFORCE_EQ(
var->HasGradVar(), true,
platform::errors::NotFound("Grad variable not exist for variable %s",
var->Name()));
auto& fwd_var = var->Var().Get<framework::LoDTensor>();
auto* grad_var =
......@@ -72,10 +64,14 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
operators::math::set_constant(*dev_ctx, grad_var, 1.0);
}
void BasicEngine::CheckBackwardInputs(OpBase* op) {
for (auto& pair : op->GetInsMap()) {
void BasicEngine::CheckBackwardInputs(const OpBase& op) {
for (auto& pair : op.GetInsMap()) {
if (!pair.second.IsGrad()) {
continue;
}
for (auto& var : pair.second) {
if (!var || op->IsAllowedEmptyVar(var.get())) {
if (!var) {
continue;
}
......@@ -89,17 +85,20 @@ void BasicEngine::CheckBackwardInputs(OpBase* op) {
if (tensor && !tensor->IsInitialized()) {
// if grad var has OverridedStopGradient skip this Op
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(op->place());
tensor->mutable_data(op->place(), var->DataType());
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(op.place());
tensor->mutable_data(op.place(), var->DataType());
operators::math::set_constant(*dev_ctx, tensor, 0.0);
}
}
}
}
void BasicEngine::PrepareGradAccumulators(OpBase* op) {
for (const auto& pair : op->GetOutsMap()) {
void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
for (const auto& pair : op.GetOutsMap()) {
if (!pair.second.IsGrad()) {
continue;
}
for (const auto& var : pair.second) {
if (!var) continue;
......@@ -114,142 +113,155 @@ void BasicEngine::PrepareGradAccumulators(OpBase* op) {
accumulator->IncreaseRefCnt();
VLOG(3) << "Prepare to acccumulate variable grad " << var->Name()
<< "with reference count " << accumulator->RefCnt();
VLOG(3) << "Prepare to acccumulate variable grad " << var->Name() << "("
<< var.get() << ") with reference count "
<< accumulator->RefCnt();
}
}
}
void BasicEngine::PrepareDeps() {
PADDLE_ENFORCE_EQ(op_deps_.empty(), true, "Op deps must be initialized here");
PADDLE_ENFORCE_EQ(accumulators_.empty(), true,
"Accumulators must be initialized here");
std::queue<OpBase*> q;
std::unordered_set<OpBase*> visited;
for (const auto& init_op : init_ops_) {
q.push(init_op.get());
visited.insert(init_op.get());
}
PADDLE_ENFORCE_EQ(
node_deps_.empty(), true,
platform::errors::AlreadyExists("Op deps must be initialized here"));
PADDLE_ENFORCE_EQ(
accumulators_.empty(), true,
platform::errors::AlreadyExists("Accumulators must be initialized here"));
std::queue<GradOpNode*> q;
std::unordered_set<GradOpNode*> visited;
q.push(init_node_.get());
visited.insert(init_node_.get());
while (!q.empty()) {
auto* cur_op = q.front();
auto* cur_node = q.front();
q.pop();
PADDLE_ENFORCE_NE(
cur_op->GetInsMap().empty() && cur_op->GetOutsMap().empty(), true,
platform::errors::NotFound(
"Inputs and outputs of %s do not exist. "
"This may be because you call \"backward()\" twice for the same "
"subgraph. Please try to call \"stop_gradient = True\" or "
"\"detach()\" if you use some same vars between two \"backward()\" "
"calls.",
cur_op->Type()));
PrepareGradAccumulators(cur_op);
const auto& grad_pending_ops = cur_op->GradPendingOps();
for (auto& grad_pending_op : grad_pending_ops) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
++op_deps_[grad_pending_op.get()];
if (visited.count(grad_pending_op.get()) == 0) {
visited.insert(grad_pending_op.get());
q.push(grad_pending_op.get());
for (auto& cur_op : *cur_node) {
PADDLE_ENFORCE_NE(
cur_op.GetInsMap().empty() && cur_op.GetOutsMap().empty(), true,
platform::errors::NotFound(
"Inputs and outputs of %s do not exist. "
"This may be because you call \"backward()\" twice for the same "
"subgraph. Please try to call \"stop_gradient = True\" or "
"\"detach()\" if you use some same vars between two "
"\"backward()\" "
"calls.",
cur_op.Type()));
PrepareGradAccumulators(cur_op);
}
const auto& grad_pending_nodes = cur_node->GradPendingNodes();
for (auto& grad_pending_node : grad_pending_nodes) {
PADDLE_ENFORCE_NOT_NULL(
grad_pending_node,
platform::errors::NotFound("Grad pending node should not be null"));
++node_deps_[grad_pending_node.get()];
if (visited.count(grad_pending_node.get()) == 0) {
visited.insert(grad_pending_node.get());
q.push(grad_pending_node.get());
}
}
}
}
void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VariableWrapper> src,
VariableWrapper* dst) {
auto iter = accumulators_.find(dst);
PADDLE_ENFORCE_EQ(iter != accumulators_.end(), true,
"Cannot find gradient of variable %s", dst->Name());
iter->second->Add(std::move(src), op->id());
}
void BasicEngine::Execute() {
if (init_node_ == nullptr) {
return;
}
PrepareDeps();
// Start execute Computation graph
std::queue<std::shared_ptr<OpBase>> q;
for (const auto& init_op : init_ops_) {
q.push(std::move(init_op));
}
std::queue<std::shared_ptr<GradOpNode>> q;
q.push(std::move(init_node_));
size_t op_num = 0;
while (!q.empty()) {
auto shared_cur_op = std::move(q.front());
auto shared_cur_node = std::move(q.front());
q.pop();
auto* cur_op = shared_cur_op.get();
++op_num;
for (auto& cur_op : *shared_cur_node) {
++op_num;
// CheckBackWardInput
CheckBackwardInputs(cur_op);
// CheckBackWardInput
CheckBackwardInputs(cur_op);
// Step 1: Run Backward
auto& bwd_ins = cur_op->GetInsMap();
auto& bwd_outs = cur_op->GetOutsMap();
// Step 1: Run Backward
auto& bwd_ins = cur_op.GetInsMap();
auto& bwd_outs = cur_op.GetOutsMap();
NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
// 1. construct the output map 2. replace the element in the map
// A var may be coresponding to several grad var in one op
for (auto it = tmp_outs.begin(); it != tmp_outs.end(); ++it) {
for (size_t i = 0; i < it->second.size(); ++i) {
auto tmp_var =
std::make_shared<VariableWrapper>("Gtmp@"); // Do not need grad
NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
// 1. construct the output map 2. replace the element in the map
// A var may be coresponding to several grad var in one op
for (auto& pair : tmp_outs) {
if (!pair.second.IsGrad()) {
continue;
}
auto var = it->second[i];
it->second[i] = tmp_var;
if (var) {
need_accu_var_list_.emplace_back(var.get(), std::move(tmp_var));
for (auto& var : pair.second) {
if (!var) {
continue;
}
auto iter = accumulators_.find(var.get());
PADDLE_ENFORCE_EQ(
iter != accumulators_.end(), true,
platform::errors::NotFound("Cannot find gradient of variable %s",
var->Name()));
if (!var->OverridedStopGradient() && iter->second->RefCnt() == 1) {
continue;
}
var = std::make_shared<VariableWrapper>("Gtmp@");
need_accu_var_list_.emplace_back(iter->second.get(), var);
}
}
}
{
VLOG(3) << "Start to execute grad op " << cur_op->Type();
OpBase::Run(cur_op->InnerOp(), bwd_ins, tmp_outs, cur_op->Attrs(),
cur_op->place());
}
// Step 2: Sum Gradient
{
VLOG(3) << "Start to execute grad op " << cur_op.Type();
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
}
if (need_accu_var_list_.size() > 0) {
// Step 2: Sum Gradient
for (auto& pair : need_accu_var_list_) {
SumGradient(cur_op, std::move(pair.second), pair.first);
pair.first->Add(std::move(pair.second), cur_op.id());
}
}
need_accu_var_list_.clear();
need_accu_var_list_.clear();
// Step 3: Collect ready ops
VLOG(3) << "Remove op after op " << cur_op.Type() << " runs";
cur_op.ClearBackwardTrace();
}
for (auto& grad_pending_op : cur_op->GradPendingOps()) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
auto iter = op_deps_.find(grad_pending_op.get());
if (iter == op_deps_.end()) {
// Step 3: Collect ready ops
for (auto& grad_pending_node : shared_cur_node->GradPendingNodes()) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_node,
platform::errors::NotFound(
"Grad pending node should not be nullptr"));
auto iter = node_deps_.find(grad_pending_node.get());
if (iter == node_deps_.end()) {
continue;
}
VLOG(3) << "Found grad_pending op of " << cur_op->Type();
// An Op is ready to go while its deps comes to zero
if (--(iter->second) == 0) {
q.push(grad_pending_op);
VLOG(3) << "Push grad_pending op " << grad_pending_op->Type()
<< " into queue";
q.push(grad_pending_node);
}
}
// Step 4: Delete op to collect unused variables
VLOG(3) << "Remove op after op " << cur_op->Type() << " runs";
cur_op->ClearBackwardTrace();
}
Clear();
VLOG(1) << "Backward op number: " << op_num;
}
void BasicEngine::Clear() {
init_node_.reset();
node_deps_.clear();
accumulators_.clear();
need_accu_var_list_.clear();
}
} // namespace imperative
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
namespace paddle {
namespace imperative {
class VarBase;
class OpBase;
class BasicEngine : public Engine {
public:
void Init(VarBase* var, const detail::BackwardStrategy& strategy);
void Execute() override;
private:
void PrepareDeps();
void CheckBackwardInputs(const OpBase& op);
void PrepareGradAccumulators(const OpBase& op);
void Clear();
private:
std::shared_ptr<GradOpNode> init_node_;
detail::BackwardStrategy backward_strategy_;
std::unordered_map<GradOpNode*, size_t> node_deps_;
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_;
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_;
};
} // namespace imperative
} // namespace paddle
......@@ -18,9 +18,11 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
......@@ -51,11 +53,8 @@ class GradOpBaseMakerBase {
attrs_(attrs) {}
virtual ~GradOpBaseMakerBase() = default;
virtual std::vector<std::shared_ptr<OpBase>> operator()() const = 0;
static std::shared_ptr<OpBase> CreateOp() {
return std::make_shared<OpBase>();
}
virtual std::shared_ptr<GradOpNode> operator()() const = 0;
TracedVarList<VarBase, TracedVarRole::kBackward> InputGrad(
const std::string& name, bool drop_empty_grad = true) const {
......@@ -138,6 +137,10 @@ class GradOpBaseMakerBase {
return var_base_map_out_.count(name) > 0;
}
static std::shared_ptr<GradOpNode> NewGradNode() {
return std::make_shared<GradOpNode>();
}
private:
template <TracedVarRole kRole>
TracedVarList<VarBase, kRole> GetVarBaseList(const std::string& name,
......@@ -149,7 +152,13 @@ class GradOpBaseMakerBase {
if (iterator != data_map.end()) {
vec_temp.reserve(iterator->second.size());
bool is_valid = false;
for (auto& var_base_temp : iterator->second) {
if (!var_base_temp) {
vec_temp.emplace_back();
continue;
}
if (kRole == TracedVarRole::kBackward) {
if (!var_base_temp->HasGradVar()) {
VLOG(6) << "GradVarBase of var " << var_base_temp->Name()
......@@ -168,6 +177,11 @@ class GradOpBaseMakerBase {
} else {
vec_temp.emplace_back(var_base_temp);
}
is_valid = true;
}
if (!is_valid) {
vec_temp.clear();
}
}
......@@ -185,44 +199,63 @@ class TracedGradOp {
DISABLE_COPY_AND_ASSIGN(TracedGradOp);
public:
explicit TracedGradOp(const std::shared_ptr<OpBase>& op) : op_(op) {}
explicit TracedGradOp(const std::shared_ptr<GradOpNode>& node)
: node_(node), op_(&(node->emplace_back())) {}
~TracedGradOp() {
op_->SetGradPendingOps(
{grad_pending_ops_.begin(), grad_pending_ops_.end()});
op_->CheckAttrs();
if (UNLIKELY(op_->GetOutsMap().empty())) {
node_->pop_back();
} else {
op_->CheckAttrs();
}
}
template <TracedVarRole kRole>
void SetInput(const std::string& name,
const TracedVarList<VarBase, kRole>& vars) {
if (vars.empty()) {
return;
}
if (kRole == TracedVarRole::kBackward) {
for (auto& var : vars) {
var->AddGradOp(op_);
if (var && !var->OverridedStopGradient()) {
var->SetGradNode(node_);
}
}
}
op_->SetInput(name, ToVarWrapperList(vars));
auto var_wrappers = ToVarWrapperList<kRole>(vars);
if (!var_wrappers.empty()) {
op_->SetInput(name, std::move(var_wrappers),
kRole == TracedVarRole::kBackward);
}
}
template <TracedVarRole kRole>
void SetOutput(const std::string& name,
const TracedVarList<VarBase, kRole>& vars) {
if (vars.empty()) {
return;
}
if (kRole == TracedVarRole::kBackward) {
if (vars.size() == 1 && vars.front()->OverridedStopGradient()) {
op_->SetOutput(name, VariableWrapperList{});
return;
} else {
for (auto& var : vars) {
if (!var->OverridedStopGradient()) {
for (auto& op : var->GradOps()) {
grad_pending_ops_.emplace(op);
}
if (var && !var->OverridedStopGradient() && var->GradNode()) {
node_->InsertGradPendingNode(var->GradNode());
}
}
}
}
op_->SetOutput(name, ToVarWrapperList(vars));
auto var_wrappers = ToVarWrapperList<kRole>(vars);
if (!var_wrappers.empty()) {
op_->SetOutput(name, std::move(var_wrappers),
kRole == TracedVarRole::kBackward);
}
}
void SetType(const std::string& type) { op_->SetType(type); }
......@@ -247,19 +280,31 @@ class TracedGradOp {
}
private:
template <TracedVarRole kRole>
static std::vector<std::shared_ptr<VariableWrapper>> ToVarWrapperList(
const std::vector<std::shared_ptr<VarBase>>& vars) {
std::vector<std::shared_ptr<VariableWrapper>> result;
result.reserve(vars.size());
bool has_valid = false;
for (auto& var : vars) {
result.emplace_back(var->SharedVar());
if (UNLIKELY(!var || (kRole == TracedVarRole::kBackward &&
var->OverridedStopGradient()))) {
result.emplace_back();
} else {
result.emplace_back(var->SharedVar());
has_valid = true;
}
}
if (!has_valid) {
result.clear();
}
return result;
}
private:
const std::shared_ptr<OpBase>& op_;
std::unordered_set<std::shared_ptr<OpBase>> grad_pending_ops_;
const std::shared_ptr<GradOpNode>& node_;
OpBase* op_;
};
} // namespace imperative
......
......@@ -14,63 +14,18 @@
#pragma once
#include <cstddef>
#include <cstdint>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace imperative {
// It seems there is no need for Engine to be an
// singleton, we can have multi-engine to run
// mutil-graoh. For future use we may expose a interface
// to Python to support
class Engine {
DISABLE_COPY_AND_ASSIGN(Engine);
public:
Engine() = default;
virtual ~Engine() = default;
virtual void Execute() = 0;
virtual void Init(VarBase* var, const detail::BackwardStrategy& strategy) = 0;
};
class BasicEngine : public Engine {
public:
void Init(VarBase* var, const detail::BackwardStrategy& strategy) override;
void Execute() override;
private:
void PrepareDeps();
void CheckBackwardInputs(OpBase* op);
void PrepareGradAccumulators(OpBase* op);
void SumGradient(OpBase* op, std::shared_ptr<VariableWrapper> src,
VariableWrapper* dst);
// TODO(jiabin): maybe we can optimize the performance of engine by cache the
// result
void Clear() {
init_ops_.clear();
op_deps_.clear();
accumulators_.clear();
}
std::vector<std::shared_ptr<OpBase>> init_ops_;
detail::BackwardStrategy backward_strategy_;
std::unordered_map<OpBase*, size_t> op_deps_;
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_;
std::vector<std::pair<VariableWrapper*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_;
};
} // namespace imperative
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle {
namespace imperative {
template <typename VarType>
class DygraphExecutionContext : public framework::ExecutionContext {
using Variable = framework::Variable;
public:
DygraphExecutionContext(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::DeviceContext& device_context,
const framework::RuntimeContext& ctx,
std::vector<framework::KernelConfig>* configs,
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
: ExecutionContext(op, scope, device_context, ctx, configs),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
PADDLE_ENFORCE_NE(it, var_base_map_in_.end(),
platform::errors::PreconditionNotMet(
"Can not find [%s] in Input", name));
return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName;
}
std::vector<std::string> InputNames(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_in_.end(),
platform::errors::NotFound("Can not find [%s] in Input", name));
std::vector<std::string> vec_res;
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.push_back(it->second[i]->Name());
} else {
vec_res.push_back(framework::kEmptyVarName);
}
}
return vec_res;
}
std::string OutputName(const std::string& name) const override {
auto it = var_base_map_out_.find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_.end(),
platform::errors::NotFound("Can not find [%s] in Output", name));
return it->second[0] ? it->second[0]->Name() : framework::kEmptyVarName;
}
std::vector<std::string> OutputNames(const std::string& name) const override {
auto it = var_base_map_out_.find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_.end(),
platform::errors::NotFound("Can not find [%s] in Output", name));
std::vector<std::string> vec_res;
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.push_back(it->second[i]->Name());
} else {
vec_res.push_back(framework::kEmptyVarName);
}
}
return vec_res;
}
bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0;
}
const framework::AttributeMap& Attrs() const override { return attrs_; }
const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);
PADDLE_ENFORCE_NE(
it, attrs_.end(),
platform::errors::NotFound("can not find [%s] in attrs", name));
return it->second;
}
std::vector<std::string> InNameList() const override {
std::vector<std::string> vec_temp;
vec_temp.reserve(var_base_map_in_.size());
for (auto& v : var_base_map_in_) {
vec_temp.push_back(v.first);
}
return vec_temp;
}
bool HasInput(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
return (it != var_base_map_in_.end() && it->second.size() > 0);
}
bool HasOutput(const std::string& name) const override {
auto it = var_base_map_out_.find(name);
return (it != var_base_map_out_.end() && it->second.size() > 0);
}
size_t InputSize(const std::string& name) const override {
return InputNames(name).size();
}
size_t OutputSize(const std::string& name) const override {
return OutputNames(name).size();
}
const Variable* InputVar(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
if (it == var_base_map_in_.end()) {
return nullptr;
}
return it->second.empty() || it->second[0] == nullptr
? nullptr
: it->second[0]->MutableVar();
}
Variable* OutputVar(const std::string& name) const override {
auto it = var_base_map_out_.find(name);
if (it == var_base_map_out_.end()) {
return nullptr;
}
return it->second.empty() || it->second[0] == nullptr
? nullptr
: it->second[0]->MutableVar();
}
const std::vector<Variable*> MultiInputVar(
const std::string& name) const override {
auto it = var_base_map_in_.find(name);
if (it == var_base_map_in_.end()) {
return {};
}
std::vector<Variable*> vec_res;
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
}
return vec_res;
}
std::vector<Variable*> MultiOutputVar(
const std::string& name) const override {
auto it = var_base_map_out_.find(name);
if (it == var_base_map_out_.end()) {
return {};
}
std::vector<Variable*> vec_res;
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
vec_res.push_back(it->second[i] ? it->second[i]->MutableVar() : nullptr);
}
return vec_res;
}
private:
const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_;
};
} // namespace imperative
} // namespace paddle
......@@ -29,6 +29,39 @@
namespace paddle {
namespace imperative {
static void MoveOrCopyVar(framework::Variable* dst, framework::Variable* src,
bool force_copy) {
if (!force_copy) {
*dst = std::move(*src);
return;
}
VLOG(10) << "Copy occurs when accumulating gradients";
if (src->IsType<framework::LoDTensor>()) {
auto& src_tensor = src->Get<framework::LoDTensor>();
if (!dst->IsType<framework::LoDTensor>()) {
dst->Clear();
}
auto* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::TensorCopy(src_tensor, src_tensor.place(), dst_tensor);
dst_tensor->set_lod(src_tensor.lod());
} else if (src->IsType<framework::SelectedRows>()) {
auto& src_selected_rows = src->Get<framework::SelectedRows>();
if (!dst->IsType<framework::SelectedRows>()) {
dst->Clear();
}
auto* dst_selected_rows = dst->GetMutable<framework::SelectedRows>();
framework::TensorCopy(src_selected_rows.value(),
src_selected_rows.value().place(),
dst_selected_rows->mutable_value());
dst_selected_rows->set_rows(src_selected_rows.rows());
dst_selected_rows->set_height(src_selected_rows.height());
} else {
PADDLE_THROW(platform::errors::PermissionDenied(
"Only support LoDTensor and SelectedRows for gradient accumulation"));
}
}
template <typename T>
class TensorAddFunctor : public boost::static_visitor<> {
public:
......@@ -141,6 +174,49 @@ void SelectedRowsAddToTensor(const framework::Variable& src,
framework::DataTypeToString(data_type)));
}
static void SelectedRowsAddTensor(
const framework::Variable& src_selected_rows_var,
const framework::Variable& src_tensor_var,
framework::Variable* dst_tensor_var) {
const auto& src_selected_rows =
src_selected_rows_var.Get<framework::SelectedRows>();
const auto& src_tensor = src_tensor_var.Get<framework::LoDTensor>();
const auto& place = src_tensor.place();
auto data_type = src_tensor.type();
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto* dst_tensor = dst_tensor_var->GetMutable<framework::LoDTensor>();
dst_tensor->Resize(src_tensor.dims());
dst_tensor->mutable_data(place, data_type);
#define PADDLE_SELECTED_ROWS_ADD_TENSOR(dev_ctx_type, cpp_type) \
if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
paddle::operators::math::SelectedRowsAddTensor<dev_ctx_type, cpp_type> \
functor; \
functor(*(dynamic_cast<dev_ctx_type*>(dev_ctx)), src_selected_rows, \
src_tensor, dst_tensor); \
return; \
}
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, float);
PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CUDADeviceContext, double);
} else {
#endif
PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, float);
PADDLE_SELECTED_ROWS_ADD_TENSOR(platform::CPUDeviceContext, double);
#ifdef PADDLE_WITH_CUDA
}
#endif
PADDLE_THROW(platform::errors::InvalidArgument(
"Not supported data type %s for SelectedRowsAddToTensor",
framework::DataTypeToString(data_type)));
#undef PADDLE_SELECTED_ROWS_ADD_TENSOR
}
// Note(chenweihang): when two selected rows need to be added,
// adding one to another is not equal to merging two selected rows
// to one then add it to a empty selected rows, the after is correct
......@@ -189,7 +265,7 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge(
}
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
VariableWrapper* var_) {
VariableWrapper* var_, bool unchange_input) {
auto& src = var->Var();
auto* dst = var_->MutableVar();
if (dst->IsType<framework::LoDTensor>()) {
......@@ -204,10 +280,15 @@ void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
}
} else {
if (src.IsType<framework::LoDTensor>()) {
auto* src_mutable = var->MutableVar();
SelectedRowsAddToTensor(*dst, src_mutable);
*dst = std::move(*(var->MutableVar()));
var_->SetType(framework::proto::VarType::LOD_TENSOR);
if (unchange_input) {
framework::Variable new_dst;
SelectedRowsAddTensor(*dst, src, &new_dst);
*dst = std::move(new_dst);
} else {
auto* src_mutable = var->MutableVar();
SelectedRowsAddToTensor(*dst, src_mutable);
*dst = std::move(*(var->MutableVar()));
}
} else if (src.IsType<framework::SelectedRows>()) {
auto temp = SelectedRowsMerge(src, *dst);
*dst = std::move(*(temp->MutableVar()));
......@@ -234,18 +315,23 @@ static platform::Place GetPlaceOfVar(
}
void EagerGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
size_t trace_id) {
size_t trace_id, bool unchange_input) {
/**
* If var has grad node, it indicates that this var would be an input
* of a grad op. Therefore, it should not be changed.
*/
if (var->HasGradNode()) {
unchange_input = true;
}
auto* dst_var = var_->MutableVar();
platform::Place place = GetPlaceOfVar(var);
if (!var_->OverridedStopGradient()) {
VLOG(3) << "Sum Gradient for: " << var_->Name();
if (cur_cnt_ == 0) {
if (var->Var().IsType<framework::SelectedRows>()) {
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
}
*dst_var = std::move(*(var->MutableVar()));
MoveOrCopyVar(dst_var, var->MutableVar(), unchange_input);
} else {
VariableWrapperAdd(var, var_);
VariableWrapperAdd(var, var_, unchange_input);
}
} else {
if (!var_->Var().IsInitialized() ||
......@@ -268,75 +354,91 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
}
}
++cur_cnt_;
if (var_->Var().IsType<framework::LoDTensor>()) {
var_->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (var_->Var().IsType<framework::SelectedRows>()) {
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
}
}
void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
size_t trace_id) {
size_t trace_id, bool unchange_input) {
auto* dst_var = var_->MutableVar();
platform::Place place = GetPlaceOfVar(var);
if (!var_->OverridedStopGradient()) {
if (ref_cnt_ == 1) {
if (var->Var().IsType<framework::SelectedRows>()) {
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
*dst_var = std::move(*(var->MutableVar()));
} else {
*dst_var = std::move(*(var->MutableVar()));
}
MoveOrCopyVar(dst_var, var->MutableVar(),
unchange_input || var->HasGradNode());
} else {
if (tmp_grad_vars_.empty()) {
tmp_grad_vars_.reserve(ref_cnt_);
}
tmp_grad_vars_.emplace_back(std::move(var), trace_id);
tmp_grad_vars_.emplace_back(std::move(var), trace_id, unchange_input);
if (tmp_grad_vars_.size() != ref_cnt_) {
return;
}
std::sort(
tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
[](const std::pair<std::shared_ptr<VariableWrapper>, size_t>& p1,
const std::pair<std::shared_ptr<VariableWrapper>, size_t>& p2) {
return p1.second > p2.second;
});
std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
[](const SavedVarInfo& info1, const SavedVarInfo& info2) {
return info1.trace_id > info2.trace_id;
});
for (auto& var_info : tmp_grad_vars_) {
if (var_info.var->HasGradNode()) {
var_info.unchange_input = true;
}
}
#ifdef PADDLE_WITH_CUDA
if (paddle::platform::is_gpu_place(place)) {
bool dst_varbase_is_initialized = false;
// accumulate selected rows firstly
for (size_t i = 0; i < tmp_grad_vars_.size(); ++i) {
if (tmp_grad_vars_[i]
.first->Var()
.IsType<framework::SelectedRows>()) {
if (!dst_varbase_is_initialized) {
dst_varbase_is_initialized = true;
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
*dst_var = std::move(*(tmp_grad_vars_[i].first->MutableVar()));
} else {
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
}
for (auto& var_info : tmp_grad_vars_) {
if (!var_info.var->Var().IsType<framework::SelectedRows>()) {
continue;
}
}
// accumulate lod tensor
for (size_t i = 0; i < tmp_grad_vars_.size(); ++i) {
if (!dst_varbase_is_initialized) {
dst_varbase_is_initialized = true;
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
MoveOrCopyVar(dst_var, var_info.var->MutableVar(),
var_info.unchange_input);
} else {
VariableWrapperAdd(var_info.var, var_, var_info.unchange_input);
}
if (tmp_grad_vars_[i].first->Var().IsType<framework::LoDTensor>()) {
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
var_info.var = nullptr;
}
for (auto& var_info : tmp_grad_vars_) {
if (!var_info.var) {
continue;
}
PADDLE_ENFORCE_EQ(var_info.var->Var().IsType<framework::LoDTensor>(),
true, platform::errors::PermissionDenied(
"Gradient var must be LoDTensor"));
if (!dst_varbase_is_initialized) {
dst_varbase_is_initialized = true;
MoveOrCopyVar(dst_var, var_info.var->MutableVar(),
var_info.unchange_input);
} else {
VariableWrapperAdd(var_info.var, var_, var_info.unchange_input);
}
var_info.var = nullptr;
}
} else {
#endif
if (tmp_grad_vars_[0].first->Var().IsType<framework::SelectedRows>()) {
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
} else {
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
}
MoveOrCopyVar(dst_var, tmp_grad_vars_[0].var->MutableVar(),
tmp_grad_vars_[0].unchange_input);
for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) {
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
VariableWrapperAdd(tmp_grad_vars_[i].var, var_,
tmp_grad_vars_[i].unchange_input);
tmp_grad_vars_[i].var = nullptr;
}
#ifdef PADDLE_WITH_CUDA
}
......@@ -364,6 +466,12 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
// looks like tmp_grad_vars will not have any member but just in case
tmp_grad_vars_.clear();
}
if (var_->Var().IsType<framework::LoDTensor>()) {
var_->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (var_->Var().IsType<framework::SelectedRows>()) {
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
}
}
} // namespace imperative
......
......@@ -26,7 +26,8 @@ class GradientAccumulator {
public:
explicit GradientAccumulator(VariableWrapper* var) : var_(var) {}
virtual void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) = 0;
virtual void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id,
bool unchange_input = false) = 0;
virtual ~GradientAccumulator() = default;
......@@ -43,7 +44,8 @@ class EagerGradientAccumulator : public GradientAccumulator {
public:
using GradientAccumulator::GradientAccumulator;
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) override;
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id,
bool unchange_input) override;
private:
size_t cur_cnt_{0};
......@@ -53,11 +55,23 @@ class SortedGradientAccumulator : public GradientAccumulator {
public:
using GradientAccumulator::GradientAccumulator;
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) override;
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id,
bool unchange_input) override;
private:
std::vector<std::pair<std::shared_ptr<VariableWrapper>, size_t>>
tmp_grad_vars_;
struct SavedVarInfo {
SavedVarInfo(std::shared_ptr<VariableWrapper>&& v, size_t id,
bool enable_unchange_input)
: var(std::move(v)),
trace_id(id),
unchange_input(enable_unchange_input) {}
std::shared_ptr<VariableWrapper> var;
size_t trace_id;
bool unchange_input;
};
std::vector<SavedVarInfo> tmp_grad_vars_;
};
} // namespace imperative
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
namespace paddle {
namespace imperative {
template <typename VarType>
class DygraphInferShapeContext : public framework::InferShapeContext {
using DDim = framework::DDim;
public:
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr)
: var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {}
bool HasInput(const std::string& name) const override {
// has only one input
auto it = var_base_map_in_->find(name);
if (it == var_base_map_in_->end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(), 1UL,
platform::errors::PreconditionNotMet(
"Input %s should not have more than one inputs", name));
return in[0] != nullptr;
}
bool HasOutput(const std::string& name) const override {
// has only one output
auto it = var_base_map_out_->find(name);
if (it == var_base_map_out_->end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(), 1UL,
platform::errors::PreconditionNotMet(
"Output %s should not have more than one outputs", name));
return out[0] != nullptr;
}
bool HasInputs(const std::string& name) const override {
auto it = var_base_map_in_->find(name);
if (it == var_base_map_in_->end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool HasOutputs(const std::string& name) const override {
auto it = var_base_map_out_->find(name);
if (it == var_base_map_out_->end() || it->second.empty()) {
return false;
}
for (auto& output : it->second) {
if (output == nullptr) {
return false;
}
}
return true;
}
framework::AttrReader Attrs() const override {
return framework::AttrReader(*attrs_);
}
std::vector<std::string> Inputs(const std::string& name) const override {
std::vector<std::string> vec_res;
auto it = var_base_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(),
platform::errors::NotFound("can not find [%s] in input", name));
vec_res.reserve(it->second.size());
for (auto& var : it->second) {
if (var) {
vec_res.push_back(var->Name());
} else {
vec_res.push_back(framework::kEmptyVarName);
}
}
return vec_res;
}
std::vector<std::string> Outputs(const std::string& name) const override {
std::vector<std::string> vec_res;
auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name));
vec_res.reserve(it->second.size());
for (auto& var : it->second) {
if (var) {
vec_res.push_back(var->Name());
} else {
vec_res.push_back(framework::kEmptyVarName);
}
}
return vec_res;
}
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
auto in_it = var_base_map_in_->find(in);
auto out_it = var_base_map_out_->find(out);
PADDLE_ENFORCE_NE(
in_it, var_base_map_in_->end(),
platform::errors::NotFound("can not found [%s] in input", in));
PADDLE_ENFORCE_GT(in_it->second.size(), i,
platform::errors::PreconditionNotMet(
"Inputs %s should have %llu argument", in, i));
PADDLE_ENFORCE_NE(
out_it, var_base_map_out_->end(),
platform::errors::NotFound("can not found [%s] in input", in));
PADDLE_ENFORCE_GT(out_it->second.size(), j,
platform::errors::PreconditionNotMet(
"Outputs %s should have %llu argument", out, j));
framework::Variable* in_var = in_it->second[i]->MutableVar();
framework::Variable* out_var = out_it->second[j]->MutableVar();
PADDLE_ENFORCE_EQ(in_var->Type(), out_var->Type(),
platform::errors::PreconditionNotMet(
"The type of %s and %s is not the same.", in, out));
if (in_var->IsType<framework::LoDTensor>()) {
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
}
}
void ShareAllLoD(const std::string& in,
const std::string& out) const override {
// do nothing
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
// do nothing
}
bool IsRuntime() const override { return true; }
// TODO(paddle-dev): Can this be template?
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetInputVarPtrs not support in dygraph runtime context"));
}
std::vector<framework::InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetOutputVarPtrs not support in dygraph runtime context"));
}
DDim GetInputDim(const std::string& name) const override {
auto it = var_base_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(),
platform::errors::NotFound("can not find [%s] in input", name));
PADDLE_ENFORCE_EQ(
it->second.size(), 1UL,
platform::errors::PreconditionNotMet(
"Input(%s) should hold one element, but now it holds %d", name,
it->second.size()));
return this->GetDim(it->second[0]->MutableVar());
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
// const std::vector<Variable*>& vars = InputVars(name);
std::vector<DDim> vec_res;
auto it = var_base_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(),
platform::errors::NotFound("can not find [%s] in output", name));
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.emplace_back(GetDim(it->second[i]->MutableVar()));
} else {
vec_res.emplace_back();
}
}
return vec_res;
}
std::vector<framework::proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
std::vector<framework::proto::VarType::Type> vec_res;
auto it = var_base_map_in_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_in_->end(),
platform::errors::NotFound("can not find [%s] in input", name));
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.emplace_back(
framework::ToVarType(it->second[i]->MutableVar()->Type()));
} else {
vec_res.emplace_back();
}
}
return vec_res;
}
std::vector<framework::proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
std::vector<framework::proto::VarType::Type> vec_res;
auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name));
vec_res.reserve(it->second.size());
for (size_t i = 0; i < it->second.size(); ++i) {
if (it->second[i]) {
vec_res.emplace_back(
framework::ToVarType(it->second[i]->MutableVar()->Type()));
} else {
vec_res.emplace_back(static_cast<framework::proto::VarType::Type>(-1));
}
}
return vec_res;
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name));
if (it->second[0]) {
SetDim(it->second[0]->MutableVar(), dim);
}
}
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(),
platform::errors::NotFound("can not find [%s] in output", name));
PADDLE_ENFORCE_EQ(it->second.size(), dims.size(),
platform::errors::PreconditionNotMet(
"dim size [%d] is not match output var number [%d]",
dims.size(), it->second.size()));
for (size_t i = 0; i < dims.size(); ++i) {
if (it->second[i]) {
SetDim(it->second[i]->MutableVar(), dims[i]);
}
}
}
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetLoDLevel function not support in dygraph mode"));
}
void SetLoDLevel(const std::string& out, int32_t lod_level,
size_t j = 0) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"SetLoDLevel function not support in dygraph mode"));
}
protected:
DDim GetDim(framework::Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
"Input variable should not be null"));
if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().dims();
} else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(platform::errors::PermissionDenied(
"Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is xx."));
}
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetRepeatedDims not support in dygraph runtime"));
}
void SetDim(framework::Variable* var, const DDim& dim) {
if (var->IsType<framework::LoDTensor>()) {
var->GetMutable<framework::LoDTensor>()->Resize(dim);
} else if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::PermissionDenied(
"Variable type_id %s, expect LoDTensor/SelectedRows."));
}
}
void SetDims(const std::vector<framework::Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(
length, dims.size(),
platform::errors::PreconditionNotMet(
"Vars number [%d] should be equal with dims number [%d]", length,
dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"SetRepeatedDims not support in dygraph runtime"));
}
private:
const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_;
};
} // namespace imperative
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
namespace paddle {
namespace imperative {
// infer var type context for imperative mode
template <typename VarType>
class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>& outputs,
const framework::AttributeMap& attrs_map)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs_map),
input_names_(),
output_names_(),
var_set_() {
input_names_.reserve(inputs_.size());
for (auto& it : inputs_) {
for (auto& var : it.second) {
if (var) {
input_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var.get();
}
}
}
output_names_.reserve(outputs_.size());
for (auto& it : outputs_) {
for (auto& var : it.second) {
if (var) {
output_names_[it.first].emplace_back(var->Name());
var_set_[var->Name()] = var.get();
}
}
}
}
virtual ~RuntimeInferVarTypeContext() {}
framework::Attribute GetAttr(const std::string& name) const override {
auto iter = attrs_.find(name);
PADDLE_ENFORCE_EQ(
iter != attrs_.end(), true,
platform::errors::NotFound("Cannot find attribute %s", name));
return iter->second;
}
bool HasVar(const std::string& name) const override {
return var_set_.count(name) > 0;
}
bool HasInput(const std::string& name) const override {
auto it = inputs_.find(name);
return (it != inputs_.end() && it->second.size() > 0);
}
bool HasOutput(const std::string& name) const override {
auto it = outputs_.find(name);
return (it != outputs_.end() && it->second.size() > 0);
}
const std::vector<std::string>& Input(
const std::string& name) const override {
auto iter = input_names_.find(name);
PADDLE_ENFORCE_EQ(
iter != input_names_.end(), true,
platform::errors::NotFound("Cannot find input var %s", name));
return iter->second;
}
const std::vector<std::string>& Output(
const std::string& name) const override {
auto iter = output_names_.find(name);
PADDLE_ENFORCE_EQ(
iter != output_names_.end(), true,
platform::errors::NotFound("Cannot find output var %s", name));
return iter->second;
}
framework::proto::VarType::Type GetType(
const std::string& name) const override {
auto iter = var_set_.find(name);
PADDLE_ENFORCE_EQ(
iter != var_set_.end(), true,
platform::errors::NotFound("Cannot find var %s in GetType", name));
return iter->second->Type();
}
void SetType(const std::string& name,
framework::proto::VarType::Type type) override {
if (name == "kLookupTablePath") {
VLOG(2) << "SUPER UGLY FIX, remove this when move imperative mode in C++";
} else {
var_set_[name]->SetType(type);
if ((var_set_[name]->MutableVar()->IsInitialized() == true) &&
(var_set_[name]->MutableVar()->Type() != type)) {
var_set_[name]->MutableVar()->Clear();
}
}
}
framework::proto::VarType::Type GetDataType(
const std::string& name) const override {
auto iter = var_set_.find(name);
PADDLE_ENFORCE_EQ(
iter != var_set_.end(), true,
platform::errors::NotFound("Cannot find var %s in GetDataType", name));
return iter->second->DataType();
}
void SetDataType(const std::string& name,
framework::proto::VarType::Type type) override {
var_set_[name]->SetDataType(type);
}
std::vector<framework::proto::VarType::Type> GetDataTypes(
const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"GetDataTypes is not supported in runtime InferVarType"));
}
void SetDataTypes(const std::string& name,
const std::vector<framework::proto::VarType::Type>&
multiple_data_type) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"SetDataTypes is not supported in runtime InferVarType"));
}
std::vector<int64_t> GetShape(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType"));
}
void SetShape(const std::string& name,
const std::vector<int64_t>& dims) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle Shape in runtime InferVarType"));
}
int32_t GetLoDLevel(const std::string& name) const override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType"));
}
void SetLoDLevel(const std::string& name, int32_t lod_level) override {
PADDLE_THROW(platform::errors::PermissionDenied(
"Do not handle LoDLevel in runtime InferVarType"));
}
private:
const NameVarMap<VarType>& inputs_;
const NameVarMap<VarType>& outputs_;
const framework::AttributeMap& attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, VarType*> var_set_;
};
} // namespace imperative
} // namespace paddle
......@@ -19,6 +19,10 @@
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/infer_var_type_context.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/prepared_operator.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -180,7 +184,7 @@ static std::string LayerDebugStringImpl(const std::string& op_type,
size_t i = 0;
for (auto& pair : ins) {
if (i > 0) ss << ", ";
ss << DebugString(pair.first, pair.second);
ss << DebugString<VarType>(pair.first, pair.second);
++i;
}
......@@ -188,7 +192,7 @@ static std::string LayerDebugStringImpl(const std::string& op_type,
i = 0;
for (auto& pair : outs) {
if (i > 0) ss << ", ";
ss << DebugString(pair.first, pair.second);
ss << DebugString<VarType>(pair.first, pair.second);
++i;
}
return ss.str();
......@@ -206,6 +210,27 @@ std::string LayerDebugString(const std::string& op_type,
return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs);
}
VarBase::VarBase(bool has_grad, const std::shared_ptr<VariableWrapper>& var)
: var_(var), grad_node_(var->GetGradNode()) {
if (has_grad) {
if (auto grad_var = var_->GetGradVar()) {
grad_var_ = std::make_shared<VarBase>(false, grad_var);
} else {
grad_var_ = std::make_shared<VarBase>(false, GradVarName());
var_->SetGradVar(grad_var_->var_);
}
}
if (IsDebugEnabled()) {
VLOG(10) << "Construct VarBase: " << Name();
name_set_.Insert(Name());
}
}
size_t VarBase::GradOpNum() const {
return grad_node_ ? grad_node_->size() : 0;
}
void VarBase::ClearGradient() {
if (grad_var_) {
if (grad_var_->Var().IsType<framework::SelectedRows>()) {
......@@ -292,8 +317,6 @@ void OpBase::SetType(const std::string& type) {
}
void OpBase::ClearBackwardTrace() {
grad_pending_ops_.clear();
allow_empty_vars_.clear();
ins_.clear();
outs_.clear();
}
......@@ -308,14 +331,16 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
auto& info = op.Info();
if (info.infer_var_type_) {
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, &outs, attrs);
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, outs, attrs);
info.infer_var_type_(&infer_var_type_ctx);
}
// Initialize output var type
for (auto& var_pair : outs) {
for (auto& var : var_pair.second) {
InitializeVariable(var->MutableVar(), var->Type());
if (var) {
InitializeVariable(var->MutableVar(), var->Type());
}
}
}
......@@ -344,5 +369,64 @@ void OpBase::Run(const framework::OperatorBase& op,
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, place);
}
static void ClearNoNeedBufferInputs(OpBase* op) {
auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (!inferer) return;
auto* ins = op->GetMutableInsMap();
const auto& no_need_buffer_slots =
inferer(*ins, op->GetOutsMap(), op->Attrs());
if (no_need_buffer_slots.empty()) return;
for (auto& slot : no_need_buffer_slots) {
auto iter = ins->find(slot);
if (iter == ins->end()) continue;
VLOG(2) << "Clear data buffer of " << slot << " in " << op->Type();
PADDLE_ENFORCE_EQ(
iter->second.IsGrad(), false,
platform::errors::InvalidArgument(
"Only forward variable buffers can be clear, this may be a bug"));
for (auto& each_var : *(iter->second.MutableVarList())) {
if (!each_var) continue;
auto& var = each_var->Var();
PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(), true,
platform::errors::PermissionDenied(
"NoNeedBufferVars only support LoDTensor"));
// TODO(zjl): support higher order derivatives
auto new_var = new VariableWrapper(each_var->Name());
auto* new_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims());
new_tensor->set_lod(old_tensor.lod());
each_var.reset(new_var);
}
}
}
std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place) {
const auto& info = op.Info();
if (!info.dygraph_grad_op_maker_) {
return nullptr;
}
auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs);
if (grad_node && !grad_node->empty()) {
for (auto& op : *grad_node) {
op.SetId(OpBase::GenerateUniqueId());
op.SetPlace(place);
ClearNoNeedBufferInputs(&op);
}
return grad_node;
} else {
return nullptr;
}
}
} // namespace imperative
} // namespace paddle
此差异已折叠。
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace imperative {
// TODO(zjl): to support py_func layer
class OpBase {
public:
OpBase() = default;
OpBase(const OpBase&) = delete;
OpBase(OpBase&&) = default;
OpBase& operator=(const OpBase&) = delete;
OpBase& operator=(OpBase&&) = default;
~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
const std::string& Type() const { return op_->Type(); }
const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::OpInfo& Info() const { return op_->Info(); }
const framework::OperatorBase& InnerOp() const { return *op_; }
void ClearBackwardTrace();
NameVarMap<VariableWrapper>* GetMutableOutsMap() { return &outs_; }
NameVarMap<VariableWrapper>* GetMutableInsMap() { return &ins_; }
const NameVarMap<VariableWrapper>& GetInsMap() const { return ins_; }
const NameVarMap<VariableWrapper>& GetOutsMap() const { return outs_; }
void SetType(const std::string& type);
void CheckAttrs() {
auto& info = op_->Info();
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true);
}
}
void SetInput(const std::string& name, VariableWrapperList vars,
bool is_grad) {
auto& in_vars = ins_[name];
*(in_vars.MutableVarList()) = std::move(vars);
in_vars.SetIsGrad(is_grad);
}
void SetOutput(const std::string& name, VariableWrapperList vars,
bool is_grad) {
auto& out_vars = outs_[name];
*(out_vars.MutableVarList()) = std::move(vars);
out_vars.SetIsGrad(is_grad);
}
void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
void SetAttr(const std::string& name, const framework::Attribute& v) {
attrs_[name] = v;
}
void SetBlockAttr(const std::string& name, framework::BlockDesc* block) {
PADDLE_THROW(platform::errors::PermissionDenied(
"SetBlockAttr is not support in dygraph OpBase"));
}
const framework::AttributeMap& Attrs() { return attrs_; }
bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; }
const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE_NE(
it, attrs_.end(),
platform::errors::NotFound("can not find attribute [%s]", name));
return it->second;
}
template <typename T>
inline const T& Attr(const std::string& name) const {
return boost::get<T>(GetAttr(name));
}
size_t id() const { return id_; }
void SetId(size_t id) { id_ = id; }
const platform::Place& place() const { return place_; }
void SetPlace(const platform::Place& place) { place_ = place; }
static size_t GenerateUniqueId() {
static std::atomic<size_t> unique_id{0};
return unique_id.fetch_add(1);
}
static void Run(const framework::OperatorBase& op,
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
static void Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
private:
NameVarMap<VariableWrapper> ins_;
NameVarMap<VariableWrapper> outs_;
framework::AttributeMap attrs_;
std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_;
size_t id_{-1UL};
std::vector<std::function<void()>> backward_hooks_;
};
class GradOpNode {
public:
GradOpNode() = default;
void reserve(size_t size) { ops_.reserve(size); }
size_t size() const { return ops_.size(); }
bool empty() const { return ops_.empty(); }
void clear() { ops_.clear(); }
void pop_back() { ops_.pop_back(); }
template <typename... ARGS>
OpBase& emplace_back(ARGS&&... args) { // NOLINT
ops_.emplace_back(std::forward<ARGS>(args)...);
return ops_.back();
}
const OpBase& back() const { return ops_.back(); }
OpBase& back() { return ops_.back(); }
OpBase& operator[](size_t idx) { return ops_[idx]; }
const OpBase& operator[](size_t idx) const { return ops_[idx]; }
/* Iterator related */
using Iterator = std::vector<OpBase>::iterator;
using ConstIterator = std::vector<OpBase>::const_iterator;
Iterator begin() { return ops_.begin(); }
Iterator end() { return ops_.end(); }
ConstIterator begin() const { return ops_.begin(); }
ConstIterator end() const { return ops_.end(); }
void InsertGradPendingNode(const std::shared_ptr<GradOpNode>& node) {
if (node &&
std::find(grad_pending_nodes_.begin(), grad_pending_nodes_.end(),
node) == grad_pending_nodes_.end()) {
grad_pending_nodes_.emplace_back(node);
}
}
const std::vector<std::shared_ptr<GradOpNode>>& GradPendingNodes() const {
return grad_pending_nodes_;
}
private:
DISABLE_COPY_AND_ASSIGN(GradOpNode);
private:
std::vector<OpBase> ops_;
std::vector<std::shared_ptr<GradOpNode>> grad_pending_nodes_;
};
} // namespace imperative
} // namespace paddle
此差异已折叠。
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace imperative {
class VarBase;
class PartialGradEngine : public Engine {
public:
PartialGradEngine(const std::vector<std::shared_ptr<VarBase>> &input_targets,
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place,
const detail::BackwardStrategy &strategy,
bool create_graph);
void Execute() override;
std::vector<std::shared_ptr<VarBase>> GetResult() const;
private:
void Clear();
private:
std::vector<std::shared_ptr<VarBase>> input_targets_;
std::vector<std::shared_ptr<VarBase>> output_targets_;
std::vector<std::shared_ptr<VarBase>> output_grads_;
std::vector<std::shared_ptr<VarBase>> no_grad_vars_;
platform::Place place_;
detail::BackwardStrategy strategy_;
bool create_graph_;
std::vector<std::shared_ptr<VarBase>> results_;
};
} // namespace imperative
} // namespace paddle
......@@ -14,6 +14,9 @@
#include "paddle/fluid/imperative/prepared_operator.h"
#include <sstream>
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/infer_var_type_context.h"
namespace paddle {
namespace imperative {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace imperative {
class VariableWrapper;
class SavedVariableWrapperList {
public:
SavedVariableWrapperList() : vars_(), is_grad_(false) {}
template <typename... Args>
explicit SavedVariableWrapperList(bool is_grad, Args&&... args)
: vars_(std::forward<Args>(args)...), is_grad_(is_grad) {}
bool IsGrad() const { return is_grad_; }
void SetIsGrad(bool is_grad) { is_grad_ = is_grad; }
const std::vector<std::shared_ptr<VariableWrapper>>& VarList() const {
return vars_;
}
std::vector<std::shared_ptr<VariableWrapper>>* MutableVarList() {
return &vars_;
}
/* Borrow method from std::vector */
size_t size() const { return vars_.size(); }
bool empty() const { return vars_.empty(); }
template <typename... ARGS>
void emplace_back(ARGS&&... args) {
vars_.emplace_back(std::forward<ARGS>(args)...);
}
using Iterator = std::vector<std::shared_ptr<VariableWrapper>>::iterator;
using ConstIterator =
std::vector<std::shared_ptr<VariableWrapper>>::const_iterator;
Iterator begin() { return vars_.begin(); }
Iterator end() { return vars_.end(); }
ConstIterator begin() const { return vars_.begin(); }
ConstIterator end() const { return vars_.end(); }
std::shared_ptr<VariableWrapper>& operator[](size_t idx) {
return vars_[idx];
}
const std::shared_ptr<VariableWrapper>& operator[](size_t idx) const {
return vars_[idx];
}
operator const std::vector<std::shared_ptr<VariableWrapper>>&() const {
return vars_;
}
private:
std::vector<std::shared_ptr<VariableWrapper>> vars_;
bool is_grad_;
};
} // namespace imperative
} // namespace paddle
......@@ -117,5 +117,202 @@ TEST(test_add_functor, add_functor) {
#endif
}
static void CopyVar(const framework::Variable& var,
framework::Variable* dst_ptr) {
auto& dst = *dst_ptr;
dst.Clear();
if (var.IsType<framework::LoDTensor>()) {
const auto& src_tensor = var.Get<framework::LoDTensor>();
auto* dst_tensor = dst.GetMutable<framework::LoDTensor>();
framework::TensorCopySync(src_tensor, src_tensor.place(), dst_tensor);
} else {
const auto& src_selected_rows = var.Get<framework::SelectedRows>();
auto* dst_selected_rows = dst.GetMutable<framework::SelectedRows>();
dst_selected_rows->set_rows(src_selected_rows.rows());
dst_selected_rows->set_height(src_selected_rows.height());
framework::TensorCopySync(src_selected_rows.value(),
src_selected_rows.value().place(),
dst_selected_rows->mutable_value());
}
}
static bool IsEqualVar(const framework::Variable& var1,
const framework::Variable& var2) {
if (var1.Type() != var2.Type()) {
return false;
}
framework::Tensor t1, t2;
if (var1.IsType<framework::LoDTensor>()) {
framework::TensorCopySync(var1.Get<framework::LoDTensor>(),
platform::CPUPlace(), &t1);
framework::TensorCopySync(var2.Get<framework::LoDTensor>(),
platform::CPUPlace(), &t2);
} else {
auto& s1 = var1.Get<framework::SelectedRows>();
auto& s2 = var2.Get<framework::SelectedRows>();
if (s1.height() != s2.height()) {
return false;
}
if (s1.rows().size() != s2.rows().size()) {
return false;
}
auto row1_data = s1.rows().data();
auto row2_data = s2.rows().data();
if (std::memcmp(row1_data, row2_data,
s1.rows().size() * sizeof(*row1_data)) != 0) {
return false;
}
framework::TensorCopySync(var1.Get<framework::SelectedRows>().value(),
platform::CPUPlace(), &t1);
framework::TensorCopySync(var2.Get<framework::SelectedRows>().value(),
platform::CPUPlace(), &t2);
}
if (t1.type() != t2.type() || t1.dims() != t2.dims()) {
return false;
}
auto* t1_p = t1.data<void>();
auto* t2_p = t2.data<void>();
return std::memcmp(t1_p, t2_p,
t1.numel() * framework::SizeOfType(t1.type())) == 0;
}
template <typename T>
static framework::Variable RandomTensor(const framework::DDim& dims,
const platform::Place& place,
int low = -10, int high = 10) {
framework::Tensor cpu_tensor;
cpu_tensor.Resize(dims);
auto* ptr = cpu_tensor.mutable_data<T>(platform::CPUPlace());
std::uniform_int_distribution<int> dist(low, high);
std::random_device rd;
std::mt19937 engine(rd());
for (int64_t i = 0; i < cpu_tensor.numel(); ++i) {
ptr[i] = dist(engine);
}
framework::Variable ret;
framework::TensorCopySync(cpu_tensor, place,
ret.GetMutable<framework::LoDTensor>());
return ret;
}
template <typename T>
static framework::Variable RandomSelectedRows(framework::DDim dims,
const platform::Place& place,
int64_t row_number, int low = -10,
int high = 10) {
auto height = dims[0];
dims[0] = row_number;
framework::Variable ret;
auto* sr = ret.GetMutable<framework::SelectedRows>();
auto tensor_var = RandomTensor<T>(dims, place, low, high);
sr->mutable_value()->ShareDataWith(
tensor_var.template Get<framework::LoDTensor>());
sr->set_height(height);
sr->mutable_rows()->resize(row_number);
auto* row_data = sr->mutable_rows()->data();
std::uniform_int_distribution<int64_t> dist(0, height - 1);
std::random_device rd;
std::mt19937 engine(rd());
for (int64_t i = 0; i < dims[0]; ++i) {
row_data[i] = dist(engine);
}
return ret;
}
static std::unique_ptr<GradientAccumulator> CreateAccumulator(
const std::shared_ptr<VariableWrapper>& var, bool sort_gradient) {
if (sort_gradient) {
return std::unique_ptr<GradientAccumulator>(
new SortedGradientAccumulator(var.get()));
} else {
return std::unique_ptr<GradientAccumulator>(
new EagerGradientAccumulator(var.get()));
}
}
static void TestGradientAccumulatorTestUnchangeInput(
const platform::Place& place, bool sort_gradient) {
framework::DDim dim{10, 20};
int64_t maximum_row_number = 100;
std::uniform_int_distribution<int64_t> dist(1, maximum_row_number);
int seed;
{
std::random_device rd;
seed = rd();
}
std::mt19937 engine(seed);
auto create_var = [&](bool use_tensor) {
if (use_tensor) {
return RandomTensor<float>(dim, place);
} else {
return RandomSelectedRows<float>(dim, place, dist(engine));
}
};
std::vector<bool> use_tensors = {false, true};
for (auto use_tensor1 : use_tensors) {
for (auto use_tensor2 : use_tensors) {
auto g_var1 = std::make_shared<VariableWrapper>("g_var1");
g_var1->SetOverridedStopGradient(false);
auto g_accum1 = CreateAccumulator(g_var1, sort_gradient);
g_accum1->IncreaseRefCnt();
g_accum1->IncreaseRefCnt();
auto g_var2 = std::make_shared<VariableWrapper>("g_var2");
g_var2->SetOverridedStopGradient(false);
auto g_accum2 = CreateAccumulator(g_var2, sort_gradient);
g_accum2->IncreaseRefCnt();
g_accum2->IncreaseRefCnt();
auto var1 = create_var(use_tensor1);
auto var_wrapper1_1 = std::make_shared<VariableWrapper>("tmp1_1");
auto var_wrapper2_1 = std::make_shared<VariableWrapper>("tmp2_1");
CopyVar(var1, var_wrapper1_1->MutableVar());
CopyVar(var1, var_wrapper2_1->MutableVar());
auto var2 = create_var(use_tensor2);
auto var_wrapper1_2 = std::make_shared<VariableWrapper>("tmp1_2");
auto var_wrapper2_2 = std::make_shared<VariableWrapper>("tmp2_2");
CopyVar(var2, var_wrapper1_2->MutableVar());
CopyVar(var2, var_wrapper2_2->MutableVar());
g_accum1->Add(var_wrapper1_1, 0, false);
g_accum1->Add(var_wrapper1_2, 1, false);
g_accum2->Add(var_wrapper2_1, 0, true);
g_accum2->Add(var_wrapper2_2, 1, true);
ASSERT_TRUE(IsEqualVar(var_wrapper2_1->Var(), var1));
ASSERT_TRUE(IsEqualVar(var_wrapper2_2->Var(), var2));
ASSERT_TRUE(IsEqualVar(g_var1->Var(), g_var2->Var()));
}
}
}
TEST(test_gradient_accumulator, test_unchange_input) {
for (auto sort_gradient : {false, true}) {
TestGradientAccumulatorTestUnchangeInput(platform::CPUPlace(),
sort_gradient);
#ifdef PADDLE_WITH_CUDA
TestGradientAccumulatorTestUnchangeInput(platform::CUDAPlace(0),
sort_gradient);
#endif
}
}
} // namespace imperative
} // namespace paddle
......@@ -21,6 +21,9 @@
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/imperative/execution_context.h"
#include "paddle/fluid/imperative/infer_shape_context.h"
#include "paddle/fluid/imperative/infer_var_type_context.h"
#include "paddle/fluid/imperative/layer.h"
namespace imperative = paddle::imperative;
......@@ -45,7 +48,7 @@ TEST(test_layer, test_runtime_context) {
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap attrs;
auto *ctx = new imperative::RuntimeInferVarTypeContext<imperative::VarBase>(
ins, &outs, attrs);
ins, outs, attrs);
ASSERT_TRUE(ctx->HasVar("vin"));
ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out"));
......@@ -120,11 +123,12 @@ TEST(test_layer, test_debug_string) {
ASSERT_TRUE(res_sr.find("SelectedRows") != std::string::npos);
}
static std::shared_ptr<imperative::OpBase> CreateOpBase(
static std::shared_ptr<imperative::GradOpNode> CreateGradNode(
size_t id, const std::string &type, const imperative::NameVarBaseMap &ins,
const imperative::NameVarBaseMap &outs,
const framework::AttributeMap &attrs, const platform::Place &place) {
auto op = std::make_shared<imperative::OpBase>();
auto node = std::make_shared<imperative::GradOpNode>();
auto *op = &(node->emplace_back());
op->SetId(id);
op->SetPlace(place);
op->SetType(type);
......@@ -134,7 +138,7 @@ static std::shared_ptr<imperative::OpBase> CreateOpBase(
for (auto &var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetInput(pair.first, vars);
op->SetInput(pair.first, vars, false);
}
for (auto &pair : outs) {
......@@ -142,10 +146,10 @@ static std::shared_ptr<imperative::OpBase> CreateOpBase(
for (auto &var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetOutput(pair.first, vars);
op->SetOutput(pair.first, vars, false);
}
return op;
return node;
}
TEST(test_layer, test_clear_backward_info) {
......@@ -163,19 +167,21 @@ TEST(test_layer, test_clear_backward_info) {
framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1;
auto op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place);
auto preceding_op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place);
op->SetGradPendingOps({preceding_op});
auto node = CreateGradNode(0, "mul", ins, outs, concat_att_map, place);
auto pending_node =
CreateGradNode(0, "mul", ins, outs, concat_att_map, place);
node->InsertGradPendingNode(pending_node);
ASSERT_EQ(node->size(), 1UL);
auto *op = &(node->back());
ASSERT_GT(op->GetInsMap().size(), 0UL);
ASSERT_GT(op->GetOutsMap().size(), 0UL);
ASSERT_GT(op->GradPendingOps().size(), 0UL);
op->ClearBackwardTrace();
ASSERT_EQ(op->GetInsMap().size(), 0UL);
ASSERT_EQ(op->GetOutsMap().size(), 0UL);
ASSERT_EQ(op->GradPendingOps().size(), 0UL);
}
TEST(test_layer, test_varbase_basic) {
......
......@@ -22,6 +22,7 @@
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/memory/memcpy.h"
......@@ -148,9 +149,9 @@ TEST(test_tracer, test_track_backward_output) {
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL);
}
TEST(test_tracer, test_track_backward_input) {
......@@ -188,9 +189,9 @@ TEST(test_tracer, test_track_backward_input) {
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL);
}
#if defined(PADDLE_WITH_CUDA)
TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
......@@ -240,9 +241,9 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
tracer.TraceOp("reduce_sum", reduce_in, reduce_out, reduce_attr_map,
gpu_place, true);
detail::BackwardStrategy back_st;
imperative::Engine* engine = tracer.GetDefaultEngine();
engine->Init(reduce_sum_out.get(), back_st);
engine->Execute();
imperative::BasicEngine engine;
engine.Init(reduce_sum_out.get(), back_st);
engine.Execute();
framework::LoDTensor rlt;
framework::TensorCopySync(vout->Var().Get<framework::LoDTensor>(), place,
......@@ -346,14 +347,14 @@ TEST(test_tracer, test_var_without_grad_var) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
}
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
ASSERT_EQ(x_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL);
detail::BackwardStrategy back_st;
imperative::Engine* engine = tracer.GetDefaultEngine();
engine->Init(vout.get(), back_st);
engine->Execute();
imperative::BasicEngine engine;
engine.Init(vout.get(), back_st);
engine.Execute();
// check the grad
framework::LoDTensor x_grad;
......@@ -382,7 +383,7 @@ static void TestVarOpDestructionMain(const platform::Place& place,
size_t loop_num = 10) {
WeakPtrSet<VariableWrapper> var_wrappers;
WeakPtrSet<VarBase> var_bases;
WeakPtrSet<OpBase> op_bases;
WeakPtrSet<GradOpNode> op_bases;
Tracer tracer;
......@@ -413,30 +414,31 @@ static void TestVarOpDestructionMain(const platform::Place& place,
NameVarBaseMap{{"Out", {z}}}, framework::AttributeMap{},
place, true);
ASSERT_EQ(z->GradOps().size(), 0UL);
ASSERT_EQ(z->GradVarBase()->GradOps().size(), 1UL);
auto new_op = z->GradVarBase()->GradOps()[0];
ASSERT_EQ(z->GradOpNum(), 0UL);
ASSERT_EQ(z->GradVarBase()->GradOpNum(), 1UL);
auto new_op = z->GradVarBase()->GradNode();
ASSERT_EQ(x->GradOps().size(), 0UL);
ASSERT_EQ(y->GradOps().size(), 0UL);
ASSERT_EQ(x->GradOpNum(), 0UL);
ASSERT_EQ(y->GradOpNum(), 0UL);
std::unordered_set<std::shared_ptr<OpBase>> expected_pending_ops;
std::unordered_set<std::shared_ptr<GradOpNode>> expected_pending_ops;
if (i == 0) {
ASSERT_EQ(x->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(x->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL);
} else {
ASSERT_EQ(x->GradVarBase()->GradOps().size(), 1UL);
ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(x->GradVarBase()->GradOpNum(), 1UL);
ASSERT_EQ(y->GradVarBase()->GradOpNum(), 0UL);
for (auto& op : x->GradVarBase()->GradOps()) {
expected_pending_ops.emplace(op);
if (x->GradVarBase()->GradNode()) {
expected_pending_ops.emplace(x->GradVarBase()->GradNode());
}
for (auto& op : y->GradVarBase()->GradOps()) {
expected_pending_ops.emplace(op);
if (y->GradVarBase()->GradNode()) {
expected_pending_ops.emplace(y->GradVarBase()->GradNode());
}
std::unordered_set<std::shared_ptr<OpBase>> actual_pending_ops;
for (auto& op : new_op->GradPendingOps()) {
std::unordered_set<std::shared_ptr<GradOpNode>> actual_pending_ops;
for (auto& op : new_op->GradPendingNodes()) {
actual_pending_ops.emplace(op);
}
......
......@@ -16,6 +16,7 @@
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
......@@ -31,49 +32,6 @@ void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer) {
VLOG(6) << "Set current tracer: " << g_current_tracer;
}
static void ClearNoNeedBufferInputs(OpBase* op) {
auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (!inferer) return;
auto* ins = op->GetMutableInsMap();
const auto& no_need_buffer_slots =
inferer(*ins, op->GetOutsMap(), op->Attrs());
if (no_need_buffer_slots.empty()) return;
for (auto& slot : no_need_buffer_slots) {
auto iter = ins->find(slot);
if (iter == ins->end()) continue;
VLOG(2) << "Clear data buffer of " << slot << " in " << op->Type();
for (auto& each_var : iter->second) {
if (!each_var) continue;
auto& var = each_var->Var();
PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(), true,
"Only support LoDTensor");
// TODO(zjl): support higher order derivatives
auto new_var = new VariableWrapper(each_var->Name());
auto* new_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims());
new_tensor->set_lod(old_tensor.lod());
each_var.reset(new_var);
op->AddAllowedEmptyVar(new_var);
}
}
}
static std::vector<std::shared_ptr<OpBase>> CreateGradOpBases(
const framework::OpInfo& info, const std::string& type,
const NameVarBaseMap& in, const NameVarBaseMap& out,
const framework::AttributeMap& attrs) {
if (info.dygraph_grad_op_maker_) {
return info.dygraph_grad_op_maker_(type, in, out, attrs);
} else {
return {};
}
}
static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
for (const auto& name_pair : outs) {
for (const auto& vb : name_pair.second) {
......@@ -103,7 +61,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
}
if (ComputeRequiredGrad(ins, outs, trace_backward)) {
TraceBackward(op_info, type, ins, outs, attrs, place);
CreateGradOpNode(*op, ins, outs, attrs, place);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
......@@ -133,22 +91,5 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
return false;
}
void Tracer::TraceBackward(const framework::OpInfo& info,
const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
auto grad_op_bases = CreateGradOpBases(info, type, ins, outs, attrs);
auto grad_op_num = grad_op_bases.size();
if (grad_op_num == 0) return;
size_t trace_id = GenerateUniqueId();
for (auto& grad_op : grad_op_bases) {
grad_op->SetPlace(place);
grad_op->SetId(trace_id);
ClearNoNeedBufferInputs(grad_op.get());
}
}
} // namespace imperative
} // namespace paddle
......@@ -21,7 +21,7 @@
#include <unordered_map>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/platform/macros.h"
......@@ -46,7 +46,7 @@ class Tracer {
public:
Tracer()
: engine_(new BasicEngine()),
: basic_engine_(new BasicEngine()),
program_desc_tracer_(new jit::ProgramDescTracer()),
generator_(new UniqueNameGenerator()) {
expected_place_ = platform::CPUPlace();
......@@ -64,8 +64,6 @@ class Tracer {
bool ComputeRequiredGrad(const NameVarBaseMap& ins,
const NameVarBaseMap& outs, bool trace_backward);
Engine* GetDefaultEngine() const { return engine_.get(); }
void SetEnableProgramDescTracing(bool enabled) {
enable_program_desc_tracing_ = enabled;
}
......@@ -82,6 +80,8 @@ class Tracer {
return generator_->Generate(key);
}
BasicEngine* GetEngine() const { return basic_engine_.get(); }
platform::Place ExpectedPlace() const { return expected_place_; }
void SetExpectedPlace(platform::Place place) { expected_place_ = place; }
......@@ -91,18 +91,7 @@ class Tracer {
void SetNoGrad(bool no_grad) { no_grad_ = no_grad; }
private:
void TraceBackward(const framework::OpInfo& info, const std::string& type,
const NameVarBaseMap& ins, const NameVarBaseMap& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
static size_t GenerateUniqueId() {
static std::atomic<size_t> id{0};
return id.fetch_add(1);
}
private:
std::unique_ptr<Engine> engine_;
std::unique_ptr<BasicEngine> basic_engine_;
std::unique_ptr<jit::ProgramDescTracer> program_desc_tracer_;
bool enable_program_desc_tracing_{false};
std::unique_ptr<UniqueNameGenerator> generator_;
......
......@@ -23,18 +23,37 @@ namespace paddle {
namespace imperative {
class VariableWrapper;
class SavedVariableWrapperList;
class VarBase;
class OpBase;
class GradOpNode;
class Tracer;
using WeakNameVarBaseMap =
std::map<std::string, std::vector<std::weak_ptr<VarBase>>>;
namespace details {
template <typename T>
using NameVarMap = std::map<std::string, std::vector<std::shared_ptr<T>>>;
struct NameVarMapTrait {};
template <>
struct NameVarMapTrait<VarBase> {
using Type = std::map<std::string, std::vector<std::shared_ptr<VarBase>>>;
};
template <>
struct NameVarMapTrait<VariableWrapper> {
using Type = std::map<std::string, SavedVariableWrapperList>;
};
} // namespace details
template <typename T>
using NameVarMap = typename details::NameVarMapTrait<T>::Type;
using NameVarBaseMap = NameVarMap<VarBase>;
using NameVariableWrapperMap = NameVarMap<VariableWrapper>;
using WeakNameVarBaseMap =
std::map<std::string, std::vector<std::weak_ptr<VarBase>>>;
using VariableWrapperList = std::vector<std::shared_ptr<VariableWrapper>>;
} // namespace imperative
} // namespace paddle
......@@ -14,14 +14,20 @@
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace imperative {
class VarBase;
class GradOpNode;
class VariableWrapper {
public:
friend class VarBase;
explicit VariableWrapper(const std::string& name) : name_(name) {}
const framework::Variable& Var() const { return var_; }
......@@ -31,6 +37,10 @@ class VariableWrapper {
// This is used for python api
void SetOverridedStopGradient(bool stop_gradient) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient);
if (auto grad_var = grad_var_.lock()) {
grad_var->SetOverridedStopGradient(stop_gradient);
}
}
// This is used for python api
......@@ -47,6 +57,10 @@ class VariableWrapper {
VLOG(6) << "Ignore Stop gradient conversion for Var: " << Name()
<< "Set value is: " << overrided_stop_gradient_;
}
if (auto grad_var = grad_var_.lock()) {
grad_var->InnerSetOverridedStopGradient(stop_gradient);
}
}
void SetPersistable(bool persistable) { persistable_ = persistable; }
......@@ -65,6 +79,18 @@ class VariableWrapper {
data_type_ = data_type;
}
std::shared_ptr<VariableWrapper> GetGradVar() const {
return grad_var_.lock();
}
const std::weak_ptr<VariableWrapper>& GetWeakGradVar() const {
return grad_var_;
}
std::shared_ptr<GradOpNode> GetGradNode() const { return grad_node_.lock(); }
bool HasGradNode() const { return !grad_node_.expired(); }
framework::proto::VarType::Type DataType() const {
const framework::Tensor* tensor = nullptr;
if (var_.IsInitialized()) {
......@@ -85,6 +111,32 @@ class VariableWrapper {
}
}
private:
void SetGradVar(const std::shared_ptr<VariableWrapper>& var) {
auto shared_var = grad_var_.lock();
if (shared_var != var) {
PADDLE_ENFORCE_EQ(shared_var, nullptr,
platform::errors::PermissionDenied(
"Cannot set gradient var wrapper twice"));
grad_var_ = var;
}
}
void SetGradNode(const std::shared_ptr<GradOpNode>& grad_node) {
if (!grad_node) {
grad_node_.reset();
return;
}
auto shared_node = grad_node_.lock();
if (shared_node != grad_node) {
PADDLE_ENFORCE_EQ(
shared_node, nullptr,
platform::errors::PermissionDenied("Cannot set gradient op twice"));
grad_node_ = grad_node;
}
}
private:
framework::Variable var_;
std::string name_;
......@@ -96,6 +148,9 @@ class VariableWrapper {
framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR};
framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32};
std::weak_ptr<VariableWrapper> grad_var_;
std::weak_ptr<GradOpNode> grad_node_;
};
} // namespace imperative
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace operators {
......
......@@ -109,31 +109,29 @@ class MinusGradMaker : public imperative::GradOpBaseMakerBase {
public:
using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const override {
std::vector<std::shared_ptr<imperative::OpBase>> ops;
std::shared_ptr<imperative::GradOpNode> operator()() const override {
auto x_g = this->InputGrad("X");
auto y_g = this->InputGrad("Y");
auto node = this->NewGradNode();
if (!x_g.empty()) {
auto x_g_op = CreateOp();
imperative::TracedGradOp op(x_g_op);
imperative::TracedGradOp op(node);
op.SetType("scale");
op.SetInput("X", this->OutputGrad("Out"));
op.SetOutput("Out", x_g);
op.SetAttr("scale", 1.0f);
ops.emplace_back(x_g_op);
}
auto y_g = this->InputGrad("Y");
if (!y_g.empty()) {
auto y_g_op = CreateOp();
imperative::TracedGradOp op(y_g_op);
imperative::TracedGradOp op(node);
op.SetType("scale");
op.SetInput("X", this->OutputGrad("Out"));
op.SetOutput("Out", y_g);
op.SetAttr("scale", -1.0f);
ops.emplace_back(y_g_op);
}
return ops;
return node;
}
};
......
......@@ -64,21 +64,22 @@ class ReduceMeanDoubleGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
public:
using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const override {
std::vector<std::shared_ptr<imperative::OpBase>> ops;
auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx
std::shared_ptr<imperative::GradOpNode> operator()() const override {
auto out_grads = InputGrad(framework::GradVarName("Out"));
if (!out_grads.empty()) {
auto out_grad_op = CreateOp();
imperative::TracedGradOp op(out_grad_op);
op.SetType("reduce_mean");
op.SetInput("X", x_gg);
op.SetAttrMap(Attrs());
op.SetOutput("Out", out_grads);
ops.emplace_back(out_grad_op);
auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx
auto node = this->NewGradNode();
{
imperative::TracedGradOp op(node);
op.SetType("reduce_mean");
op.SetInput("X", x_gg);
op.SetAttrMap(Attrs());
op.SetOutput("Out", out_grads);
}
return node;
} else {
return nullptr;
}
return ops;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ReduceMeanGradNoNeedBufferVarInference,
......
......@@ -272,25 +272,25 @@ class SumGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
public:
using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const override {
std::shared_ptr<imperative::GradOpNode> operator()() const override {
auto x_grads = InputGrad("X", false);
using InputGradsType = decltype(x_grads);
std::vector<std::shared_ptr<imperative::OpBase>> grad_ops;
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::shared_ptr<imperative::VarBase>& x_grad) {
auto grad_op = CreateOp();
imperative::TracedGradOp op(grad_op);
op.SetType("scale");
op.SetInput("X", og);
op.SetOutput("Out", InputGradsType{x_grad});
op.SetAttr("scale", 1.0f);
return grad_op;
});
return grad_ops;
if (!x_grads.empty()) {
auto node = this->NewGradNode();
node->reserve(x_grads.size());
auto og = OutputGrad("Out");
for (auto& x_grad : x_grads) {
imperative::TracedGradOp op(node);
op.SetType("scale");
op.SetInput("X", og);
op.SetOutput("Out", InputGradsType{x_grad});
op.SetAttr("scale", 1.0f);
}
return node;
} else {
return nullptr;
}
}
};
......
此差异已折叠。
......@@ -256,7 +256,9 @@ void BindOpDesc(pybind11::module *m) {
.def("set_is_target", &pd::OpDesc::SetIsTarget)
.def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference);
pybind11::return_value_policy::reference)
.def("inputs", &pd::OpDesc::Inputs)
.def("outputs", &pd::OpDesc::Outputs);
}
} // namespace pybind
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册