未验证 提交 d33c4343 编写于 作者: Z Zeng Jinle 提交者: GitHub

Imperative tracer refactoring (#22457)

* refine grad maker, test=develop

* refactor tracer stage 1, test=develop

* merge develop to solve conflict third times, test=develop
上级 08a772cb
......@@ -242,10 +242,11 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
"GradOpBaseMaker of %s has been registered", op_type));
info->dygraph_grad_op_maker_ = [](
const imperative::OpBase* fw_op_base,
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out) {
T maker(fw_op_base, var_base_map_in, var_base_map_out);
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs) {
T maker(type, var_base_map_in, var_base_map_out, attrs);
return maker();
};
}
......
......@@ -28,6 +28,26 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace details {
template <typename T>
struct GradOpPtrTrait {};
template <>
struct GradOpPtrTrait<OpDesc> {
using Type = OpDesc*;
};
template <>
struct GradOpPtrTrait<imperative::OpBase> {
using Type = imperative::TracedGradOp*;
};
} // namespace details
template <typename T>
using GradOpPtr = typename details::GradOpPtrTrait<T>::Type;
/*
This functor class is responsible for creating the gradient ops for the given
operator fwd_op. After it is called (through operator()), the pairs of
......@@ -47,6 +67,10 @@ class GradOpDescMakerBase {
grad_to_var_(grad_to_var),
grad_block_(grad_block) {}
static std::unique_ptr<OpDesc> CreateOp() {
return std::unique_ptr<OpDesc>(new OpDesc());
}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDesc>> operator()() const = 0;
......@@ -100,7 +124,13 @@ class GradOpDescMakerBase {
return ret_val;
}
std::vector<std::string> Empty() const { return {}; }
static std::vector<std::string> EmptyInput() { return {}; }
static std::vector<std::string> EmptyOutput() { return {}; }
static std::vector<std::string> EmptyInputGrad() { return {}; }
static std::vector<std::string> EmptyOutputGrad() { return {}; }
std::vector<std::string> InputNames() const {
return this->fwd_op_.InputNames();
......@@ -155,16 +185,7 @@ class GradOpDescMakerBase {
};
template <typename T>
class SingleGradOpMaker {
public:
std::vector<std::unique_ptr<T>> operator()() const {
PADDLE_ENFORCE(false, "should not call this function");
return {};
}
protected:
virtual std::unique_ptr<T> Apply() const = 0;
};
class SingleGradOpMaker {};
template <>
class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
......@@ -173,12 +194,13 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
std::vector<std::unique_ptr<OpDesc>> operator()() const {
std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(this->Apply());
retv.emplace_back(new OpDesc());
this->Apply(retv.front().get());
return retv;
}
protected:
virtual std::unique_ptr<OpDesc> Apply() const = 0;
virtual void Apply(GradOpPtr<OpDesc> op) const = 0;
};
template <>
......@@ -187,16 +209,18 @@ class SingleGradOpMaker<imperative::OpBase>
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
public:
std::vector<std::unique_ptr<imperative::OpBase>> operator()() const {
std::vector<std::unique_ptr<imperative::OpBase>> retv;
retv.emplace_back(this->Apply());
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const {
std::vector<std::shared_ptr<imperative::OpBase>> retv{
std::make_shared<imperative::OpBase>()};
{
imperative::TracedGradOp grad_op(retv.front());
this->Apply(&grad_op);
}
return retv;
}
protected:
virtual std::unique_ptr<imperative::OpBase> Apply() const = 0;
virtual void Apply(GradOpPtr<imperative::OpBase> op) const = 0;
};
template <typename T, bool DropEmptyIG = true>
......@@ -205,8 +229,7 @@ class DefaultGradOpMaker final : public SingleGradOpMaker<T> {
using SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const final {
auto* grad = new T();
void Apply(GradOpPtr<T> grad) const final {
grad->SetType(this->ForwardOpType() + "_grad");
for (auto& input_param : this->InputNames()) {
......@@ -221,19 +244,11 @@ class DefaultGradOpMaker final : public SingleGradOpMaker<T> {
}
grad->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad);
}
};
template <typename T>
class EmptyGradOpMaker {
public:
virtual std::vector<std::unique_ptr<T>> operator()()
const final { /* NOLINT */
return {};
}
};
class EmptyGradOpMaker {};
template <>
class EmptyGradOpMaker<OpDesc> final : public GradOpDescMakerBase {
......@@ -247,10 +262,18 @@ class EmptyGradOpMaker<imperative::OpBase> final
: public imperative::GradOpBaseMakerBase {
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::unique_ptr<imperative::OpBase>> operator()() const final {
std::vector<std::shared_ptr<imperative::OpBase>> operator()() const final {
return {};
}
};
} // namespace framework
namespace operators {
template <typename T>
using GradOpPtr = framework::GradOpPtr<T>;
} // namespace operators
} // namespace paddle
......@@ -45,8 +45,9 @@ bool StaticGraphInferNoNeedBufferVarsContext::HasOutput(
}
DyGraphInferNoNeedBufferVarsContext::DyGraphInferNoNeedBufferVarsContext(
const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs, const AttributeMap &attrs)
const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs)
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
bool DyGraphInferNoNeedBufferVarsContext::HasOutput(
......
......@@ -56,15 +56,16 @@ class StaticGraphInferNoNeedBufferVarsContext final
class DyGraphInferNoNeedBufferVarsContext final
: public InferNoNeedBufferVarsContext {
public:
DyGraphInferNoNeedBufferVarsContext(const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs,
const AttributeMap &attr);
DyGraphInferNoNeedBufferVarsContext(
const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs);
bool HasOutput(const std::string &slot) const final;
private:
const imperative::NameVarBaseMap &inputs_;
const imperative::NameVarBaseMap &outputs_;
const imperative::NameVarMap<imperative::VariableWrapper> &inputs_;
const imperative::NameVarMap<imperative::VariableWrapper> &outputs_;
};
class NoNeedBufferVarsInference {
......@@ -106,8 +107,8 @@ class InferNoNeedBufferVarsFN {
}
inline const std::unordered_set<std::string> &operator()(
const imperative::NameVarBaseMap &inputs,
const imperative::NameVarBaseMap &outputs,
const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_);
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
......
......@@ -35,10 +35,10 @@ TEST(test_no_need_buffer_vars_inference, test_static_graph) {
TEST(test_no_need_buffer_vars_inference, test_dygraph) {
AttributeMap attrs{{"is_test", true}};
imperative::NameVarBaseMap inputs;
imperative::NameVarBaseMap outputs;
imperative::NameVarMap<imperative::VariableWrapper> inputs;
imperative::NameVarMap<imperative::VariableWrapper> outputs;
outputs["Out"].emplace_back(nullptr);
outputs["Out"].emplace_back(new imperative::VarBase("tmp_0"));
outputs["Out"].emplace_back(new imperative::VariableWrapper("tmp_0"));
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
......
......@@ -1310,8 +1310,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
for (auto& input : ctx.InNameList()) {
ParseInputDataType(ctx, input, &data_type);
}
PADDLE_ENFORCE_NE(data_type, dafault_data_type,
"DataType should be indicated by input Variable.");
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
platform::errors::NotFound(
"DataType should be indicated by input Variable at %s.", Type()));
return data_type;
}
......
......@@ -56,10 +56,11 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>;
using DygraphGradOpMakerFN =
std::function<std::vector<std::unique_ptr<imperative::OpBase>>(
const imperative::OpBase* fw_op_base,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out)>;
std::function<std::vector<std::shared_ptr<imperative::OpBase>>(
const std::string& /*op_type*/,
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/)>;
using InferVarTypeFN =
std::function<void(framework::InferVarTypeContext* /*context*/)>;
......
......@@ -17,6 +17,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/imperative/layer.h"
......@@ -27,37 +28,70 @@
namespace paddle {
namespace imperative {
enum TracedVarRole { kForward = 0, kBackward = 1 };
template <typename T, TracedVarRole kRole>
class TracedVarList : public std::vector<std::shared_ptr<T>> {
private:
using BaseClass = std::vector<std::shared_ptr<T>>;
public:
using BaseClass::BaseClass;
};
class GradOpBaseMakerBase {
public:
explicit GradOpBaseMakerBase(const OpBase* fw_op_base,
explicit GradOpBaseMakerBase(const std::string& type,
const NameVarBaseMap& var_base_map_in,
const NameVarBaseMap& var_base_map_out)
: fw_op_base_(fw_op_base),
const NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs)
: type_(type),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out) {}
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
virtual ~GradOpBaseMakerBase() = default;
virtual std::vector<std::unique_ptr<OpBase>> operator()() const = 0;
virtual std::vector<std::shared_ptr<OpBase>> operator()() const = 0;
static std::shared_ptr<OpBase> CreateOp() {
return std::make_shared<OpBase>();
}
std::vector<std::shared_ptr<VarBase>> InputGrad(
TracedVarList<VarBase, TracedVarRole::kBackward> InputGrad(
const std::string& name, bool drop_empty_grad = true) const {
return GetVarBaseList(name, true, true);
return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/true);
}
std::vector<std::shared_ptr<VarBase>> OutputGrad(
TracedVarList<VarBase, TracedVarRole::kBackward> OutputGrad(
const std::string& name) const {
return GetVarBaseList(name, true, false);
return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/false);
}
TracedVarList<VarBase, TracedVarRole::kForward> Input(
const std::string& name) const {
return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/true);
}
TracedVarList<VarBase, TracedVarRole::kForward> Output(
const std::string& name) const {
return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/false);
}
static TracedVarList<VarBase, TracedVarRole::kForward> EmptyInput() {
return {};
}
std::vector<std::shared_ptr<VarBase>> Input(const std::string name) const {
return GetVarBaseList(name, false, true);
static TracedVarList<VarBase, TracedVarRole::kForward> EmptyOutput() {
return {};
}
std::vector<std::shared_ptr<VarBase>> Output(const std::string& name) const {
return GetVarBaseList(name, false, false);
static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyOutputGrad() {
return {};
}
std::vector<std::shared_ptr<VarBase>> Empty() const { return {}; }
static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyInputGrad() {
return {};
}
std::vector<std::string> InputNames() const {
std::vector<std::string> vec_temp;
......@@ -65,7 +99,6 @@ class GradOpBaseMakerBase {
for (auto& it : var_base_map_in_) {
vec_temp.emplace_back(it.first);
}
return vec_temp;
}
......@@ -75,21 +108,17 @@ class GradOpBaseMakerBase {
for (auto& it : var_base_map_out_) {
vec_temp.emplace_back(it.first);
}
return vec_temp;
}
const std::unordered_map<std::string, framework::Attribute>& Attrs() const {
return fw_op_base_->Attrs();
}
const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::Attribute& GetAttr(const std::string& name) const {
auto& map = fw_op_base_->Attrs();
auto it = map.find(name);
PADDLE_ENFORCE(it != map.end(),
"Cannot find attribute [%s] in operator [%s]", name,
fw_op_base_->Type());
auto it = attrs_.find(name);
PADDLE_ENFORCE_EQ(
it != attrs_.end(), true,
platform::errors::NotFound(
"Cannot find attribute [%s] in operator [%s]", name, type_));
return it->second;
}
......@@ -98,41 +127,37 @@ class GradOpBaseMakerBase {
return boost::get<T>(GetAttr(name));
}
std::string ForwardOpType() const { return fw_op_base_->Type(); }
const std::string& ForwardOpType() const { return type_; }
protected:
bool HasInput(const std::string& name) const {
auto it = var_base_map_in_.find(name);
return it != var_base_map_in_.end();
return var_base_map_in_.count(name) > 0;
}
bool HasOutput(const std::string name) const {
auto it = var_base_map_out_.find(name);
return it != var_base_map_out_.end();
bool HasOutput(const std::string& name) const {
return var_base_map_out_.count(name) > 0;
}
private:
std::vector<std::shared_ptr<VarBase>> GetVarBaseList(const std::string& name,
bool is_grad,
bool is_input) const {
const NameVarBaseMap& data_map =
is_input ? var_base_map_in_ : var_base_map_out_;
template <TracedVarRole kRole>
TracedVarList<VarBase, kRole> GetVarBaseList(const std::string& name,
bool is_input) const {
const auto& data_map = is_input ? var_base_map_in_ : var_base_map_out_;
auto iterator = data_map.find(name);
std::vector<std::shared_ptr<imperative::VarBase>> vec_temp;
TracedVarList<VarBase, kRole> vec_temp;
if (iterator != data_map.end()) {
vec_temp.reserve(iterator->second.size());
for (auto& var_base_temp : iterator->second) {
if (is_grad) {
if (kRole == TracedVarRole::kBackward) {
if (!var_base_temp->HasGradVar()) {
VLOG(6) << "GradVarBase of var " << var_base_temp->Name()
<< " in OP " << fw_op_base_->Type() << " is null";
<< " in OP " << type_ << " is null";
var_base_temp->MutableGradVarBase();
}
auto grad_var_base_tmp = var_base_temp->GradVarBase();
if (!is_input) {
auto* tensor = grad_var_base_tmp->MutableVar()
->GetMutable<framework::LoDTensor>();
......@@ -150,12 +175,91 @@ class GradOpBaseMakerBase {
}
private:
const OpBase* fw_op_base_;
const std::string& type_;
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
};
protected:
std::vector<framework::BlockDesc*> grad_block_;
class TracedGradOp {
DISABLE_COPY_AND_ASSIGN(TracedGradOp);
public:
explicit TracedGradOp(const std::shared_ptr<OpBase>& op) : op_(op) {}
~TracedGradOp() {
op_->SetGradPendingOps(
{grad_pending_ops_.begin(), grad_pending_ops_.end()});
op_->CheckAttrs();
}
template <TracedVarRole kRole>
void SetInput(const std::string& name,
const TracedVarList<VarBase, kRole>& vars) {
if (kRole == TracedVarRole::kBackward) {
for (auto& var : vars) {
var->AddGradOp(op_);
}
}
op_->SetInput(name, ToVarWrapperList(vars));
}
template <TracedVarRole kRole>
void SetOutput(const std::string& name,
const TracedVarList<VarBase, kRole>& vars) {
if (kRole == TracedVarRole::kBackward) {
if (vars.size() == 1 && vars.front()->OverridedStopGradient()) {
op_->SetOutput(name, VariableWrapperList{});
return;
} else {
for (auto& var : vars) {
if (!var->OverridedStopGradient()) {
for (auto& op : var->GradOps()) {
grad_pending_ops_.emplace(op);
}
}
}
}
}
op_->SetOutput(name, ToVarWrapperList(vars));
}
void SetType(const std::string& type) { op_->SetType(type); }
void SetAttrMap(const framework::AttributeMap& attrs) {
return op_->SetAttrMap(attrs);
}
void SetAttr(const std::string& name, const framework::Attribute& v) {
op_->SetAttr(name, v);
}
bool HasAttr(const std::string& name) const { return op_->HasAttr(name); }
const framework::Attribute& GetAttr(const std::string& name) const {
return op_->GetAttr(name);
}
template <typename T>
inline const T& Attr(const std::string& name) const {
return op_->Attr<T>(name);
}
private:
static std::vector<std::shared_ptr<VariableWrapper>> ToVarWrapperList(
const std::vector<std::shared_ptr<VarBase>>& vars) {
std::vector<std::shared_ptr<VariableWrapper>> result;
result.reserve(vars.size());
for (auto& var : vars) {
result.emplace_back(var->SharedVar());
}
return result;
}
private:
const std::shared_ptr<OpBase>& op_;
std::unordered_set<std::shared_ptr<OpBase>> grad_pending_ops_;
};
} // namespace imperative
......
......@@ -30,16 +30,9 @@
namespace paddle {
namespace imperative {
void Engine::RunOp(paddle::imperative::OpBase* op,
const paddle::imperative::NameVarBaseMap& ins,
const paddle::imperative::NameVarBaseMap& outs,
const paddle::platform::Place& place) {
op->Run(ins, outs);
}
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
backward_strategy_ = strategy;
const std::vector<OpBase*> ops = var->GradVarBase()->GradOps();
const auto& ops = var->GradVarBase()->GradOps();
var->ClearGradOps();
if (ops.empty() || var->OverridedStopGradient()) {
......@@ -59,7 +52,9 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
return;
}
}
init_ops_ = ops;
var->GradVarBase()->ClearGradOps();
VLOG(3) << "start backward";
PADDLE_ENFORCE_EQ(var->HasGradVar(), true,
......@@ -71,7 +66,6 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
VLOG(6) << "init loss grad:" << var->GradVarBase()->Name()
<< " as stop_gradient false";
var->GradVarBase()->InnerSetOverridedStopGradient(false);
var->GradVarBase()->SetGradGenerated(true);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place());
grad_var->Resize(fwd_var.dims());
grad_var->mutable_data(fwd_var.place(), fwd_var.type());
......@@ -81,35 +75,29 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
void BasicEngine::CheckBackwardInputs(OpBase* op) {
for (auto& pair : op->GetInsMap()) {
for (auto& var : pair.second) {
if (var && IsGrad(var.get())) {
// if grad var has OverridedStopGradient skip this Op
if (!var->GradGenerated()) {
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(op->place());
auto* tensor = var->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->mutable_data(op->place(), var->DataType());
operators::math::set_constant(*dev_ctx, tensor, 0.0);
} else {
continue;
}
if (!var || op->IsAllowedEmptyVar(var.get())) {
continue;
}
}
}
}
void BasicEngine::SetBackwardOutputs(paddle::imperative::OpBase* op) {
for (auto& pair : op->GetOutsMap()) {
for (auto& var : pair.second) {
if (var) {
// Set Backward outputs's generate_grad as true
var->SetGradGenerated(true);
VLOG(6) << "Set backward output: " << var->Name()
<< "'s SetGeneratedGrad as True";
auto* inner_var = var->MutableVar();
framework::Tensor* tensor = nullptr;
if (!inner_var->IsInitialized() ||
inner_var->IsType<framework::LoDTensor>()) {
tensor = inner_var->GetMutable<framework::LoDTensor>();
}
if (tensor && !tensor->IsInitialized()) {
// if grad var has OverridedStopGradient skip this Op
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(op->place());
tensor->mutable_data(op->place(), var->DataType());
operators::math::set_constant(*dev_ctx, tensor, 0.0);
}
}
}
}
void BasicEngine::PrepareGradAccumulators(OpBase* op) {
for (const auto& pair : op->GetOutsMap()) {
for (const auto& var : pair.second) {
......@@ -140,50 +128,63 @@ void BasicEngine::PrepareDeps() {
std::queue<OpBase*> q;
std::unordered_set<OpBase*> visited;
for (const auto& init_op : init_ops_) {
q.push(init_op);
visited.insert(init_op);
q.push(init_op.get());
visited.insert(init_op.get());
}
while (!q.empty()) {
auto* cur_op = q.front();
q.pop();
VLOG(3) << "Checking grads of op " << cur_op->Type();
SetBackwardOutputs(cur_op);
PADDLE_ENFORCE_NE(
cur_op->GetInsMap().empty() && cur_op->GetOutsMap().empty(), true,
platform::errors::NotFound(
"Inputs and outputs of %s do not exist. "
"This may be because you call \"backward()\" twice for the same "
"subgraph. Please try to call \"stop_gradient = True\" or "
"\"detach()\" if you use some same vars between two \"backward()\" "
"calls.",
cur_op->Type()));
PrepareGradAccumulators(cur_op);
auto& grad_pending_ops = cur_op->GradPendingOps();
for (auto* grad_pending_op : grad_pending_ops) {
const auto& grad_pending_ops = cur_op->GradPendingOps();
for (auto& grad_pending_op : grad_pending_ops) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
++op_deps_[grad_pending_op];
if (visited.count(grad_pending_op) == 0) {
visited.insert(grad_pending_op);
q.push(grad_pending_op);
++op_deps_[grad_pending_op.get()];
if (visited.count(grad_pending_op.get()) == 0) {
visited.insert(grad_pending_op.get());
q.push(grad_pending_op.get());
}
}
}
}
void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VarBase> src,
VarBase* dst) {
void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VariableWrapper> src,
VariableWrapper* dst) {
auto iter = accumulators_.find(dst);
PADDLE_ENFORCE_EQ(iter != accumulators_.end(), true,
"Cannot find gradient of variable %s", dst->Name());
iter->second->Add(std::move(src), op->id());
}
void BasicEngine::Execute() {
PrepareDeps();
// Start execute Computation graph
std::queue<OpBase*> q;
std::queue<std::shared_ptr<OpBase>> q;
for (const auto& init_op : init_ops_) {
q.push(init_op);
q.push(std::move(init_op));
}
size_t op_num = 0;
while (!q.empty()) {
OpBase* cur_op = q.front();
auto shared_cur_op = std::move(q.front());
q.pop();
auto* cur_op = shared_cur_op.get();
++op_num;
// CheckBackWardInput
CheckBackwardInputs(cur_op);
......@@ -191,26 +192,28 @@ void BasicEngine::Execute() {
auto& bwd_ins = cur_op->GetInsMap();
auto& bwd_outs = cur_op->GetOutsMap();
NameVarBaseMap tmp_outs(bwd_outs);
NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
// 1. construct the output map 2. replace the element in the map
// A var may be coresponding to several grad var in one op
for (auto it = tmp_outs.begin(); it != tmp_outs.end(); ++it) {
for (size_t i = 0; i < it->second.size(); ++i) {
auto tmp_var =
std::make_shared<VarBase>(false, "Gtmp@"); // Do not need grad
std::make_shared<VariableWrapper>("Gtmp@"); // Do not need grad
auto var = it->second[i];
it->second[i] = tmp_var;
if (var) {
need_accu_var_list_.emplace_back(
make_pair(var.get(), std::move(tmp_var)));
var->ClearGradOps();
need_accu_var_list_.emplace_back(var.get(), std::move(tmp_var));
}
}
}
VLOG(3) << "Start to execute grad op " << cur_op->Type();
RunOp(cur_op, bwd_ins, tmp_outs, cur_op->place());
{
VLOG(3) << "Start to execute grad op " << cur_op->Type();
OpBase::Run(cur_op->InnerOp(), bwd_ins, tmp_outs, cur_op->Attrs(),
cur_op->place());
}
// Step 2: Sum Gradient
if (need_accu_var_list_.size() > 0) {
......@@ -223,9 +226,9 @@ void BasicEngine::Execute() {
// Step 3: Collect ready ops
for (auto* grad_pending_op : cur_op->GradPendingOps()) {
for (auto& grad_pending_op : cur_op->GradPendingOps()) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
auto iter = op_deps_.find(grad_pending_op);
auto iter = op_deps_.find(grad_pending_op.get());
if (iter == op_deps_.end()) {
continue;
}
......@@ -242,10 +245,11 @@ void BasicEngine::Execute() {
// Step 4: Delete op to collect unused variables
VLOG(3) << "Remove op after op " << cur_op->Type() << " runs";
RemoveOp(cur_op);
cur_op->ClearBackwardTrace();
}
VLOG(3) << "Clean properties of BasicEngine";
CleanEngine();
Clear();
VLOG(1) << "Backward op number: " << op_num;
}
} // namespace imperative
} // namespace paddle
......@@ -37,49 +37,12 @@ class Engine {
virtual ~Engine() = default;
virtual void Execute() = 0;
virtual void Init(VarBase* var, const detail::BackwardStrategy& strategy) = 0;
virtual void RunOp(imperative::OpBase* op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const platform::Place& place);
virtual void RemoveOp(OpBase* op) {
PADDLE_ENFORCE_NOT_NULL(op, "Cannot remove null op");
auto iter = grad_ops_.find(op);
PADDLE_ENFORCE_EQ(iter != grad_ops_.end(), true, "Op is not inside tracer");
grad_ops_.erase(iter);
}
void InsertOp(OpBase* op, std::shared_ptr<OpBase> op_shared) {
grad_ops_[op] = std::move(op_shared);
}
const std::unordered_set<VarBase*>& GradVars() const { return grad_vars_; }
const std::unordered_map<OpBase*, std::shared_ptr<OpBase>>& GradOps() const {
return grad_ops_;
}
void InsertGradVar(VarBase* grad) { grad_vars_.emplace(grad); }
bool IsGrad(VarBase* var) { return grad_vars_.count(var) > 0; }
void Clear() {
grad_ops_.clear();
grad_vars_.clear();
}
private:
std::unordered_map<OpBase*, std::shared_ptr<OpBase>>
grad_ops_; // opBase for remove - grad_op
std::unordered_set<VarBase*> grad_vars_;
};
class BasicEngine : public Engine {
public:
BasicEngine() = default;
void Init(VarBase* var, const detail::BackwardStrategy& strategy) override;
~BasicEngine() override = default;
void Execute() override;
private:
......@@ -87,28 +50,26 @@ class BasicEngine : public Engine {
void CheckBackwardInputs(OpBase* op);
void SetBackwardOutputs(OpBase* op);
void PrepareGradAccumulators(OpBase* op);
void SumGradient(OpBase* op, std::shared_ptr<VarBase> src, VarBase* dst);
void SumGradient(OpBase* op, std::shared_ptr<VariableWrapper> src,
VariableWrapper* dst);
// TODO(jiabin): maybe we can optimize the performance of engine by cache the
// result
void CleanEngine() {
void Clear() {
init_ops_.clear();
op_deps_.clear();
accumulators_.clear();
Clear();
}
std::vector<OpBase*> init_ops_;
std::vector<std::shared_ptr<OpBase>> init_ops_;
detail::BackwardStrategy backward_strategy_;
std::unordered_map<OpBase*, size_t> op_deps_;
std::unordered_map<VarBase*, std::unique_ptr<GradientAccumulator>>
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_;
std::vector<std::pair<VarBase*, std::shared_ptr<VarBase>>>
std::vector<std::pair<VariableWrapper*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_;
};
......
......@@ -144,8 +144,8 @@ void SelectedRowsAddToTensor(const framework::Variable& src,
// Note(chenweihang): when two selected rows need to be added,
// adding one to another is not equal to merging two selected rows
// to one then add it to a empty selected rows, the after is correct
std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1,
const framework::Variable& src2) {
std::shared_ptr<VariableWrapper> SelectedRowsMerge(
const framework::Variable& src1, const framework::Variable& src2) {
auto& src_selected_rows1 = src1.Get<framework::SelectedRows>();
auto& src_selected_rows2 = src2.Get<framework::SelectedRows>();
auto place = src_selected_rows1.value().place();
......@@ -155,7 +155,7 @@ std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1,
std::vector<const framework::SelectedRows*> src_selected_rows;
src_selected_rows.emplace_back(&src_selected_rows1);
src_selected_rows.emplace_back(&src_selected_rows2);
auto dst_var = std::make_shared<VarBase>(false, "Temp");
auto dst_var = std::make_shared<VariableWrapper>("Temp");
auto* dst_selected_rows =
dst_var->MutableVar()->GetMutable<framework::SelectedRows>();
......@@ -188,7 +188,8 @@ std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1,
framework::DataTypeToString(data_type)));
}
void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) {
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
VariableWrapper* var_) {
auto& src = var->Var();
auto* dst = var_->MutableVar();
if (dst->IsType<framework::LoDTensor>()) {
......@@ -208,7 +209,7 @@ void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) {
*dst = std::move(*(var->MutableVar()));
var_->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (src.IsType<framework::SelectedRows>()) {
std::shared_ptr<VarBase> temp = SelectedRowsMerge(src, *dst);
auto temp = SelectedRowsMerge(src, *dst);
*dst = std::move(*(temp->MutableVar()));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -218,7 +219,8 @@ void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) {
}
}
platform::Place GetPlaceOfVarBase(const std::shared_ptr<VarBase>& var) {
static platform::Place GetPlaceOfVar(
const std::shared_ptr<VariableWrapper>& var) {
platform::Place place;
if (var->Var().IsType<framework::LoDTensor>()) {
place = var->Var().Get<framework::LoDTensor>().place();
......@@ -231,10 +233,10 @@ platform::Place GetPlaceOfVarBase(const std::shared_ptr<VarBase>& var) {
return place;
}
void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
void EagerGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
size_t trace_id) {
auto* dst_var = var_->MutableVar();
platform::Place place = GetPlaceOfVarBase(var);
platform::Place place = GetPlaceOfVar(var);
if (!var_->OverridedStopGradient()) {
VLOG(3) << "Sum Gradient for: " << var_->Name();
if (cur_cnt_ == 0) {
......@@ -243,7 +245,7 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
}
*dst_var = std::move(*(var->MutableVar()));
} else {
VarBaseAdd(var, var_);
VariableWrapperAdd(var, var_);
}
} else {
if (!var_->Var().IsInitialized() ||
......@@ -268,10 +270,10 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
++cur_cnt_;
}
void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
size_t trace_id) {
auto* dst_var = var_->MutableVar();
platform::Place place = GetPlaceOfVarBase(var);
platform::Place place = GetPlaceOfVar(var);
if (!var_->OverridedStopGradient()) {
if (ref_cnt_ == 1) {
if (var->Var().IsType<framework::SelectedRows>()) {
......@@ -291,11 +293,12 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
return;
}
std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
[](const std::pair<std::shared_ptr<VarBase>, size_t>& p1,
const std::pair<std::shared_ptr<VarBase>, size_t>& p2) {
return p1.second > p2.second;
});
std::sort(
tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
[](const std::pair<std::shared_ptr<VariableWrapper>, size_t>& p1,
const std::pair<std::shared_ptr<VariableWrapper>, size_t>& p2) {
return p1.second > p2.second;
});
#ifdef PADDLE_WITH_CUDA
if (paddle::platform::is_gpu_place(place)) {
......@@ -310,7 +313,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
var_->SetType(framework::proto::VarType::SELECTED_ROWS);
*dst_var = std::move(*(tmp_grad_vars_[i].first->MutableVar()));
} else {
VarBaseAdd(tmp_grad_vars_[i].first, var_);
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
}
}
}
......@@ -321,7 +324,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
}
if (tmp_grad_vars_[i].first->Var().IsType<framework::LoDTensor>()) {
VarBaseAdd(tmp_grad_vars_[i].first, var_);
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
}
}
} else {
......@@ -333,7 +336,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
}
for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) {
VarBaseAdd(tmp_grad_vars_[i].first, var_);
VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
}
#ifdef PADDLE_WITH_CUDA
}
......
......@@ -24,9 +24,9 @@ namespace imperative {
class GradientAccumulator {
public:
explicit GradientAccumulator(VarBase* var) : var_(var) {}
explicit GradientAccumulator(VariableWrapper* var) : var_(var) {}
virtual void Add(std::shared_ptr<VarBase> var, size_t trace_id) = 0;
virtual void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) = 0;
virtual ~GradientAccumulator() = default;
......@@ -35,7 +35,7 @@ class GradientAccumulator {
inline size_t RefCnt() const { return ref_cnt_; }
protected:
VarBase* var_;
VariableWrapper* var_;
size_t ref_cnt_{0};
};
......@@ -43,7 +43,7 @@ class EagerGradientAccumulator : public GradientAccumulator {
public:
using GradientAccumulator::GradientAccumulator;
void Add(std::shared_ptr<VarBase> var, size_t trace_id) override;
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) override;
private:
size_t cur_cnt_{0};
......@@ -53,10 +53,11 @@ class SortedGradientAccumulator : public GradientAccumulator {
public:
using GradientAccumulator::GradientAccumulator;
void Add(std::shared_ptr<VarBase> var, size_t trace_id) override;
void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) override;
private:
std::vector<std::pair<std::shared_ptr<VarBase>, size_t>> tmp_grad_vars_;
std::vector<std::pair<std::shared_ptr<VariableWrapper>, size_t>>
tmp_grad_vars_;
};
} // namespace imperative
......
......@@ -113,9 +113,10 @@ static framework::RuntimeContext PrepareRuntimeContext(
return framework::RuntimeContext(std::move(inputs), std::move(outputs));
}
template <typename VarType>
static std::string DebugString(
const std::string& name,
const std::vector<std::shared_ptr<VarBase>>& vars) {
const std::vector<std::shared_ptr<VarType>>& vars) {
std::stringstream ss;
ss << name << "{";
......@@ -127,7 +128,7 @@ static std::string DebugString(
continue;
}
ss << vars[i]->Name() << "[";
auto& var = vars[i]->Var();
const framework::Variable& var = vars[i]->Var();
if (!var.IsInitialized()) {
ss << "NOT_INITED_VAR";
} else if (var.IsType<framework::LoDTensor>()) {
......@@ -167,9 +168,10 @@ static std::string DebugString(
return ss.str();
}
std::string LayerDebugString(const std::string& op_type,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs) {
template <typename VarType>
static std::string LayerDebugStringImpl(const std::string& op_type,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs) {
std::stringstream ss;
ss << "Op(" << op_type << "): ";
......@@ -192,28 +194,30 @@ std::string LayerDebugString(const std::string& op_type,
return ss.str();
}
void VarBase::AddGradOps(const std::weak_ptr<OpBase>& op) {
if (op.lock() == nullptr) {
return;
}
for (const auto& cur_op : grad_ops_) {
if (cur_op.lock() == op.lock()) {
return;
}
}
grad_ops_.emplace_back(op);
std::string LayerDebugString(const std::string& op_type,
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs) {
return LayerDebugStringImpl<VarBase>(op_type, ins, outs);
}
std::string LayerDebugString(const std::string& op_type,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs) {
return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs);
}
void VarBase::ClearGradient() {
if (grad_var_) {
if (grad_var_->var_.IsType<framework::SelectedRows>()) {
auto* grad_t = grad_var_->var_.GetMutable<framework::SelectedRows>();
if (grad_var_->Var().IsType<framework::SelectedRows>()) {
auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::SelectedRows>();
if (grad_t->mutable_value()->IsInitialized()) {
grad_t->mutable_rows()->clear();
grad_t->mutable_value()->clear();
}
} else {
auto* grad_t = grad_var_->var_.GetMutable<framework::LoDTensor>();
auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
if (grad_t->IsInitialized()) {
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place());
......@@ -226,19 +230,20 @@ void VarBase::ClearGradient() {
std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
const bool blocking) const {
PADDLE_ENFORCE_EQ(
var_.IsInitialized() && (var_.IsType<framework::LoDTensor>() ||
var_.IsType<framework::SelectedRows>()),
Var().IsInitialized() && (Var().IsType<framework::LoDTensor>() ||
Var().IsType<framework::SelectedRows>()),
true, platform::errors::InvalidArgument(
"Variable is not initialized or Variable's type is not "
"LoDTensor or SelectedRows when getting numpy tensor"));
if (var_.IsType<framework::LoDTensor>()) {
auto& src_tensor = var_.Get<framework::LoDTensor>();
if (Var().IsType<framework::LoDTensor>()) {
auto& src_tensor = Var().Get<framework::LoDTensor>();
// TODO(Jiabin): change this after move unique_name generator to CXX
auto new_var = std::make_shared<VarBase>(
true, Name() + std::to_string(copied_counter_++));
auto* dst_tensor = new_var->var_.GetMutable<framework::LoDTensor>();
auto* dst_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod());
new_var->SetPersistable(Persistable());
new_var->SetDataType(DataType());
......@@ -257,12 +262,12 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
}
return new_var;
} else {
auto& src_selected_rows = var_.Get<framework::SelectedRows>();
auto& src_selected_rows = Var().Get<framework::SelectedRows>();
auto new_var = std::make_shared<VarBase>(
false, "Itmp" + std::to_string(copied_counter_++));
new_var->SetType(framework::proto::VarType::SELECTED_ROWS);
auto* dst_selected_rows =
new_var->var_.GetMutable<framework::SelectedRows>();
new_var->MutableVar()->GetMutable<framework::SelectedRows>();
framework::TensorCopy(src_selected_rows.value(), dst_place,
dst_selected_rows->mutable_value());
......@@ -281,39 +286,32 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
return new_var;
}
}
// create OpBase from optype
OpBase::OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place)
: id_(id), place_(place), attrs_(attrs) {
const auto& info = framework::OpInfoMap::Instance().Get(type);
// Step 1: Run forward
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true);
}
void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
VLOG(3) << "Construct Op: " << type << std::endl;
}
void OpBase::CreateOperatorBase() {
const auto& info = framework::OpInfoMap::Instance().Get(type_);
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true);
}
op_ = framework::OpRegistry::CreateOp(type_, {}, {}, {}, false);
void OpBase::ClearBackwardTrace() {
grad_pending_ops_.clear();
allow_empty_vars_.clear();
ins_.clear();
outs_.clear();
}
void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
auto* op_kernel = dynamic_cast<framework::OperatorWithKernel*>(op_.get());
template <typename VarType>
static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
auto& info = op_->Info();
auto& info = op.Info();
if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(ins, &outs, attrs_);
RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, &outs, attrs);
info.infer_var_type_(&infer_var_type_ctx);
}
// Initialize output var type
for (auto& var_pair : outs) {
for (auto& var : var_pair.second) {
......@@ -321,20 +319,29 @@ void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
}
}
VLOG(3) << "Running Op " << Type();
VLOG(5) << LayerDebugString(Type(), ins, outs);
auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place(), &attrs_);
// VLOG(3) << "Running Op " << op.Type();
VLOG(5) << LayerDebugString(op.Type(), ins, outs);
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs);
prepared_op.Run(&ins, &outs, &attrs_);
prepared_op.Run(ins, outs, attrs);
VLOG(4) << LayerDebugString(Type(), ins, outs);
VLOG(4) << LayerDebugString(op.Type(), ins, outs);
}
void OpBase::ClearBackwardTrace() {
grad_pending_ops_.clear();
ins_.clear();
outs_.clear();
void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, place);
}
void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, place);
}
} // namespace imperative
......
......@@ -17,12 +17,13 @@
#include <atomic>
#include <cstdint>
#include <list>
#include <map> // NOLINT
#include <memory> // NOLINT
#include <mutex> // NOLINT
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string> // NOLINT
#include <unordered_map> // NOLINT
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
......@@ -35,6 +36,7 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/flags.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
......@@ -62,26 +64,28 @@ class VarBase {
public:
static std::vector<std::string> AliveVarNames();
explicit VarBase(bool has_grad, const std::string& name)
: name_(name),
: var_(std::make_shared<VariableWrapper>(name)),
grad_var_(has_grad ? new VarBase(false, GradVarName()) : nullptr) {
if (IsDebugEnabled()) {
VLOG(10) << "Construct VarBase: " << name;
name_set_.Insert(name_);
VLOG(10) << "Construct VarBase: " << Name();
name_set_.Insert(Name());
}
}
explicit VarBase(const std::string& name) : VarBase(true, name) {}
~VarBase() {
VLOG(10) << "Destruct VarBase: " << name_;
VLOG(10) << "Destruct VarBase: " << Name();
if (IsDebugEnabled()) {
name_set_.Remove(name_);
name_set_.Remove(Name());
}
}
const framework::Variable& Var() const { return var_; }
const std::shared_ptr<VariableWrapper>& SharedVar() const { return var_; }
framework::Variable* MutableVar() { return &var_; }
const framework::Variable& Var() const { return var_->Var(); }
framework::Variable* MutableVar() { return var_->MutableVar(); }
bool HasGradVar() const { return grad_var_ != nullptr; }
......@@ -94,119 +98,84 @@ class VarBase {
grad_var_ = std::make_shared<VarBase>(false, GradVarName());
// NOTE(zhiqiu): we should keep grad_var_'s stop_gradient property same as
// fwd varbase
grad_var_->SetOverridedStopGradient(overrided_stop_gradient_);
grad_var_->SetOverridedStopGradient(var_->InnerOverridedStopGradient());
}
return grad_var_;
}
const framework::Variable& GradVar() const {
PADDLE_ENFORCE_NOT_NULL(grad_var_, "Gradient of %s does not exist", name_);
return grad_var_->var_;
PADDLE_ENFORCE_NOT_NULL(
grad_var_,
platform::errors::NotFound("Gradient of %s does not exist", Name()));
return grad_var_->Var();
}
framework::Variable* MutableGradVar() {
PADDLE_ENFORCE_NOT_NULL(grad_var_, "Gradient of %s does not exist", name_);
return &(grad_var_->var_);
PADDLE_ENFORCE_NOT_NULL(
grad_var_,
platform::errors::NotFound("Gradient of %s does not exist", Name()));
return grad_var_->MutableVar();
}
// This is used for python api
void SetOverridedStopGradient(bool stop_gradient) {
if (stop_gradient) {
overrided_stop_gradient_ = 1;
} else {
overrided_stop_gradient_ = 0;
}
var_->SetOverridedStopGradient(stop_gradient);
if (grad_var_) {
grad_var_->SetOverridedStopGradient(stop_gradient);
}
}
// This is used for python api
bool OverridedStopGradient() const {
if (overrided_stop_gradient_ == 0) {
return false;
} else {
return true;
}
}
// This is used inside C++
int InnerOverridedStopGradient() const { return overrided_stop_gradient_; }
bool GradGenerated() const { return grad_generated_; }
bool OverridedStopGradient() const { return var_->OverridedStopGradient(); }
void SetGradGenerated(bool generated) { grad_generated_ = generated; }
// This is used inside C++
void InnerSetOverridedStopGradient(bool stop_gradient) {
if (overrided_stop_gradient_ == -1) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient);
if (var_->InnerOverridedStopGradient() == -1) {
var_->InnerSetOverridedStopGradient(stop_gradient);
if (grad_var_) {
grad_var_->InnerSetOverridedStopGradient(stop_gradient);
}
} else {
VLOG(6) << "Ignore Stop gradient conversion for Var: " << Name()
<< "Set value is: " << overrided_stop_gradient_;
}
}
void SetPersistable(bool persistable) { persistable_ = persistable; }
bool Persistable() const { return persistable_; }
void SetPersistable(bool persistable) { var_->SetPersistable(persistable); }
void AddGradOps(const std::weak_ptr<OpBase>& op);
bool Persistable() const { return var_->Persistable(); }
std::vector<OpBase*> GradOps() {
std::vector<OpBase*> rlt;
// TODO(jiabin): use better data structure to remove nullptr when we find it
for (const auto& wk_ptr : grad_ops_) {
OpBase* tmp_op = wk_ptr.lock().get();
if (tmp_op) rlt.emplace_back(tmp_op);
// Only grad var is allowed to call these 2 methods
void AddGradOp(const std::shared_ptr<OpBase>& op) {
if (op &&
std::find(grad_ops_.begin(), grad_ops_.end(), op) == grad_ops_.end()) {
grad_ops_.emplace_back(op);
}
return rlt;
}
const std::vector<std::shared_ptr<OpBase>>& GradOps() const {
return grad_ops_;
}
void ClearGradOps() { grad_ops_.clear(); }
const std::string& Name() const { return name_; }
const std::string& Name() const { return var_->Name(); }
void SetName(const std::string& name) {
name_ = name;
var_->SetName(name);
if (grad_var_) {
grad_var_->SetName(GradVarName());
}
}
std::string GradVarName() { return framework::GradVarName(name_); }
std::string GradVarName() { return framework::GradVarName(Name()); }
void SetType(framework::proto::VarType::Type type) { type_ = type; }
void SetType(framework::proto::VarType::Type type) { var_->SetType(type); }
framework::proto::VarType::Type Type() const { return type_; }
framework::proto::VarType::Type Type() const { return var_->Type(); }
void SetDataType(framework::proto::VarType::Type data_type) {
data_type_ = data_type;
var_->SetDataType(data_type);
if (grad_var_) {
grad_var_->SetDataType(data_type_);
grad_var_->SetDataType(data_type);
}
}
framework::proto::VarType::Type DataType() const {
const framework::Tensor* tensor = nullptr;
if (var_.IsInitialized()) {
if (type_ == framework::proto::VarType::LOD_TENSOR) {
tensor = &(var_.Get<framework::LoDTensor>());
} else if (type_ == framework::proto::VarType::SELECTED_ROWS) {
tensor = &(var_.Get<framework::SelectedRows>().value());
} else {
VLOG(6) << "Variable " << name_ << " is not initialized";
return data_type_;
}
}
if (tensor && tensor->IsInitialized()) {
return tensor->type();
} else {
VLOG(6) << "The tensor of variable " << name_ << " is not initialized";
return data_type_;
}
}
framework::proto::VarType::Type DataType() const { return var_->DataType(); }
void ClearGradient();
......@@ -214,26 +183,23 @@ class VarBase {
const bool blocking) const;
private:
framework::Variable var_;
std::string name_;
/**
* NOTE(zengjinle): never remove the const qualifier of `var_` if you are
* not very familiar with the autograd idea (including the higher order
* derivative).
*/
const std::shared_ptr<VariableWrapper> var_;
std::shared_ptr<VarBase> grad_var_;
std::vector<std::shared_ptr<OpBase>> grad_ops_;
mutable size_t copied_counter_ = 0;
// grad_op indicates which grad_op will this var be used as input
std::vector<std::weak_ptr<OpBase>> grad_ops_;
// add this property for users may set stop_gradient themselves and this
// should override the
// frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1};
bool grad_generated_{false};
bool persistable_{false};
framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR};
framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32};
static ThreadSafeNameSet name_set_;
};
using VariableWrapperList = std::vector<std::shared_ptr<VariableWrapper>>;
class Layer {
public:
virtual ~Layer() {}
......@@ -244,6 +210,7 @@ class Layer {
}
};
template <typename VarType>
class DygraphExecutionContext : public framework::ExecutionContext {
using Variable = framework::Variable;
......@@ -253,9 +220,9 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const platform::DeviceContext& device_context,
const framework::RuntimeContext& ctx,
std::vector<framework::KernelConfig>* configs,
const NameVarBaseMap& var_base_map_in,
const NameVarBaseMap& var_base_map_out,
const framework::AttributeMap* attrs)
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
: ExecutionContext(op, scope, device_context, ctx, configs),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
......@@ -303,16 +270,16 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}
bool HasAttr(const std::string& name) const override {
return attrs_->count(name);
return attrs_.count(name) != 0;
}
const framework::AttributeMap& Attrs() const override { return *attrs_; }
const framework::AttributeMap& Attrs() const override { return attrs_; }
const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_->find(name);
auto it = attrs_.find(name);
PADDLE_ENFORCE_NE(
it, attrs_->end(),
it, attrs_.end(),
platform::errors::NotFound("can not find [%s] in attrs", name));
return it->second;
......@@ -395,16 +362,17 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}
private:
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap* attrs_;
const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_;
};
// infer var type context for imperative mode
template <typename VarType>
class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
public:
RuntimeInferVarTypeContext(const NameVarBaseMap& inputs,
const NameVarBaseMap* outputs,
RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarMap<VarType>* outputs,
const framework::AttributeMap& attrs_map)
: InferVarTypeContext(nullptr, nullptr),
inputs_(inputs),
......@@ -536,65 +504,51 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
}
private:
const NameVarBaseMap& inputs_;
const NameVarBaseMap* outputs_;
const NameVarMap<VarType>& inputs_;
const NameVarMap<VarType>* outputs_;
const framework::AttributeMap& attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, VarBase*> var_set_;
std::unordered_map<std::string, VarType*> var_set_;
};
// TODO(zjl): to support py_func layer
class OpBase : public std::enable_shared_from_this<OpBase> {
class OpBase {
DISABLE_COPY_AND_ASSIGN(OpBase);
public:
~OpBase() { VLOG(3) << "Destruct Op: " << Type() << std::endl; }
OpBase() = default;
// Developer should not rely on this method to create OpBase.
// OpBase should be created in Tracer and managed by Tracer totally.
template <typename... Args>
static std::shared_ptr<OpBase> Create(Args&&... args) {
return std::shared_ptr<OpBase>(new OpBase(std::forward<Args>(args)...));
}
~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
size_t id() const { return id_; }
const std::string& Type() const { return op_->Type(); }
void Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs);
const framework::VariableNameMap& InputNameMap() const {
return op_->Inputs();
}
const framework::VariableNameMap& OutputNameMap() const {
return op_->Outputs();
}
const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::OpInfo& Info() const { return op_->Info(); }
const framework::OperatorBase& InnerOp() const { return *op_; }
void ClearBackwardTrace();
const std::vector<OpBase*>& GradPendingOps() const {
const std::vector<std::shared_ptr<OpBase>>& GradPendingOps() const {
return grad_pending_ops_;
}
void SetGradPendingOps(std::vector<OpBase*> vec_temp) {
grad_pending_ops_.swap(vec_temp);
void SetGradPendingOps(std::vector<std::shared_ptr<OpBase>> pending_ops) {
grad_pending_ops_ = std::move(pending_ops);
}
void InsertGradPendingOps(OpBase* op) { grad_pending_ops_.emplace_back(op); }
NameVarMap<VariableWrapper>* GetMutableOutsMap() { return &outs_; }
NameVarMap<VariableWrapper>* GetMutableInsMap() { return &ins_; }
const NameVarMap<VariableWrapper>& GetInsMap() { return ins_; }
const NameVarMap<VariableWrapper>& GetOutsMap() { return outs_; }
void SortGradPendingOps() {
std::sort(grad_pending_ops_.begin(), grad_pending_ops_.end(),
[](OpBase* op1, OpBase* op2) { return op1->id() > op2->id(); });
}
NameVarBaseMap* GetMutableOutsMap() { return &outs_; }
NameVarBaseMap* GetMutableInsMap() { return &ins_; }
const NameVarBaseMap& GetInsMap() { return ins_; }
const NameVarBaseMap& GetOutsMap() { return outs_; }
const platform::Place& place() const { return place_; }
// TODO(jiabin) prepare for backward hook
......@@ -609,41 +563,40 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
}
}
private:
OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place);
void SetType(const std::string& type);
public:
OpBase() {}
void CheckAttrs() {
auto& info = op_->Info();
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true);
}
}
void SetType(const std::string& type) { type_ = type; }
void SetInput(const std::string& name,
std::vector<std::shared_ptr<VarBase>> vec_var_base) {
ins_[name] = std::move(vec_var_base);
void SetInput(const std::string& name, VariableWrapperList vars) {
ins_[name] = std::move(vars);
}
void SetOutput(const std::string& name,
std::vector<std::shared_ptr<VarBase>> vec_var_base) {
outs_[name] = std::move(vec_var_base);
void SetOutput(const std::string& name, VariableWrapperList vars) {
outs_[name] = std::move(vars);
}
void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
void SetAttr(const std::string& name, const framework::Attribute& v) {
attrs_[name] = v;
}
void SetBlockAttr(const std::string& name, framework::BlockDesc* block) {
PADDLE_THROW("SetBlockAttr is not support in dygraph OpBase");
}
const framework::AttributeMap& Attrs() { return attrs_; }
void CreateOperatorBase();
void SetId(size_t id) { id_ = id; }
void SetPlace(platform::Place place) { place_ = place; }
bool HasAttr(const std::string& name) const {
return attrs_.find(name) != attrs_.end();
}
void SetPlace(const platform::Place& place) { place_ = place; }
bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; }
const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
......@@ -657,31 +610,49 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
return boost::get<T>(GetAttr(name));
}
private:
size_t id_;
void AddAllowedEmptyVar(const VariableWrapper* var) {
allow_empty_vars_.emplace(var);
}
bool IsAllowedEmptyVar(const VariableWrapper* var) {
return allow_empty_vars_.count(var) > 0;
}
static void Run(const framework::OperatorBase& op,
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
static void Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
private:
NameVarMap<VariableWrapper> ins_;
NameVarMap<VariableWrapper> outs_;
framework::AttributeMap attrs_;
std::unique_ptr<framework::OperatorBase> op_;
std::vector<std::function<void()>> backward_hooks_;
std::vector<std::shared_ptr<OpBase>> grad_pending_ops_;
platform::Place place_;
// Not need to be std::weak_ptr, because op is binded to a certain Tracer,
// and would not be used by a Tracer that does not create itself.
std::unordered_set<const VariableWrapper*> allow_empty_vars_;
std::vector<OpBase*> grad_pending_ops_;
size_t id_{-1UL};
// This part is only used for backward
NameVarBaseMap ins_;
NameVarBaseMap outs_;
std::string type_;
framework::AttributeMap attrs_;
std::vector<std::function<void()>> backward_hooks_;
};
template <typename VarType>
class DygraphInferShapeContext : public framework::InferShapeContext {
using DDim = framework::DDim;
public:
DygraphInferShapeContext(const NameVarBaseMap* in, const NameVarBaseMap* out,
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr)
: var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {}
......@@ -909,9 +880,6 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
// auto& vars = OutputVars(name);
// SetDims(vars, dims);
auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(),
......@@ -992,9 +960,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
}
private:
const NameVarBaseMap* var_base_map_in_;
const NameVarBaseMap* var_base_map_out_;
std::string type_;
const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_;
};
......
......@@ -28,8 +28,9 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
}
}
void PreparedOp::PrepareData(
const platform::Place& place, const NameVarBaseMap& ins,
template <typename VarType>
static void PrepareDataImpl(
const platform::Place& place, const NameVarMap<VarType>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) {
for (const auto& name_pair : ins) {
......@@ -59,22 +60,37 @@ void PreparedOp::PrepareData(
}
}
void PreparedOp::PrepareData(
const platform::Place& place, const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) {
PrepareDataImpl<VarBase>(place, ins, op, expected_kernel_key);
}
void PreparedOp::PrepareData(
const platform::Place& place, const NameVarMap<VariableWrapper>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) {
PrepareDataImpl<VariableWrapper>(place, ins, op, expected_kernel_key);
}
PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs)
: op_(op),
ctx_(ctx),
func_(std::move(func)),
func_(func),
dev_ctx_(dev_ctx),
kernel_configs_(kernel_configs) {}
PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::OperatorWithKernel& op,
platform::Place place,
const framework::AttributeMap* attrs) {
template <typename VarType>
PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op,
platform::Place place,
const framework::AttributeMap& attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
......@@ -90,8 +106,9 @@ PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
auto& kernels = kernels_iter->second;
framework::RuntimeContext ctx({}, {});
auto expected_kernel_key = op.GetExpectedKernelType(DygraphExecutionContext(
op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs));
auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
......@@ -108,24 +125,57 @@ PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
place = dev_ctx->GetPlace();
}
PrepareData(place, ins, op, expected_kernel_key);
PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
}
void PreparedOp::Run(const NameVarBaseMap* in, const NameVarBaseMap* out,
const framework::AttributeMap* attrs) {
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareOpImpl<VarBase>(ins, outs, op, place, attrs);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareOpImpl<VariableWrapper>(ins, outs, op, place, attrs);
}
template <typename VarType>
static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs) {
// TODO(zjl): remove scope in dygraph
framework::Scope scope;
DygraphInferShapeContext infer_shape_ctx(in, out, attrs);
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs);
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
framework::OperatorWithKernel* op_ker =
(framework::OperatorWithKernel*)(&op_);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx,
kernel_configs, ins, outs, attrs));
}
op_ker->InferShape(&infer_shape_ctx);
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, kernel_configs_, ins,
outs, attrs);
}
func_(DygraphExecutionContext(op_, scope, *dev_ctx_, ctx_, kernel_configs_,
*in, *out, attrs));
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_,
kernel_configs_, ins, outs, attrs);
}
} // namespace imperative
......
......@@ -30,28 +30,42 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
class PreparedOp {
public:
static PreparedOp Prepare(const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
platform::Place place,
const framework::AttributeMap* attrs);
const platform::Place& place,
const framework::AttributeMap& attrs);
inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx_; }
void Run(const NameVarBaseMap* in, const NameVarBaseMap* out,
const framework::AttributeMap* attrs);
void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap& attrs);
void Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs);
static void PrepareData(const platform::Place& place,
const NameVarBaseMap& ins,
const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key);
private:
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
static void PrepareData(const platform::Place& place,
const NameVarMap<VariableWrapper>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key);
private:
const framework::OperatorBase& op_;
......
......@@ -44,7 +44,8 @@ TEST(test_layer, test_runtime_context) {
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap attrs;
auto* ctx = new imperative::RuntimeInferVarTypeContext(ins, &outs, attrs);
auto *ctx = new imperative::RuntimeInferVarTypeContext<imperative::VarBase>(
ins, &outs, attrs);
ASSERT_TRUE(ctx->HasVar("vin"));
ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out"));
......@@ -57,9 +58,9 @@ TEST(test_layer, test_runtime_context) {
ASSERT_ANY_THROW(ctx->SetLoDLevel("vin", 2));
}
std::string LayerDebugString(const std::string& op_type,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs);
std::string LayerDebugString(const std::string &op_type,
const NameVarBaseMap &ins,
const NameVarBaseMap &outs);
TEST(test_layer, test_debug_string) {
platform::CPUPlace place;
......@@ -67,7 +68,7 @@ TEST(test_layer, test_debug_string) {
new imperative::VarBase(false, "vin"));
var_pair in_pair = var_pair("X", vb_vector(1, vin));
auto test_func = [&](std::shared_ptr<imperative::VarBase>& vout) {
auto test_func = [&](std::shared_ptr<imperative::VarBase> &vout) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair};
......@@ -119,6 +120,34 @@ TEST(test_layer, test_debug_string) {
ASSERT_TRUE(res_sr.find("SelectedRows") != std::string::npos);
}
static std::shared_ptr<imperative::OpBase> CreateOpBase(
size_t id, const std::string &type, const imperative::NameVarBaseMap &ins,
const imperative::NameVarBaseMap &outs,
const framework::AttributeMap &attrs, const platform::Place &place) {
auto op = std::make_shared<imperative::OpBase>();
op->SetId(id);
op->SetPlace(place);
op->SetType(type);
op->SetAttrMap(attrs);
for (auto &pair : ins) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetInput(pair.first, vars);
}
for (auto &pair : outs) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetOutput(pair.first, vars);
}
return op;
}
TEST(test_layer, test_clear_backward_info) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
......@@ -133,13 +162,11 @@ TEST(test_layer, test_clear_backward_info) {
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1;
std::shared_ptr<imperative::OpBase> op(
OpBase::Create(0, "mul", ins, outs, concat_att_map, place));
std::shared_ptr<imperative::OpBase> preceding_op(
OpBase::Create(0, "mul", ins, outs, concat_att_map, place));
op->InsertGradPendingOps(preceding_op.get());
*(op->GetMutableInsMap()) = ins;
*(op->GetMutableOutsMap()) = outs;
auto op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place);
auto preceding_op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place);
op->SetGradPendingOps({preceding_op});
ASSERT_GT(op->GetInsMap().size(), 0UL);
ASSERT_GT(op->GetOutsMap().size(), 0UL);
ASSERT_GT(op->GradPendingOps().size(), 0UL);
......@@ -163,10 +190,10 @@ TEST(test_layer, test_varbase_basic) {
std::shared_ptr<imperative::VarBase> vin_with_grad(
new imperative::VarBase(true, "vin"));
ASSERT_ANY_THROW(vin->MutableGradVar());
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable*>(
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable *>(
vin_with_grad->MutableGradVar()) != 0));
ASSERT_TRUE(
dynamic_cast<framework::Variable*>(vin_with_grad->MutableGradVar()) != 0);
ASSERT_TRUE(dynamic_cast<framework::Variable *>(
vin_with_grad->MutableGradVar()) != 0);
vin_with_grad->SetOverridedStopGradient(false);
ASSERT_FALSE(vin_with_grad->OverridedStopGradient());
ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true));
......@@ -195,14 +222,14 @@ TEST(test_layer, test_dygraph_execution_context) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
paddle::platform::CPUPlace cpu_place;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(cpu_place);
auto *dev_ctx = pool.Get(cpu_place);
paddle::framework::RuntimeContext ctx({}, {});
framework::Scope scope;
DygraphExecutionContext dy_exe_context(*(op.get()), scope, *dev_ctx, ctx,
nullptr, ins, outs, &concat_att_map);
DygraphExecutionContext<imperative::VarBase> dy_exe_context(
*(op.get()), scope, *dev_ctx, ctx, nullptr, ins, outs, concat_att_map);
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
......@@ -229,7 +256,8 @@ TEST(test_layer, test_dygraph_infershape_context) {
framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1;
DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &concat_att_map);
DygraphInferShapeContext<imperative::VarBase> infer_shape_ctx(
&ins, &outs, &concat_att_map);
bool have_x = infer_shape_ctx.HasOutputs("Out");
ASSERT_EQ(have_x, true);
......
......@@ -114,7 +114,7 @@ TEST(test_prepare_op, test_prepare_op) {
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
ins, outs,
dynamic_cast<framework::OperatorWithKernel&>(*op),
place, &split_attr_map));
place, split_attr_map));
}
const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
......@@ -165,7 +165,7 @@ TEST(test_prepare_op, test_prepare_data) {
// test if it can be transformed to GPU place
PreparedOp prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place,
&attr_map);
attr_map);
for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place(
......@@ -213,7 +213,7 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
// test if it never transferred on GPU place
PreparedOp prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place,
&attr_map);
attr_map);
for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place(
......
......@@ -18,6 +18,7 @@
#include <paddle/fluid/framework/op_registry.h>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "gtest/gtest.h"
......@@ -147,9 +148,9 @@ TEST(test_tracer, test_track_backward_output) {
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
auto* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
}
TEST(test_tracer, test_track_backward_input) {
......@@ -186,9 +187,10 @@ TEST(test_tracer, test_track_backward_input) {
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
auto* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
}
#if defined(PADDLE_WITH_CUDA)
TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
......@@ -344,10 +346,12 @@ TEST(test_tracer, test_var_without_grad_var) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
}
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
detail::BackwardStrategy back_st;
imperative::Engine* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
engine->Init(vout.get(), back_st);
engine->Execute();
......@@ -369,10 +373,137 @@ TEST(test_tracer, test_var_without_grad_var) {
}
}
template <typename T>
using WeakPtrSet =
std::set<std::weak_ptr<T>, std::owner_less<std::weak_ptr<T>>>;
static void TestVarOpDestructionMain(const platform::Place& place,
int64_t tensor_size = 10,
size_t loop_num = 10) {
WeakPtrSet<VariableWrapper> var_wrappers;
WeakPtrSet<VarBase> var_bases;
WeakPtrSet<OpBase> op_bases;
Tracer tracer;
{
auto x = std::make_shared<VarBase>("x");
auto y = std::make_shared<VarBase>("y");
x->MutableVar()
->GetMutable<framework::LoDTensor>()
->Resize({tensor_size, tensor_size})
.mutable_data<float>(place);
y->MutableVar()
->GetMutable<framework::LoDTensor>()
->Resize({tensor_size, tensor_size})
.mutable_data<float>(place);
x->SetOverridedStopGradient(false);
y->SetOverridedStopGradient(true);
for (size_t i = 0; i < loop_num; ++i) {
size_t var_wrapper_num = var_wrappers.size();
size_t var_base_num = var_bases.size();
size_t op_base_num = op_bases.size();
auto z = std::make_shared<VarBase>("z_" + std::to_string(i));
tracer.TraceOp("mul", NameVarBaseMap{{"X", {x}}, {"Y", {y}}},
NameVarBaseMap{{"Out", {z}}}, framework::AttributeMap{},
place, true);
ASSERT_EQ(z->GradOps().size(), 0UL);
ASSERT_EQ(z->GradVarBase()->GradOps().size(), 1UL);
auto new_op = z->GradVarBase()->GradOps()[0];
ASSERT_EQ(x->GradOps().size(), 0UL);
ASSERT_EQ(y->GradOps().size(), 0UL);
std::unordered_set<std::shared_ptr<OpBase>> expected_pending_ops;
if (i == 0) {
ASSERT_EQ(x->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL);
} else {
ASSERT_EQ(x->GradVarBase()->GradOps().size(), 1UL);
ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL);
for (auto& op : x->GradVarBase()->GradOps()) {
expected_pending_ops.emplace(op);
}
for (auto& op : y->GradVarBase()->GradOps()) {
expected_pending_ops.emplace(op);
}
std::unordered_set<std::shared_ptr<OpBase>> actual_pending_ops;
for (auto& op : new_op->GradPendingOps()) {
actual_pending_ops.emplace(op);
}
ASSERT_TRUE(expected_pending_ops == actual_pending_ops);
ASSERT_EQ(expected_pending_ops.count(new_op), 0UL);
}
var_wrappers.emplace(x->SharedVar());
var_wrappers.emplace(x->GradVarBase()->SharedVar());
var_wrappers.emplace(y->SharedVar());
var_wrappers.emplace(y->GradVarBase()->SharedVar());
var_wrappers.emplace(z->SharedVar());
var_wrappers.emplace(z->GradVarBase()->SharedVar());
var_bases.emplace(x);
var_bases.emplace(x->GradVarBase());
var_bases.emplace(y);
var_bases.emplace(y->GradVarBase());
var_bases.emplace(z);
var_bases.emplace(z->GradVarBase());
for (auto& op : expected_pending_ops) {
op_bases.emplace(op);
}
if (i == 0) {
ASSERT_EQ(var_wrapper_num, 0UL);
ASSERT_EQ(var_base_num, 0UL);
ASSERT_EQ(op_base_num, 0UL);
ASSERT_EQ(var_wrappers.size(), 6UL);
ASSERT_EQ(var_bases.size(), 6UL);
ASSERT_EQ(op_bases.size(), 0UL);
} else {
ASSERT_EQ(var_wrappers.size(), var_wrapper_num + 2);
ASSERT_EQ(var_bases.size(), var_base_num + 2);
ASSERT_EQ(op_bases.size(), op_base_num + 1);
}
x = z; // recurrent usage
}
}
for (auto& var : var_wrappers) {
ASSERT_TRUE(var.expired());
}
for (auto& var : var_bases) {
ASSERT_TRUE(var.expired());
}
for (auto& op : op_bases) {
ASSERT_TRUE(op.expired());
}
}
TEST(test_tracer, test_var_op_destruction) {
TestVarOpDestructionMain(platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
TestVarOpDestructionMain(platform::CUDAPlace(0));
#endif
}
} // namespace imperative
} // namespace paddle
USE_OP(mul);
USE_OP(mul_grad);
USE_OP(reduce_sum);
USE_OP(reduce_sum_grad);
USE_OP(elementwise_add);
......@@ -15,7 +15,10 @@
#include <set>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace imperative {
......@@ -48,22 +51,24 @@ static void ClearNoNeedBufferInputs(OpBase* op) {
PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(), true,
"Only support LoDTensor");
// TODO(zjl): support higher order derivatives
auto new_var = new VarBase(false, each_var->Name());
auto new_var = new VariableWrapper(each_var->Name());
auto* new_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims());
new_tensor->set_lod(old_tensor.lod());
each_var.reset(new_var);
op->AddAllowedEmptyVar(new_var);
}
}
}
static std::vector<std::unique_ptr<OpBase>> CreateGradOpBases(
const OpBase* fw_op_base, const NameVarBaseMap& in,
const NameVarBaseMap& out) {
if (fw_op_base->Info().dygraph_grad_op_maker_) {
return fw_op_base->Info().dygraph_grad_op_maker_(fw_op_base, in, out);
static std::vector<std::shared_ptr<OpBase>> CreateGradOpBases(
const framework::OpInfo& info, const std::string& type,
const NameVarBaseMap& in, const NameVarBaseMap& out,
const framework::AttributeMap& attrs) {
if (info.dygraph_grad_op_maker_) {
return info.dygraph_grad_op_maker_(type, in, out, attrs);
} else {
return {};
}
......@@ -83,17 +88,22 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward) {
VLOG(1) << "Trace Op: " << type;
size_t op_id = GenerateUniqueId();
auto op = OpBase::Create(op_id, type, ins, outs, attrs, place);
op->Run(ins, outs);
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker();
if (attr_checker) {
attr_checker->Check(&attrs, true);
}
OpBase::Run(*op, ins, outs, attrs, place);
if (enable_program_desc_tracing_) {
VLOG(5) << "Trace op " << type << " into ProgramDesc";
program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs());
program_desc_tracer_->InsertOp(type, ins, outs, attrs);
}
if (ComputeRequiredGrad(ins, outs, trace_backward)) {
TraceBackward(op, ins, outs);
TraceBackward(op_info, type, ins, outs, attrs, place);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
......@@ -102,22 +112,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
framework::AttributeMap attrs) {
VLOG(1) << "Trace Op: " << type;
size_t op_id = GenerateUniqueId();
auto op =
OpBase::Create(op_id, type, ins, outs, std::move(attrs), expected_place_);
op->Run(ins, outs);
if (enable_program_desc_tracing_) {
VLOG(5) << "Trace op " << type << " into ProgramDesc";
program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs());
}
if (ComputeRequiredGrad(ins, outs, no_grad_)) {
TraceBackward(op, ins, outs);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
TraceOp(type, ins, outs, std::move(attrs), expected_place_, no_grad_);
}
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
......@@ -138,78 +133,19 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
return false;
}
void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs) {
// grad_to_var is a map of framework::GradVarName(in_var_name/out_var_name) ->
// in_var_name/out_var_name
std::unordered_map<std::string, std::string> grad_to_var;
// Get grad_op_desc using fwd_op_desc
std::vector<std::unique_ptr<OpBase>> grad_op_bases_ =
CreateGradOpBases(fwd_op.get(), ins, outs);
size_t grad_op_num = grad_op_bases_.size();
std::set<VarBase*> set_input_vars;
for (auto& fwd_in_it : ins) {
for (auto& var_base_it : fwd_in_it.second) {
set_input_vars.insert(var_base_it.get());
}
}
for (auto& fwd_out_it : outs) {
for (auto& var_base_it : fwd_out_it.second) {
set_input_vars.insert(var_base_it.get());
}
}
for (size_t i = 0; i < grad_op_num; ++i) {
size_t trace_id = fwd_op->id();
std::shared_ptr<OpBase> grad_op = std::move(grad_op_bases_[i]);
void Tracer::TraceBackward(const framework::OpInfo& info,
const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
auto grad_op_bases = CreateGradOpBases(info, type, ins, outs, attrs);
auto grad_op_num = grad_op_bases.size();
if (grad_op_num == 0) return;
size_t trace_id = GenerateUniqueId();
for (auto& grad_op : grad_op_bases) {
grad_op->SetPlace(place);
grad_op->SetId(trace_id);
grad_op->SetPlace(fwd_op->place());
grad_op->CreateOperatorBase();
auto& grad_in = *(grad_op->GetMutableInsMap());
auto& grad_out = *(grad_op->GetMutableOutsMap());
for (auto& grad_in_it : grad_in) {
for (auto& var_base_it : grad_in_it.second) {
if (set_input_vars.count(var_base_it.get()) == 0) {
var_base_it->AddGradOps(grad_op);
engine_->InsertGradVar(var_base_it.get());
}
}
}
std::set<OpBase*> visited_preceding_ops;
for (auto& grad_out_it : grad_out) {
bool flag_clear_list = false;
for (auto& var_base_it : grad_out_it.second) {
if ((!var_base_it->OverridedStopGradient()) ||
(grad_out_it.second.size() > 1)) {
auto preceding_ops = var_base_it->GradOps();
if (!preceding_ops.empty()) {
for (const auto& op : preceding_ops) {
visited_preceding_ops.insert(op);
}
}
} else {
flag_clear_list = true;
}
}
if (flag_clear_list) {
grad_out_it.second.clear();
}
}
std::vector<OpBase*> vec_preceding_ops(visited_preceding_ops.begin(),
visited_preceding_ops.end());
grad_op->SetGradPendingOps(std::move(vec_preceding_ops));
// this OpBase* is just used to manage op's life time
engine_->InsertOp(grad_op.get(), grad_op);
ClearNoNeedBufferInputs(grad_op.get());
}
}
......
......@@ -64,9 +64,6 @@ class Tracer {
bool ComputeRequiredGrad(const NameVarBaseMap& ins,
const NameVarBaseMap& outs, bool trace_backward);
void TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const NameVarBaseMap& ins, const NameVarBaseMap& outs);
Engine* GetDefaultEngine() const { return engine_.get(); }
void SetEnableProgramDescTracing(bool enabled) {
......@@ -94,6 +91,11 @@ class Tracer {
void SetNoGrad(bool no_grad) { no_grad_ = no_grad; }
private:
void TraceBackward(const framework::OpInfo& info, const std::string& type,
const NameVarBaseMap& ins, const NameVarBaseMap& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
static size_t GenerateUniqueId() {
static std::atomic<size_t> id{0};
return id.fetch_add(1);
......
......@@ -22,12 +22,16 @@ limitations under the License. */
namespace paddle {
namespace imperative {
class VariableWrapper;
class VarBase;
class OpBase;
class Tracer;
using NameVarBaseMap =
std::map<std::string, std::vector<std::shared_ptr<VarBase>>>;
template <typename T>
using NameVarMap = std::map<std::string, std::vector<std::shared_ptr<T>>>;
using NameVarBaseMap = NameVarMap<VarBase>;
using NameVariableWrapperMap = NameVarMap<VariableWrapper>;
using WeakNameVarBaseMap =
std::map<std::string, std::vector<std::weak_ptr<VarBase>>>;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace imperative {
class VariableWrapper {
public:
explicit VariableWrapper(const std::string& name) : name_(name) {}
const framework::Variable& Var() const { return var_; }
framework::Variable* MutableVar() { return &var_; }
// This is used for python api
void SetOverridedStopGradient(bool stop_gradient) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient);
}
// This is used for python api
bool OverridedStopGradient() const { return overrided_stop_gradient_ != 0; }
// This is used inside C++
int InnerOverridedStopGradient() const { return overrided_stop_gradient_; }
// This is used inside C++
void InnerSetOverridedStopGradient(bool stop_gradient) {
if (overrided_stop_gradient_ == -1) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient);
} else {
VLOG(6) << "Ignore Stop gradient conversion for Var: " << Name()
<< "Set value is: " << overrided_stop_gradient_;
}
}
void SetPersistable(bool persistable) { persistable_ = persistable; }
bool Persistable() const { return persistable_; }
const std::string& Name() const { return name_; }
void SetName(const std::string& name) { name_ = name; }
void SetType(framework::proto::VarType::Type type) { type_ = type; }
framework::proto::VarType::Type Type() const { return type_; }
void SetDataType(framework::proto::VarType::Type data_type) {
data_type_ = data_type;
}
framework::proto::VarType::Type DataType() const {
const framework::Tensor* tensor = nullptr;
if (var_.IsInitialized()) {
if (type_ == framework::proto::VarType::LOD_TENSOR) {
tensor = &(var_.Get<framework::LoDTensor>());
} else if (type_ == framework::proto::VarType::SELECTED_ROWS) {
tensor = &(var_.Get<framework::SelectedRows>().value());
} else {
VLOG(6) << "Variable " << name_ << " is not initialized";
return data_type_;
}
}
if (tensor && tensor->IsInitialized()) {
return tensor->type();
} else {
VLOG(6) << "The tensor of variable " << name_ << " is not initialized";
return data_type_;
}
}
private:
framework::Variable var_;
std::string name_;
// add this property for users may set stop_gradient themselves and this
// should override the frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1};
bool persistable_{false};
framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR};
framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32};
};
} // namespace imperative
} // namespace paddle
......@@ -68,8 +68,7 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
......@@ -86,8 +85,6 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", this->Output("Out"));
}
return op;
}
};
......@@ -727,8 +724,7 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("relu_grad_grad");
// input1: Out
op->SetInput("Out", this->Input("Out"));
......@@ -737,7 +733,6 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
// output: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
......@@ -750,8 +745,7 @@ class LeakyReluDoubleGradMaker
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("leaky_relu_grad_grad");
// input1: Out
op->SetInput("Out", this->Input("Out"));
......@@ -760,7 +754,6 @@ class LeakyReluDoubleGradMaker
op->SetAttrMap(this->Attrs());
// Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
......@@ -772,8 +765,7 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("elu_grad_grad");
op->SetInput("X", this->Input("X"));
......@@ -785,7 +777,6 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
// Out@GRAD@GRAD: ddy
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
......@@ -797,8 +788,7 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("sqrt_grad_grad");
op->SetInput("Out", this->Input("Out"));
op->SetInput("DX", this->Output(framework::GradVarName("X")));
......@@ -806,7 +796,6 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput("DOut", this->InputGrad("Out"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
......@@ -818,8 +807,7 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("square_grad_grad");
op->SetInput("X", this->Input("X"));
// Out@GRAD: dy
......@@ -833,7 +821,6 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetOutput("DX", this->InputGrad("X"));
// Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
......@@ -849,16 +836,13 @@ class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("pow_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetInput("FactorTensor", this->Input("FactorTensor"));
op->SetAttrMap(this->Attrs());
return op;
}
};
class PowOp : public framework::OperatorWithKernel {
......
......@@ -93,13 +93,11 @@ class AddPositionEncodingGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("add_position_encoding_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -132,8 +132,7 @@ class AffineChannelGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("affine_channel_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
......@@ -144,8 +143,6 @@ class AffineChannelGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<T>(op);
}
};
......
......@@ -204,8 +204,7 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("affine_grid_grad");
op->SetInput("OutputShape", this->Input("OutputShape"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
......@@ -213,7 +212,6 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("Theta"), this->InputGrad("Theta"));
return std::unique_ptr<T>(op);
}
};
......
......@@ -108,15 +108,13 @@ class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("argsort_grad");
op->SetInput("Indices", this->Output("Indices"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -216,14 +216,12 @@ class ArrayToLoDTensorGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("lod_tensor_to_array");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetInput("RankTable", this->Input("RankTable"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -99,12 +99,10 @@ class AssignGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("X"));
return std::unique_ptr<T>(op);
}
};
......
......@@ -721,8 +721,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
};
template <typename T>
std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const {
auto *op = new T();
void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
......@@ -746,8 +745,6 @@ std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<T>(op);
}
} // namespace operators
......
......@@ -165,7 +165,7 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override;
void Apply(GradOpPtr<T> op) const override;
};
class BatchNormOpInferVarType
......
......@@ -154,8 +154,7 @@ class BilinearTensorProductGradOpMaker
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("bilinear_tensor_product_grad");
op->SetAttrMap(this->Attrs());
op->SetInput("X", this->Input("X"));
......@@ -169,8 +168,6 @@ class BilinearTensorProductGradOpMaker
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
return op;
}
};
......
......@@ -141,15 +141,13 @@ class BprLossGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("bpr_loss_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
} // namespace operators
......
......@@ -44,14 +44,12 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto grad = new T();
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("cast");
grad->SetInput("X", this->OutputGrad("Out"));
grad->SetOutput("Out", this->InputGrad("X"));
grad->SetAttr("out_dtype", this->GetAttr("in_dtype"));
grad->SetAttr("in_dtype", this->GetAttr("out_dtype"));
return std::unique_ptr<T>(grad);
}
};
......
......@@ -129,8 +129,7 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("center_loss_grad");
retv->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
retv->SetInput("SampleCenterDiff", this->Output("SampleCenterDiff"));
......@@ -138,7 +137,6 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> {
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
return retv;
}
};
......
......@@ -84,14 +84,12 @@ class ClipGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("clip_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -62,13 +62,11 @@ class CAllGatherOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_reducescatter");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
return retv;
}
};
......
......@@ -23,13 +23,11 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_allreduce_sum");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
return retv;
}
};
......
......@@ -66,13 +66,11 @@ class CReduceScatterOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("c_allgather");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
return retv;
}
};
......
......@@ -185,8 +185,7 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("concat_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("AxisTensor")) {
......@@ -195,7 +194,6 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -260,8 +260,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("conditional_block_grad");
grad_op->SetInput(ConditionalOp::kCondition,
this->Input(ConditionalOp::kCondition));
......@@ -278,7 +277,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition",
this->GetAttr("is_scalar_condition"));
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -215,15 +215,13 @@ class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("read_from_array");
grad_op->SetInput("I", this->Input("I"));
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetInput("X_W", this->Input("X"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......@@ -233,14 +231,12 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("write_to_array");
grad_op->SetInput("I", this->Input("I"));
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -329,8 +329,7 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *while_grad = new T();
void Apply(GradOpPtr<T> while_grad) const override {
while_grad->SetType("while_grad");
while_grad->SetInput(kX, this->Input(kX));
while_grad->SetInput(kOutputs, this->Output(kOutputs));
......@@ -402,8 +401,6 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
while_grad->SetAttr("original_output_grad", output_grads_list);
while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>());
return std::unique_ptr<T>(while_grad);
}
};
......
......@@ -612,8 +612,7 @@ class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
......@@ -624,8 +623,6 @@ class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......@@ -634,8 +631,7 @@ class Conv3DGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
......@@ -649,8 +645,6 @@ class Conv3DGradMaker : public framework::SingleGradOpMaker<T> {
}
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......@@ -663,8 +657,7 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW
op->SetInput("Input", this->Input("Input"));
......@@ -682,16 +675,14 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput("DDOutput",
ddx.empty()
? this->Empty()
? this->EmptyInputGrad()
: this->InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter",
ddx.empty() ? this->Empty() : this->InputGrad("Filter"));
op->SetOutput("DInput",
ddw.empty() ? this->Empty() : this->InputGrad("Input"));
op->SetOutput("DFilter", ddx.empty() ? this->EmptyInputGrad()
: this->InputGrad("Filter"));
op->SetOutput("DInput", ddw.empty() ? this->EmptyInputGrad()
: this->InputGrad("Input"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......@@ -704,8 +695,7 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW
op->SetInput("Input", this->Input("Input"));
......@@ -720,16 +710,14 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput("DDOutput",
ddx.empty()
? this->Empty()
? this->EmptyInputGrad()
: this->InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter",
ddx.empty() ? this->Empty() : this->InputGrad("Filter"));
op->SetOutput("DInput",
ddw.empty() ? this->Empty() : this->InputGrad("Input"));
op->SetOutput("DFilter", ddx.empty() ? this->EmptyInputGrad()
: this->InputGrad("Filter"));
op->SetOutput("DInput", ddw.empty() ? this->EmptyInputGrad()
: this->InputGrad("Input"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......
......@@ -199,8 +199,7 @@ class ConvShiftGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("conv_shift_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
......@@ -208,7 +207,6 @@ class ConvShiftGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -441,8 +441,7 @@ class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
......@@ -454,7 +453,6 @@ class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker<T> {
}
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -172,8 +172,7 @@ class CosSimGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("cos_sim_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Y", this->Input("Y"));
......@@ -184,7 +183,6 @@ class CosSimGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -187,8 +187,7 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("crop_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
......@@ -197,7 +196,6 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> {
}
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -279,8 +279,7 @@ class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("crop_tensor_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
......@@ -292,7 +291,6 @@ class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
}
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -263,15 +263,13 @@ class CrossEntropyGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("cross_entropy_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......@@ -372,8 +370,7 @@ class CrossEntropyGradOpMaker2 : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("cross_entropy_grad2");
op->SetInput("Label", this->Input("Label"));
op->SetInput("MatchX", this->Output("MatchX"));
......@@ -381,7 +378,6 @@ class CrossEntropyGradOpMaker2 : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -203,8 +203,7 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("cudnn_lstm_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("InitH", this->Input("InitH"));
......@@ -223,7 +222,6 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH"));
op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -58,15 +58,13 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("cumsum");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", boost::get<int>(this->GetAttr("axis")));
grad_op->SetAttr("reverse", !boost::get<bool>(this->GetAttr("reverse")));
grad_op->SetAttr("exclusive", boost::get<bool>(this->GetAttr("exclusive")));
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -131,15 +131,13 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("cvm_grad");
op->SetInput("CVM", this->Input("CVM"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -461,8 +461,7 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("data_norm_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
......@@ -482,8 +481,6 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("BatchSum"));
op->SetOutput(framework::GradVarName("BatchSquareSum"),
this->InputGrad("BatchSquareSum"));
return std::unique_ptr<T>(op);
}
};
......
......@@ -228,9 +228,7 @@ class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("deformable_conv_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
......@@ -244,7 +242,6 @@ class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Mask"), this->InputGrad("Mask"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -211,9 +211,7 @@ class DeformableConvV1GradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("deformable_conv_v1_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter"));
......@@ -225,7 +223,6 @@ class DeformableConvV1GradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Offset"), this->InputGrad("Offset"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -211,9 +211,7 @@ class DeformablePSROIPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("deformable_psroi_pooling_grad");
op->SetInput("Input", this->Input("Input"));
op->SetInput("Trans", this->Input("Trans"));
......@@ -225,7 +223,6 @@ class DeformablePSROIPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Trans"), this->InputGrad("Trans"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -627,8 +627,7 @@ class ROIPerspectiveTransformGradMaker
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("roi_perspective_transform_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("ROIs", this->Input("ROIs"));
......@@ -637,7 +636,6 @@ class ROIPerspectiveTransformGradMaker
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -178,8 +178,7 @@ class SigmoidFocalLossGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("sigmoid_focal_loss_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label"));
......@@ -187,7 +186,6 @@ class SigmoidFocalLossGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -269,8 +269,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("yolov3_loss_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("GTBox", this->Input("GTBox"));
......@@ -283,10 +282,9 @@ class Yolov3LossGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {});
op->SetOutput(framework::GradVarName("GTLabel"), {});
op->SetOutput(framework::GradVarName("GTScore"), {});
return std::unique_ptr<T>(op);
op->SetOutput(framework::GradVarName("GTBox"), this->EmptyInputGrad());
op->SetOutput(framework::GradVarName("GTLabel"), this->EmptyInputGrad());
op->SetOutput(framework::GradVarName("GTScore"), this->EmptyInputGrad());
}
};
......
......@@ -144,14 +144,12 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("dropout_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("Mask", this->Output("Mask"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -76,8 +76,7 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_add_grad_grad");
op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
......@@ -87,7 +86,6 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return op;
}
};
......
......@@ -73,8 +73,7 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_div_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
......@@ -83,7 +82,6 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......@@ -93,8 +91,7 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_div_grad_grad");
op->SetInput("Y", this->Input("Y"));
op->SetInput("Out", this->Input("Out"));
......@@ -107,8 +104,6 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetOutput("DOut", this->InputGrad("Out"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return op;
}
};
......
......@@ -49,8 +49,7 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_max_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
......@@ -58,7 +57,6 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -49,8 +49,7 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_min_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
......@@ -58,7 +57,6 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -76,8 +76,7 @@ class ElementwiseMulOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_mul_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
......@@ -85,7 +84,6 @@ class ElementwiseMulOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
}
};
......@@ -95,8 +93,7 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_mul_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
......@@ -109,7 +106,6 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
}
};
......
......@@ -363,8 +363,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; \
\
protected: \
std::unique_ptr<T> Apply() const override { \
auto *op = new T(); \
void Apply(::paddle::framework::GradOpPtr<T> op) const override { \
op->SetType(#kernel_type "_grad"); \
op->SetInput("X", this->Input("X")); \
op->SetInput("Y", this->Input("Y")); \
......@@ -375,7 +374,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
this->InputGrad("X")); \
op->SetOutput(::paddle::framework::GradVarName("Y"), \
this->InputGrad("Y")); \
return std::unique_ptr<T>(op); \
} \
}
......
......@@ -23,8 +23,7 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_pow_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
......@@ -32,7 +31,6 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
}
};
class ElementwisePowOpMaker : public ElementwiseOpMaker {
......
......@@ -75,8 +75,7 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_sub_grad_grad");
op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
......@@ -86,7 +85,6 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return op;
}
};
......
......@@ -101,14 +101,12 @@ class ErfGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("erf_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -103,15 +103,13 @@ class ExpandAsGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("expand_as_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("target_tensor", this->Input("target_tensor"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -203,8 +203,7 @@ class ExpandGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("expand_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
......@@ -212,7 +211,6 @@ class ExpandGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("expand_times_tensor", this->Input("expand_times_tensor"));
op->SetInput("ExpandTimes", this->Input("ExpandTimes"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -113,8 +113,7 @@ class FilterByInstagGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("filter_by_instag_grad");
op->SetInput("IndexMap", this->Output("IndexMap"));
op->SetInput("Ins", this->Input("Ins"));
......@@ -122,7 +121,6 @@ class FilterByInstagGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("LossWeight", this->Output("LossWeight"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Ins"), this->InputGrad("Ins"));
return op;
}
};
} // namespace operators
......
......@@ -142,14 +142,12 @@ class FlattenGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("flatten_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......@@ -211,14 +209,12 @@ class Flatten2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("flatten2_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -121,9 +121,7 @@ class FSPGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("fsp_grad");
op->SetInput("X", this->Input("X"));
......@@ -134,8 +132,6 @@ class FSPGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
}
};
......
......@@ -61,8 +61,7 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Output("Y"));
......@@ -79,8 +78,6 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return op;
}
};
......
......@@ -227,8 +227,7 @@ class FusedElemwiseActivationGradMaker
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType(this->ForwardOpType() + "_grad");
for (auto &input_param : this->InputNames()) {
......@@ -255,11 +254,10 @@ class FusedElemwiseActivationGradMaker
grad_op->SetOutput(framework::GradVarName("IntermediateOut"),
this->OutputGrad("IntermediateOut"));
} else {
grad_op->SetInput("IntermediateOut", {});
grad_op->SetOutput(framework::GradVarName("IntermediateOut"), {});
grad_op->SetInput("IntermediateOut", this->EmptyOutput());
grad_op->SetOutput(framework::GradVarName("IntermediateOut"),
this->EmptyOutputGrad());
}
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -158,15 +158,13 @@ class FusedEmbeddingSeqPoolGradOpMaker
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("fused_embedding_seq_pool_grad");
op->SetInput("Ids", this->Input("Ids"));
op->SetInput("W", this->Input("W"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -151,15 +151,13 @@ class GatherNdGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("gather_nd_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -113,15 +113,13 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("gather_grad");
op->SetInput("Index", this->Input("Index"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -176,8 +176,7 @@ class GridSampleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("grid_sampler_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Grid", this->Input("Grid"));
......@@ -187,7 +186,6 @@ class GridSampleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Grid"), this->InputGrad("Grid"));
return std::unique_ptr<T>(op);
}
};
......
......@@ -189,8 +189,7 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("group_norm_grad");
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("Bias", this->Input("Bias"));
......@@ -203,8 +202,6 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......
......@@ -390,8 +390,7 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("gru_grad");
grad_op->SetInput("Input", this->Input("Input"));
grad_op->SetInput("H0", this->Input("H0"));
......@@ -415,7 +414,6 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -212,8 +212,7 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("gru_unit_grad");
op->SetInput("Input", this->Input("Input"));
......@@ -232,7 +231,6 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("HiddenPrev"));
op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<T>(op);
}
};
......
......@@ -189,8 +189,7 @@ class HierarchicalSigmoidGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
// Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD
op->SetInput("X", this->Input("X"));
......@@ -207,8 +206,6 @@ class HierarchicalSigmoidGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......
......@@ -106,15 +106,13 @@ class HingeLossGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("hinge_loss_grad");
op->SetInput("Logits", this->Input("Logits"));
op->SetInput("Labels", this->Input("Labels"));
op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
op->SetOutput(framework::GradVarName("Logits"), this->InputGrad("Logits"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -121,15 +121,13 @@ class HuberLossGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("huber_loss_grad");
op->SetInput("Residual", this->Output("Residual"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -159,14 +159,12 @@ class Im2SequenceGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("im2sequence_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -71,13 +71,11 @@ class IncrementGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("increment");
grad_op->SetInput("X", this->Output("Out"));
grad_op->SetOutput("Out", this->Input("X"));
grad_op->SetAttr("step", -boost::get<float>(this->GetAttr("step")));
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -80,8 +80,7 @@ class InstanceNormGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("instance_norm_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
......@@ -94,8 +93,6 @@ class InstanceNormGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<T>(op);
}
};
......@@ -105,8 +102,7 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("instance_norm_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Scale", this->Input("Scale"));
......@@ -121,7 +117,6 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DScale", this->InputGrad("Scale"));
op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y")));
return std::unique_ptr<T>(op);
}
};
......
......@@ -429,8 +429,7 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("SizeTensor") > 0) {
......@@ -445,7 +444,6 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -148,8 +148,7 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("kldiv_loss_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Target", this->Input("Target"));
......@@ -158,7 +157,6 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return std::unique_ptr<T>(op);
}
};
......
......@@ -69,14 +69,12 @@ class L1NormGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("l1_norm_grad");
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -117,13 +117,11 @@ class LabelSmoothGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("label_smooth_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -171,8 +171,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("layer_norm_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Mean", this->Output("Mean"));
......@@ -189,7 +188,6 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -280,8 +280,7 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("linear_chain_crf_grad");
op->SetAttrMap(this->Attrs());
op->SetInput("Emission", this->Input("Emission"));
......@@ -300,8 +299,6 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("Emission"));
op->SetOutput(framework::GradVarName("Transition"),
this->InputGrad("Transition"));
return op;
}
};
......
......@@ -208,14 +208,12 @@ class LoDResetGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("lod_reset_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
......@@ -232,14 +232,12 @@ class LoDTensorToArrayGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *grad_op = new T();
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("array_to_lod_tensor");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetInput("RankTable", this->Input("RankTable"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
}
};
......
......@@ -111,8 +111,7 @@ class LogLossGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
void Apply(GradOpPtr<T> op) const override {
op->SetType("log_loss_grad");
op->SetInput("Predicted", this->Input("Predicted"));
op->SetInput("Labels", this->Input("Labels"));
......@@ -120,7 +119,6 @@ class LogLossGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Predicted"),
this->InputGrad("Predicted"));
op->SetAttrMap(this->Attrs());
return op;
}
};
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册