未验证 提交 d33c4343 编写于 作者: Z Zeng Jinle 提交者: GitHub

Imperative tracer refactoring (#22457)

* refine grad maker, test=develop

* refactor tracer stage 1, test=develop

* merge develop to solve conflict third times, test=develop
上级 08a772cb
...@@ -242,10 +242,11 @@ struct OpInfoFiller<T, kGradOpBaseMaker> { ...@@ -242,10 +242,11 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
"GradOpBaseMaker of %s has been registered", op_type)); "GradOpBaseMaker of %s has been registered", op_type));
info->dygraph_grad_op_maker_ = []( info->dygraph_grad_op_maker_ = [](
const imperative::OpBase* fw_op_base, const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out) { const imperative::NameVarBaseMap& var_base_map_out,
T maker(fw_op_base, var_base_map_in, var_base_map_out); const framework::AttributeMap& attrs) {
T maker(type, var_base_map_in, var_base_map_out, attrs);
return maker(); return maker();
}; };
} }
......
...@@ -28,6 +28,26 @@ limitations under the License. */ ...@@ -28,6 +28,26 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details {
template <typename T>
struct GradOpPtrTrait {};
template <>
struct GradOpPtrTrait<OpDesc> {
using Type = OpDesc*;
};
template <>
struct GradOpPtrTrait<imperative::OpBase> {
using Type = imperative::TracedGradOp*;
};
} // namespace details
template <typename T>
using GradOpPtr = typename details::GradOpPtrTrait<T>::Type;
/* /*
This functor class is responsible for creating the gradient ops for the given This functor class is responsible for creating the gradient ops for the given
operator fwd_op. After it is called (through operator()), the pairs of operator fwd_op. After it is called (through operator()), the pairs of
...@@ -47,6 +67,10 @@ class GradOpDescMakerBase { ...@@ -47,6 +67,10 @@ class GradOpDescMakerBase {
grad_to_var_(grad_to_var), grad_to_var_(grad_to_var),
grad_block_(grad_block) {} grad_block_(grad_block) {}
static std::unique_ptr<OpDesc> CreateOp() {
return std::unique_ptr<OpDesc>(new OpDesc());
}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDesc>> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDesc>> operator()() const = 0;
...@@ -100,7 +124,13 @@ class GradOpDescMakerBase { ...@@ -100,7 +124,13 @@ class GradOpDescMakerBase {
return ret_val; return ret_val;
} }
std::vector<std::string> Empty() const { return {}; } static std::vector<std::string> EmptyInput() { return {}; }
static std::vector<std::string> EmptyOutput() { return {}; }
static std::vector<std::string> EmptyInputGrad() { return {}; }
static std::vector<std::string> EmptyOutputGrad() { return {}; }
std::vector<std::string> InputNames() const { std::vector<std::string> InputNames() const {
return this->fwd_op_.InputNames(); return this->fwd_op_.InputNames();
...@@ -155,16 +185,7 @@ class GradOpDescMakerBase { ...@@ -155,16 +185,7 @@ class GradOpDescMakerBase {
}; };
template <typename T> template <typename T>
class SingleGradOpMaker { class SingleGradOpMaker {};
public:
std::vector<std::unique_ptr<T>> operator()() const {
PADDLE_ENFORCE(false, "should not call this function");
return {};
}
protected:
virtual std::unique_ptr<T> Apply() const = 0;
};
template <> template <>
class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase { class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
...@@ -173,12 +194,13 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase { ...@@ -173,12 +194,13 @@ class SingleGradOpMaker<OpDesc> : public GradOpDescMakerBase {
std::vector<std::unique_ptr<OpDesc>> operator()() const { std::vector<std::unique_ptr<OpDesc>> operator()() const {
std::vector<std::unique_ptr<OpDesc>> retv; std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(this->Apply()); retv.emplace_back(new OpDesc());
this->Apply(retv.front().get());
return retv; return retv;
} }
protected: protected:
virtual std::unique_ptr<OpDesc> Apply() const = 0; virtual void Apply(GradOpPtr<OpDesc> op) const = 0;
}; };
template <> template <>
...@@ -187,16 +209,18 @@ class SingleGradOpMaker<imperative::OpBase> ...@@ -187,16 +209,18 @@ class SingleGradOpMaker<imperative::OpBase>
public: public:
using GradOpBaseMakerBase::GradOpBaseMakerBase; using GradOpBaseMakerBase::GradOpBaseMakerBase;
public: std::vector<std::shared_ptr<imperative::OpBase>> operator()() const {
std::vector<std::unique_ptr<imperative::OpBase>> operator()() const { std::vector<std::shared_ptr<imperative::OpBase>> retv{
std::vector<std::unique_ptr<imperative::OpBase>> retv; std::make_shared<imperative::OpBase>()};
retv.emplace_back(this->Apply()); {
imperative::TracedGradOp grad_op(retv.front());
this->Apply(&grad_op);
}
return retv; return retv;
} }
protected: protected:
virtual std::unique_ptr<imperative::OpBase> Apply() const = 0; virtual void Apply(GradOpPtr<imperative::OpBase> op) const = 0;
}; };
template <typename T, bool DropEmptyIG = true> template <typename T, bool DropEmptyIG = true>
...@@ -205,8 +229,7 @@ class DefaultGradOpMaker final : public SingleGradOpMaker<T> { ...@@ -205,8 +229,7 @@ class DefaultGradOpMaker final : public SingleGradOpMaker<T> {
using SingleGradOpMaker<T>::SingleGradOpMaker; using SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const final { void Apply(GradOpPtr<T> grad) const final {
auto* grad = new T();
grad->SetType(this->ForwardOpType() + "_grad"); grad->SetType(this->ForwardOpType() + "_grad");
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
...@@ -221,19 +244,11 @@ class DefaultGradOpMaker final : public SingleGradOpMaker<T> { ...@@ -221,19 +244,11 @@ class DefaultGradOpMaker final : public SingleGradOpMaker<T> {
} }
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad);
} }
}; };
template <typename T> template <typename T>
class EmptyGradOpMaker { class EmptyGradOpMaker {};
public:
virtual std::vector<std::unique_ptr<T>> operator()()
const final { /* NOLINT */
return {};
}
};
template <> template <>
class EmptyGradOpMaker<OpDesc> final : public GradOpDescMakerBase { class EmptyGradOpMaker<OpDesc> final : public GradOpDescMakerBase {
...@@ -247,10 +262,18 @@ class EmptyGradOpMaker<imperative::OpBase> final ...@@ -247,10 +262,18 @@ class EmptyGradOpMaker<imperative::OpBase> final
: public imperative::GradOpBaseMakerBase { : public imperative::GradOpBaseMakerBase {
public: public:
using GradOpBaseMakerBase::GradOpBaseMakerBase; using GradOpBaseMakerBase::GradOpBaseMakerBase;
std::vector<std::unique_ptr<imperative::OpBase>> operator()() const final { std::vector<std::shared_ptr<imperative::OpBase>> operator()() const final {
return {}; return {};
} }
}; };
} // namespace framework } // namespace framework
namespace operators {
template <typename T>
using GradOpPtr = framework::GradOpPtr<T>;
} // namespace operators
} // namespace paddle } // namespace paddle
...@@ -45,8 +45,9 @@ bool StaticGraphInferNoNeedBufferVarsContext::HasOutput( ...@@ -45,8 +45,9 @@ bool StaticGraphInferNoNeedBufferVarsContext::HasOutput(
} }
DyGraphInferNoNeedBufferVarsContext::DyGraphInferNoNeedBufferVarsContext( DyGraphInferNoNeedBufferVarsContext::DyGraphInferNoNeedBufferVarsContext(
const imperative::NameVarBaseMap &inputs, const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarBaseMap &outputs, const AttributeMap &attrs) const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs)
: InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {} : InferNoNeedBufferVarsContext(attrs), inputs_(inputs), outputs_(outputs) {}
bool DyGraphInferNoNeedBufferVarsContext::HasOutput( bool DyGraphInferNoNeedBufferVarsContext::HasOutput(
......
...@@ -56,15 +56,16 @@ class StaticGraphInferNoNeedBufferVarsContext final ...@@ -56,15 +56,16 @@ class StaticGraphInferNoNeedBufferVarsContext final
class DyGraphInferNoNeedBufferVarsContext final class DyGraphInferNoNeedBufferVarsContext final
: public InferNoNeedBufferVarsContext { : public InferNoNeedBufferVarsContext {
public: public:
DyGraphInferNoNeedBufferVarsContext(const imperative::NameVarBaseMap &inputs, DyGraphInferNoNeedBufferVarsContext(
const imperative::NameVarBaseMap &outputs, const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const AttributeMap &attr); const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs);
bool HasOutput(const std::string &slot) const final; bool HasOutput(const std::string &slot) const final;
private: private:
const imperative::NameVarBaseMap &inputs_; const imperative::NameVarMap<imperative::VariableWrapper> &inputs_;
const imperative::NameVarBaseMap &outputs_; const imperative::NameVarMap<imperative::VariableWrapper> &outputs_;
}; };
class NoNeedBufferVarsInference { class NoNeedBufferVarsInference {
...@@ -106,8 +107,8 @@ class InferNoNeedBufferVarsFN { ...@@ -106,8 +107,8 @@ class InferNoNeedBufferVarsFN {
} }
inline const std::unordered_set<std::string> &operator()( inline const std::unordered_set<std::string> &operator()(
const imperative::NameVarBaseMap &inputs, const imperative::NameVarMap<imperative::VariableWrapper> &inputs,
const imperative::NameVarBaseMap &outputs, const imperative::NameVarMap<imperative::VariableWrapper> &outputs,
const AttributeMap &attrs) const { const AttributeMap &attrs) const {
PADDLE_ENFORCE_NOT_NULL(inferer_); PADDLE_ENFORCE_NOT_NULL(inferer_);
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs); DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
......
...@@ -35,10 +35,10 @@ TEST(test_no_need_buffer_vars_inference, test_static_graph) { ...@@ -35,10 +35,10 @@ TEST(test_no_need_buffer_vars_inference, test_static_graph) {
TEST(test_no_need_buffer_vars_inference, test_dygraph) { TEST(test_no_need_buffer_vars_inference, test_dygraph) {
AttributeMap attrs{{"is_test", true}}; AttributeMap attrs{{"is_test", true}};
imperative::NameVarBaseMap inputs; imperative::NameVarMap<imperative::VariableWrapper> inputs;
imperative::NameVarBaseMap outputs; imperative::NameVarMap<imperative::VariableWrapper> outputs;
outputs["Out"].emplace_back(nullptr); outputs["Out"].emplace_back(nullptr);
outputs["Out"].emplace_back(new imperative::VarBase("tmp_0")); outputs["Out"].emplace_back(new imperative::VariableWrapper("tmp_0"));
DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs); DyGraphInferNoNeedBufferVarsContext ctx(inputs, outputs, attrs);
......
...@@ -1310,8 +1310,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1310,8 +1310,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
for (auto& input : ctx.InNameList()) { for (auto& input : ctx.InNameList()) {
ParseInputDataType(ctx, input, &data_type); ParseInputDataType(ctx, input, &data_type);
} }
PADDLE_ENFORCE_NE(data_type, dafault_data_type, PADDLE_ENFORCE_NE(
"DataType should be indicated by input Variable."); data_type, dafault_data_type,
platform::errors::NotFound(
"DataType should be indicated by input Variable at %s.", Type()));
return data_type; return data_type;
} }
......
...@@ -56,10 +56,11 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>( ...@@ -56,10 +56,11 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const std::vector<BlockDesc*>& grad_block)>; const std::vector<BlockDesc*>& grad_block)>;
using DygraphGradOpMakerFN = using DygraphGradOpMakerFN =
std::function<std::vector<std::unique_ptr<imperative::OpBase>>( std::function<std::vector<std::shared_ptr<imperative::OpBase>>(
const imperative::OpBase* fw_op_base, const std::string& /*op_type*/,
const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& var_base_map_out)>; const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/)>;
using InferVarTypeFN = using InferVarTypeFN =
std::function<void(framework::InferVarTypeContext* /*context*/)>; std::function<void(framework::InferVarTypeContext* /*context*/)>;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
...@@ -27,37 +28,70 @@ ...@@ -27,37 +28,70 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
enum TracedVarRole { kForward = 0, kBackward = 1 };
template <typename T, TracedVarRole kRole>
class TracedVarList : public std::vector<std::shared_ptr<T>> {
private:
using BaseClass = std::vector<std::shared_ptr<T>>;
public:
using BaseClass::BaseClass;
};
class GradOpBaseMakerBase { class GradOpBaseMakerBase {
public: public:
explicit GradOpBaseMakerBase(const OpBase* fw_op_base, explicit GradOpBaseMakerBase(const std::string& type,
const NameVarBaseMap& var_base_map_in, const NameVarBaseMap& var_base_map_in,
const NameVarBaseMap& var_base_map_out) const NameVarBaseMap& var_base_map_out,
: fw_op_base_(fw_op_base), const framework::AttributeMap& attrs)
: type_(type),
var_base_map_in_(var_base_map_in), var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out) {} var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
virtual ~GradOpBaseMakerBase() = default; virtual ~GradOpBaseMakerBase() = default;
virtual std::vector<std::unique_ptr<OpBase>> operator()() const = 0; virtual std::vector<std::shared_ptr<OpBase>> operator()() const = 0;
static std::shared_ptr<OpBase> CreateOp() {
return std::make_shared<OpBase>();
}
std::vector<std::shared_ptr<VarBase>> InputGrad( TracedVarList<VarBase, TracedVarRole::kBackward> InputGrad(
const std::string& name, bool drop_empty_grad = true) const { const std::string& name, bool drop_empty_grad = true) const {
return GetVarBaseList(name, true, true); return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/true);
} }
std::vector<std::shared_ptr<VarBase>> OutputGrad( TracedVarList<VarBase, TracedVarRole::kBackward> OutputGrad(
const std::string& name) const { const std::string& name) const {
return GetVarBaseList(name, true, false); return GetVarBaseList<TracedVarRole::kBackward>(name, /*is_input=*/false);
}
TracedVarList<VarBase, TracedVarRole::kForward> Input(
const std::string& name) const {
return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/true);
}
TracedVarList<VarBase, TracedVarRole::kForward> Output(
const std::string& name) const {
return GetVarBaseList<TracedVarRole::kForward>(name, /*is_input=*/false);
}
static TracedVarList<VarBase, TracedVarRole::kForward> EmptyInput() {
return {};
} }
std::vector<std::shared_ptr<VarBase>> Input(const std::string name) const { static TracedVarList<VarBase, TracedVarRole::kForward> EmptyOutput() {
return GetVarBaseList(name, false, true); return {};
} }
std::vector<std::shared_ptr<VarBase>> Output(const std::string& name) const { static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyOutputGrad() {
return GetVarBaseList(name, false, false); return {};
} }
std::vector<std::shared_ptr<VarBase>> Empty() const { return {}; } static TracedVarList<VarBase, TracedVarRole::kBackward> EmptyInputGrad() {
return {};
}
std::vector<std::string> InputNames() const { std::vector<std::string> InputNames() const {
std::vector<std::string> vec_temp; std::vector<std::string> vec_temp;
...@@ -65,7 +99,6 @@ class GradOpBaseMakerBase { ...@@ -65,7 +99,6 @@ class GradOpBaseMakerBase {
for (auto& it : var_base_map_in_) { for (auto& it : var_base_map_in_) {
vec_temp.emplace_back(it.first); vec_temp.emplace_back(it.first);
} }
return vec_temp; return vec_temp;
} }
...@@ -75,21 +108,17 @@ class GradOpBaseMakerBase { ...@@ -75,21 +108,17 @@ class GradOpBaseMakerBase {
for (auto& it : var_base_map_out_) { for (auto& it : var_base_map_out_) {
vec_temp.emplace_back(it.first); vec_temp.emplace_back(it.first);
} }
return vec_temp; return vec_temp;
} }
const std::unordered_map<std::string, framework::Attribute>& Attrs() const { const framework::AttributeMap& Attrs() const { return attrs_; }
return fw_op_base_->Attrs();
}
const framework::Attribute& GetAttr(const std::string& name) const { const framework::Attribute& GetAttr(const std::string& name) const {
auto& map = fw_op_base_->Attrs(); auto it = attrs_.find(name);
auto it = map.find(name); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(it != map.end(), it != attrs_.end(), true,
"Cannot find attribute [%s] in operator [%s]", name, platform::errors::NotFound(
fw_op_base_->Type()); "Cannot find attribute [%s] in operator [%s]", name, type_));
return it->second; return it->second;
} }
...@@ -98,41 +127,37 @@ class GradOpBaseMakerBase { ...@@ -98,41 +127,37 @@ class GradOpBaseMakerBase {
return boost::get<T>(GetAttr(name)); return boost::get<T>(GetAttr(name));
} }
std::string ForwardOpType() const { return fw_op_base_->Type(); } const std::string& ForwardOpType() const { return type_; }
protected: protected:
bool HasInput(const std::string& name) const { bool HasInput(const std::string& name) const {
auto it = var_base_map_in_.find(name); return var_base_map_in_.count(name) > 0;
return it != var_base_map_in_.end();
} }
bool HasOutput(const std::string name) const { bool HasOutput(const std::string& name) const {
auto it = var_base_map_out_.find(name); return var_base_map_out_.count(name) > 0;
return it != var_base_map_out_.end();
} }
private: private:
std::vector<std::shared_ptr<VarBase>> GetVarBaseList(const std::string& name, template <TracedVarRole kRole>
bool is_grad, TracedVarList<VarBase, kRole> GetVarBaseList(const std::string& name,
bool is_input) const { bool is_input) const {
const NameVarBaseMap& data_map = const auto& data_map = is_input ? var_base_map_in_ : var_base_map_out_;
is_input ? var_base_map_in_ : var_base_map_out_;
auto iterator = data_map.find(name); auto iterator = data_map.find(name);
std::vector<std::shared_ptr<imperative::VarBase>> vec_temp; TracedVarList<VarBase, kRole> vec_temp;
if (iterator != data_map.end()) { if (iterator != data_map.end()) {
vec_temp.reserve(iterator->second.size()); vec_temp.reserve(iterator->second.size());
for (auto& var_base_temp : iterator->second) { for (auto& var_base_temp : iterator->second) {
if (is_grad) { if (kRole == TracedVarRole::kBackward) {
if (!var_base_temp->HasGradVar()) { if (!var_base_temp->HasGradVar()) {
VLOG(6) << "GradVarBase of var " << var_base_temp->Name() VLOG(6) << "GradVarBase of var " << var_base_temp->Name()
<< " in OP " << fw_op_base_->Type() << " is null"; << " in OP " << type_ << " is null";
var_base_temp->MutableGradVarBase(); var_base_temp->MutableGradVarBase();
} }
auto grad_var_base_tmp = var_base_temp->GradVarBase(); auto grad_var_base_tmp = var_base_temp->GradVarBase();
if (!is_input) { if (!is_input) {
auto* tensor = grad_var_base_tmp->MutableVar() auto* tensor = grad_var_base_tmp->MutableVar()
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
...@@ -150,12 +175,91 @@ class GradOpBaseMakerBase { ...@@ -150,12 +175,91 @@ class GradOpBaseMakerBase {
} }
private: private:
const OpBase* fw_op_base_; const std::string& type_;
const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_; const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
};
protected: class TracedGradOp {
std::vector<framework::BlockDesc*> grad_block_; DISABLE_COPY_AND_ASSIGN(TracedGradOp);
public:
explicit TracedGradOp(const std::shared_ptr<OpBase>& op) : op_(op) {}
~TracedGradOp() {
op_->SetGradPendingOps(
{grad_pending_ops_.begin(), grad_pending_ops_.end()});
op_->CheckAttrs();
}
template <TracedVarRole kRole>
void SetInput(const std::string& name,
const TracedVarList<VarBase, kRole>& vars) {
if (kRole == TracedVarRole::kBackward) {
for (auto& var : vars) {
var->AddGradOp(op_);
}
}
op_->SetInput(name, ToVarWrapperList(vars));
}
template <TracedVarRole kRole>
void SetOutput(const std::string& name,
const TracedVarList<VarBase, kRole>& vars) {
if (kRole == TracedVarRole::kBackward) {
if (vars.size() == 1 && vars.front()->OverridedStopGradient()) {
op_->SetOutput(name, VariableWrapperList{});
return;
} else {
for (auto& var : vars) {
if (!var->OverridedStopGradient()) {
for (auto& op : var->GradOps()) {
grad_pending_ops_.emplace(op);
}
}
}
}
}
op_->SetOutput(name, ToVarWrapperList(vars));
}
void SetType(const std::string& type) { op_->SetType(type); }
void SetAttrMap(const framework::AttributeMap& attrs) {
return op_->SetAttrMap(attrs);
}
void SetAttr(const std::string& name, const framework::Attribute& v) {
op_->SetAttr(name, v);
}
bool HasAttr(const std::string& name) const { return op_->HasAttr(name); }
const framework::Attribute& GetAttr(const std::string& name) const {
return op_->GetAttr(name);
}
template <typename T>
inline const T& Attr(const std::string& name) const {
return op_->Attr<T>(name);
}
private:
static std::vector<std::shared_ptr<VariableWrapper>> ToVarWrapperList(
const std::vector<std::shared_ptr<VarBase>>& vars) {
std::vector<std::shared_ptr<VariableWrapper>> result;
result.reserve(vars.size());
for (auto& var : vars) {
result.emplace_back(var->SharedVar());
}
return result;
}
private:
const std::shared_ptr<OpBase>& op_;
std::unordered_set<std::shared_ptr<OpBase>> grad_pending_ops_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -30,16 +30,9 @@ ...@@ -30,16 +30,9 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
void Engine::RunOp(paddle::imperative::OpBase* op,
const paddle::imperative::NameVarBaseMap& ins,
const paddle::imperative::NameVarBaseMap& outs,
const paddle::platform::Place& place) {
op->Run(ins, outs);
}
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
backward_strategy_ = strategy; backward_strategy_ = strategy;
const std::vector<OpBase*> ops = var->GradVarBase()->GradOps(); const auto& ops = var->GradVarBase()->GradOps();
var->ClearGradOps(); var->ClearGradOps();
if (ops.empty() || var->OverridedStopGradient()) { if (ops.empty() || var->OverridedStopGradient()) {
...@@ -59,7 +52,9 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { ...@@ -59,7 +52,9 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
return; return;
} }
} }
init_ops_ = ops; init_ops_ = ops;
var->GradVarBase()->ClearGradOps();
VLOG(3) << "start backward"; VLOG(3) << "start backward";
PADDLE_ENFORCE_EQ(var->HasGradVar(), true, PADDLE_ENFORCE_EQ(var->HasGradVar(), true,
...@@ -71,7 +66,6 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { ...@@ -71,7 +66,6 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
VLOG(6) << "init loss grad:" << var->GradVarBase()->Name() VLOG(6) << "init loss grad:" << var->GradVarBase()->Name()
<< " as stop_gradient false"; << " as stop_gradient false";
var->GradVarBase()->InnerSetOverridedStopGradient(false); var->GradVarBase()->InnerSetOverridedStopGradient(false);
var->GradVarBase()->SetGradGenerated(true);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place()); auto* dev_ctx = platform::DeviceContextPool::Instance().Get(fwd_var.place());
grad_var->Resize(fwd_var.dims()); grad_var->Resize(fwd_var.dims());
grad_var->mutable_data(fwd_var.place(), fwd_var.type()); grad_var->mutable_data(fwd_var.place(), fwd_var.type());
...@@ -81,35 +75,29 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { ...@@ -81,35 +75,29 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
void BasicEngine::CheckBackwardInputs(OpBase* op) { void BasicEngine::CheckBackwardInputs(OpBase* op) {
for (auto& pair : op->GetInsMap()) { for (auto& pair : op->GetInsMap()) {
for (auto& var : pair.second) { for (auto& var : pair.second) {
if (var && IsGrad(var.get())) { if (!var || op->IsAllowedEmptyVar(var.get())) {
// if grad var has OverridedStopGradient skip this Op continue;
if (!var->GradGenerated()) {
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(op->place());
auto* tensor = var->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->mutable_data(op->place(), var->DataType());
operators::math::set_constant(*dev_ctx, tensor, 0.0);
} else {
continue;
}
} }
}
}
}
void BasicEngine::SetBackwardOutputs(paddle::imperative::OpBase* op) { auto* inner_var = var->MutableVar();
for (auto& pair : op->GetOutsMap()) { framework::Tensor* tensor = nullptr;
for (auto& var : pair.second) { if (!inner_var->IsInitialized() ||
if (var) { inner_var->IsType<framework::LoDTensor>()) {
// Set Backward outputs's generate_grad as true tensor = inner_var->GetMutable<framework::LoDTensor>();
var->SetGradGenerated(true); }
VLOG(6) << "Set backward output: " << var->Name()
<< "'s SetGeneratedGrad as True"; if (tensor && !tensor->IsInitialized()) {
// if grad var has OverridedStopGradient skip this Op
VLOG(6) << "Set ungenerated Grad: " << var->Name() << " as zero";
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(op->place());
tensor->mutable_data(op->place(), var->DataType());
operators::math::set_constant(*dev_ctx, tensor, 0.0);
} }
} }
} }
} }
void BasicEngine::PrepareGradAccumulators(OpBase* op) { void BasicEngine::PrepareGradAccumulators(OpBase* op) {
for (const auto& pair : op->GetOutsMap()) { for (const auto& pair : op->GetOutsMap()) {
for (const auto& var : pair.second) { for (const auto& var : pair.second) {
...@@ -140,50 +128,63 @@ void BasicEngine::PrepareDeps() { ...@@ -140,50 +128,63 @@ void BasicEngine::PrepareDeps() {
std::queue<OpBase*> q; std::queue<OpBase*> q;
std::unordered_set<OpBase*> visited; std::unordered_set<OpBase*> visited;
for (const auto& init_op : init_ops_) { for (const auto& init_op : init_ops_) {
q.push(init_op); q.push(init_op.get());
visited.insert(init_op); visited.insert(init_op.get());
} }
while (!q.empty()) { while (!q.empty()) {
auto* cur_op = q.front(); auto* cur_op = q.front();
q.pop(); q.pop();
VLOG(3) << "Checking grads of op " << cur_op->Type();
SetBackwardOutputs(cur_op); PADDLE_ENFORCE_NE(
cur_op->GetInsMap().empty() && cur_op->GetOutsMap().empty(), true,
platform::errors::NotFound(
"Inputs and outputs of %s do not exist. "
"This may be because you call \"backward()\" twice for the same "
"subgraph. Please try to call \"stop_gradient = True\" or "
"\"detach()\" if you use some same vars between two \"backward()\" "
"calls.",
cur_op->Type()));
PrepareGradAccumulators(cur_op); PrepareGradAccumulators(cur_op);
auto& grad_pending_ops = cur_op->GradPendingOps(); const auto& grad_pending_ops = cur_op->GradPendingOps();
for (auto* grad_pending_op : grad_pending_ops) { for (auto& grad_pending_op : grad_pending_ops) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op); PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
++op_deps_[grad_pending_op]; ++op_deps_[grad_pending_op.get()];
if (visited.count(grad_pending_op) == 0) { if (visited.count(grad_pending_op.get()) == 0) {
visited.insert(grad_pending_op); visited.insert(grad_pending_op.get());
q.push(grad_pending_op); q.push(grad_pending_op.get());
} }
} }
} }
} }
void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VarBase> src, void BasicEngine::SumGradient(OpBase* op, std::shared_ptr<VariableWrapper> src,
VarBase* dst) { VariableWrapper* dst) {
auto iter = accumulators_.find(dst); auto iter = accumulators_.find(dst);
PADDLE_ENFORCE_EQ(iter != accumulators_.end(), true, PADDLE_ENFORCE_EQ(iter != accumulators_.end(), true,
"Cannot find gradient of variable %s", dst->Name()); "Cannot find gradient of variable %s", dst->Name());
iter->second->Add(std::move(src), op->id()); iter->second->Add(std::move(src), op->id());
} }
void BasicEngine::Execute() { void BasicEngine::Execute() {
PrepareDeps(); PrepareDeps();
// Start execute Computation graph // Start execute Computation graph
std::queue<OpBase*> q; std::queue<std::shared_ptr<OpBase>> q;
for (const auto& init_op : init_ops_) { for (const auto& init_op : init_ops_) {
q.push(init_op); q.push(std::move(init_op));
} }
size_t op_num = 0;
while (!q.empty()) { while (!q.empty()) {
OpBase* cur_op = q.front(); auto shared_cur_op = std::move(q.front());
q.pop(); q.pop();
auto* cur_op = shared_cur_op.get();
++op_num;
// CheckBackWardInput // CheckBackWardInput
CheckBackwardInputs(cur_op); CheckBackwardInputs(cur_op);
...@@ -191,26 +192,28 @@ void BasicEngine::Execute() { ...@@ -191,26 +192,28 @@ void BasicEngine::Execute() {
auto& bwd_ins = cur_op->GetInsMap(); auto& bwd_ins = cur_op->GetInsMap();
auto& bwd_outs = cur_op->GetOutsMap(); auto& bwd_outs = cur_op->GetOutsMap();
NameVarBaseMap tmp_outs(bwd_outs); NameVarMap<VariableWrapper> tmp_outs(bwd_outs);
// 1. construct the output map 2. replace the element in the map // 1. construct the output map 2. replace the element in the map
// A var may be coresponding to several grad var in one op // A var may be coresponding to several grad var in one op
for (auto it = tmp_outs.begin(); it != tmp_outs.end(); ++it) { for (auto it = tmp_outs.begin(); it != tmp_outs.end(); ++it) {
for (size_t i = 0; i < it->second.size(); ++i) { for (size_t i = 0; i < it->second.size(); ++i) {
auto tmp_var = auto tmp_var =
std::make_shared<VarBase>(false, "Gtmp@"); // Do not need grad std::make_shared<VariableWrapper>("Gtmp@"); // Do not need grad
auto var = it->second[i]; auto var = it->second[i];
it->second[i] = tmp_var; it->second[i] = tmp_var;
if (var) { if (var) {
need_accu_var_list_.emplace_back( need_accu_var_list_.emplace_back(var.get(), std::move(tmp_var));
make_pair(var.get(), std::move(tmp_var)));
var->ClearGradOps();
} }
} }
} }
VLOG(3) << "Start to execute grad op " << cur_op->Type(); {
RunOp(cur_op, bwd_ins, tmp_outs, cur_op->place()); VLOG(3) << "Start to execute grad op " << cur_op->Type();
OpBase::Run(cur_op->InnerOp(), bwd_ins, tmp_outs, cur_op->Attrs(),
cur_op->place());
}
// Step 2: Sum Gradient // Step 2: Sum Gradient
if (need_accu_var_list_.size() > 0) { if (need_accu_var_list_.size() > 0) {
...@@ -223,9 +226,9 @@ void BasicEngine::Execute() { ...@@ -223,9 +226,9 @@ void BasicEngine::Execute() {
// Step 3: Collect ready ops // Step 3: Collect ready ops
for (auto* grad_pending_op : cur_op->GradPendingOps()) { for (auto& grad_pending_op : cur_op->GradPendingOps()) {
PADDLE_ENFORCE_NOT_NULL(grad_pending_op); PADDLE_ENFORCE_NOT_NULL(grad_pending_op);
auto iter = op_deps_.find(grad_pending_op); auto iter = op_deps_.find(grad_pending_op.get());
if (iter == op_deps_.end()) { if (iter == op_deps_.end()) {
continue; continue;
} }
...@@ -242,10 +245,11 @@ void BasicEngine::Execute() { ...@@ -242,10 +245,11 @@ void BasicEngine::Execute() {
// Step 4: Delete op to collect unused variables // Step 4: Delete op to collect unused variables
VLOG(3) << "Remove op after op " << cur_op->Type() << " runs"; VLOG(3) << "Remove op after op " << cur_op->Type() << " runs";
RemoveOp(cur_op); cur_op->ClearBackwardTrace();
} }
VLOG(3) << "Clean properties of BasicEngine"; Clear();
CleanEngine();
VLOG(1) << "Backward op number: " << op_num;
} }
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -37,49 +37,12 @@ class Engine { ...@@ -37,49 +37,12 @@ class Engine {
virtual ~Engine() = default; virtual ~Engine() = default;
virtual void Execute() = 0; virtual void Execute() = 0;
virtual void Init(VarBase* var, const detail::BackwardStrategy& strategy) = 0; virtual void Init(VarBase* var, const detail::BackwardStrategy& strategy) = 0;
virtual void RunOp(imperative::OpBase* op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const platform::Place& place);
virtual void RemoveOp(OpBase* op) {
PADDLE_ENFORCE_NOT_NULL(op, "Cannot remove null op");
auto iter = grad_ops_.find(op);
PADDLE_ENFORCE_EQ(iter != grad_ops_.end(), true, "Op is not inside tracer");
grad_ops_.erase(iter);
}
void InsertOp(OpBase* op, std::shared_ptr<OpBase> op_shared) {
grad_ops_[op] = std::move(op_shared);
}
const std::unordered_set<VarBase*>& GradVars() const { return grad_vars_; }
const std::unordered_map<OpBase*, std::shared_ptr<OpBase>>& GradOps() const {
return grad_ops_;
}
void InsertGradVar(VarBase* grad) { grad_vars_.emplace(grad); }
bool IsGrad(VarBase* var) { return grad_vars_.count(var) > 0; }
void Clear() {
grad_ops_.clear();
grad_vars_.clear();
}
private:
std::unordered_map<OpBase*, std::shared_ptr<OpBase>>
grad_ops_; // opBase for remove - grad_op
std::unordered_set<VarBase*> grad_vars_;
}; };
class BasicEngine : public Engine { class BasicEngine : public Engine {
public: public:
BasicEngine() = default;
void Init(VarBase* var, const detail::BackwardStrategy& strategy) override; void Init(VarBase* var, const detail::BackwardStrategy& strategy) override;
~BasicEngine() override = default;
void Execute() override; void Execute() override;
private: private:
...@@ -87,28 +50,26 @@ class BasicEngine : public Engine { ...@@ -87,28 +50,26 @@ class BasicEngine : public Engine {
void CheckBackwardInputs(OpBase* op); void CheckBackwardInputs(OpBase* op);
void SetBackwardOutputs(OpBase* op);
void PrepareGradAccumulators(OpBase* op); void PrepareGradAccumulators(OpBase* op);
void SumGradient(OpBase* op, std::shared_ptr<VarBase> src, VarBase* dst); void SumGradient(OpBase* op, std::shared_ptr<VariableWrapper> src,
VariableWrapper* dst);
// TODO(jiabin): maybe we can optimize the performance of engine by cache the // TODO(jiabin): maybe we can optimize the performance of engine by cache the
// result // result
void CleanEngine() { void Clear() {
init_ops_.clear(); init_ops_.clear();
op_deps_.clear(); op_deps_.clear();
accumulators_.clear(); accumulators_.clear();
Clear();
} }
std::vector<OpBase*> init_ops_; std::vector<std::shared_ptr<OpBase>> init_ops_;
detail::BackwardStrategy backward_strategy_; detail::BackwardStrategy backward_strategy_;
std::unordered_map<OpBase*, size_t> op_deps_; std::unordered_map<OpBase*, size_t> op_deps_;
std::unordered_map<VarBase*, std::unique_ptr<GradientAccumulator>> std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_; accumulators_;
std::vector<std::pair<VarBase*, std::shared_ptr<VarBase>>> std::vector<std::pair<VariableWrapper*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_; need_accu_var_list_;
}; };
......
...@@ -144,8 +144,8 @@ void SelectedRowsAddToTensor(const framework::Variable& src, ...@@ -144,8 +144,8 @@ void SelectedRowsAddToTensor(const framework::Variable& src,
// Note(chenweihang): when two selected rows need to be added, // Note(chenweihang): when two selected rows need to be added,
// adding one to another is not equal to merging two selected rows // adding one to another is not equal to merging two selected rows
// to one then add it to a empty selected rows, the after is correct // to one then add it to a empty selected rows, the after is correct
std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1, std::shared_ptr<VariableWrapper> SelectedRowsMerge(
const framework::Variable& src2) { const framework::Variable& src1, const framework::Variable& src2) {
auto& src_selected_rows1 = src1.Get<framework::SelectedRows>(); auto& src_selected_rows1 = src1.Get<framework::SelectedRows>();
auto& src_selected_rows2 = src2.Get<framework::SelectedRows>(); auto& src_selected_rows2 = src2.Get<framework::SelectedRows>();
auto place = src_selected_rows1.value().place(); auto place = src_selected_rows1.value().place();
...@@ -155,7 +155,7 @@ std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1, ...@@ -155,7 +155,7 @@ std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1,
std::vector<const framework::SelectedRows*> src_selected_rows; std::vector<const framework::SelectedRows*> src_selected_rows;
src_selected_rows.emplace_back(&src_selected_rows1); src_selected_rows.emplace_back(&src_selected_rows1);
src_selected_rows.emplace_back(&src_selected_rows2); src_selected_rows.emplace_back(&src_selected_rows2);
auto dst_var = std::make_shared<VarBase>(false, "Temp"); auto dst_var = std::make_shared<VariableWrapper>("Temp");
auto* dst_selected_rows = auto* dst_selected_rows =
dst_var->MutableVar()->GetMutable<framework::SelectedRows>(); dst_var->MutableVar()->GetMutable<framework::SelectedRows>();
...@@ -188,7 +188,8 @@ std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1, ...@@ -188,7 +188,8 @@ std::shared_ptr<VarBase> SelectedRowsMerge(const framework::Variable& src1,
framework::DataTypeToString(data_type))); framework::DataTypeToString(data_type)));
} }
void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) { void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
VariableWrapper* var_) {
auto& src = var->Var(); auto& src = var->Var();
auto* dst = var_->MutableVar(); auto* dst = var_->MutableVar();
if (dst->IsType<framework::LoDTensor>()) { if (dst->IsType<framework::LoDTensor>()) {
...@@ -208,7 +209,7 @@ void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) { ...@@ -208,7 +209,7 @@ void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) {
*dst = std::move(*(var->MutableVar())); *dst = std::move(*(var->MutableVar()));
var_->SetType(framework::proto::VarType::LOD_TENSOR); var_->SetType(framework::proto::VarType::LOD_TENSOR);
} else if (src.IsType<framework::SelectedRows>()) { } else if (src.IsType<framework::SelectedRows>()) {
std::shared_ptr<VarBase> temp = SelectedRowsMerge(src, *dst); auto temp = SelectedRowsMerge(src, *dst);
*dst = std::move(*(temp->MutableVar())); *dst = std::move(*(temp->MutableVar()));
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -218,7 +219,8 @@ void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) { ...@@ -218,7 +219,8 @@ void VarBaseAdd(std::shared_ptr<VarBase> var, VarBase* var_) {
} }
} }
platform::Place GetPlaceOfVarBase(const std::shared_ptr<VarBase>& var) { static platform::Place GetPlaceOfVar(
const std::shared_ptr<VariableWrapper>& var) {
platform::Place place; platform::Place place;
if (var->Var().IsType<framework::LoDTensor>()) { if (var->Var().IsType<framework::LoDTensor>()) {
place = var->Var().Get<framework::LoDTensor>().place(); place = var->Var().Get<framework::LoDTensor>().place();
...@@ -231,10 +233,10 @@ platform::Place GetPlaceOfVarBase(const std::shared_ptr<VarBase>& var) { ...@@ -231,10 +233,10 @@ platform::Place GetPlaceOfVarBase(const std::shared_ptr<VarBase>& var) {
return place; return place;
} }
void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var, void EagerGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
size_t trace_id) { size_t trace_id) {
auto* dst_var = var_->MutableVar(); auto* dst_var = var_->MutableVar();
platform::Place place = GetPlaceOfVarBase(var); platform::Place place = GetPlaceOfVar(var);
if (!var_->OverridedStopGradient()) { if (!var_->OverridedStopGradient()) {
VLOG(3) << "Sum Gradient for: " << var_->Name(); VLOG(3) << "Sum Gradient for: " << var_->Name();
if (cur_cnt_ == 0) { if (cur_cnt_ == 0) {
...@@ -243,7 +245,7 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var, ...@@ -243,7 +245,7 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
} }
*dst_var = std::move(*(var->MutableVar())); *dst_var = std::move(*(var->MutableVar()));
} else { } else {
VarBaseAdd(var, var_); VariableWrapperAdd(var, var_);
} }
} else { } else {
if (!var_->Var().IsInitialized() || if (!var_->Var().IsInitialized() ||
...@@ -268,10 +270,10 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var, ...@@ -268,10 +270,10 @@ void EagerGradientAccumulator::Add(std::shared_ptr<VarBase> var,
++cur_cnt_; ++cur_cnt_;
} }
void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var, void SortedGradientAccumulator::Add(std::shared_ptr<VariableWrapper> var,
size_t trace_id) { size_t trace_id) {
auto* dst_var = var_->MutableVar(); auto* dst_var = var_->MutableVar();
platform::Place place = GetPlaceOfVarBase(var); platform::Place place = GetPlaceOfVar(var);
if (!var_->OverridedStopGradient()) { if (!var_->OverridedStopGradient()) {
if (ref_cnt_ == 1) { if (ref_cnt_ == 1) {
if (var->Var().IsType<framework::SelectedRows>()) { if (var->Var().IsType<framework::SelectedRows>()) {
...@@ -291,11 +293,12 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var, ...@@ -291,11 +293,12 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
return; return;
} }
std::sort(tmp_grad_vars_.begin(), tmp_grad_vars_.end(), std::sort(
[](const std::pair<std::shared_ptr<VarBase>, size_t>& p1, tmp_grad_vars_.begin(), tmp_grad_vars_.end(),
const std::pair<std::shared_ptr<VarBase>, size_t>& p2) { [](const std::pair<std::shared_ptr<VariableWrapper>, size_t>& p1,
return p1.second > p2.second; const std::pair<std::shared_ptr<VariableWrapper>, size_t>& p2) {
}); return p1.second > p2.second;
});
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (paddle::platform::is_gpu_place(place)) { if (paddle::platform::is_gpu_place(place)) {
...@@ -310,7 +313,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var, ...@@ -310,7 +313,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
var_->SetType(framework::proto::VarType::SELECTED_ROWS); var_->SetType(framework::proto::VarType::SELECTED_ROWS);
*dst_var = std::move(*(tmp_grad_vars_[i].first->MutableVar())); *dst_var = std::move(*(tmp_grad_vars_[i].first->MutableVar()));
} else { } else {
VarBaseAdd(tmp_grad_vars_[i].first, var_); VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
} }
} }
} }
...@@ -321,7 +324,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var, ...@@ -321,7 +324,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
} }
if (tmp_grad_vars_[i].first->Var().IsType<framework::LoDTensor>()) { if (tmp_grad_vars_[i].first->Var().IsType<framework::LoDTensor>()) {
VarBaseAdd(tmp_grad_vars_[i].first, var_); VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
} }
} }
} else { } else {
...@@ -333,7 +336,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var, ...@@ -333,7 +336,7 @@ void SortedGradientAccumulator::Add(std::shared_ptr<VarBase> var,
*dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar())); *dst_var = std::move(*(tmp_grad_vars_[0].first->MutableVar()));
} }
for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) { for (size_t i = 1; i < tmp_grad_vars_.size(); ++i) {
VarBaseAdd(tmp_grad_vars_[i].first, var_); VariableWrapperAdd(tmp_grad_vars_[i].first, var_);
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
} }
......
...@@ -24,9 +24,9 @@ namespace imperative { ...@@ -24,9 +24,9 @@ namespace imperative {
class GradientAccumulator { class GradientAccumulator {
public: public:
explicit GradientAccumulator(VarBase* var) : var_(var) {} explicit GradientAccumulator(VariableWrapper* var) : var_(var) {}
virtual void Add(std::shared_ptr<VarBase> var, size_t trace_id) = 0; virtual void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) = 0;
virtual ~GradientAccumulator() = default; virtual ~GradientAccumulator() = default;
...@@ -35,7 +35,7 @@ class GradientAccumulator { ...@@ -35,7 +35,7 @@ class GradientAccumulator {
inline size_t RefCnt() const { return ref_cnt_; } inline size_t RefCnt() const { return ref_cnt_; }
protected: protected:
VarBase* var_; VariableWrapper* var_;
size_t ref_cnt_{0}; size_t ref_cnt_{0};
}; };
...@@ -43,7 +43,7 @@ class EagerGradientAccumulator : public GradientAccumulator { ...@@ -43,7 +43,7 @@ class EagerGradientAccumulator : public GradientAccumulator {
public: public:
using GradientAccumulator::GradientAccumulator; using GradientAccumulator::GradientAccumulator;
void Add(std::shared_ptr<VarBase> var, size_t trace_id) override; void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) override;
private: private:
size_t cur_cnt_{0}; size_t cur_cnt_{0};
...@@ -53,10 +53,11 @@ class SortedGradientAccumulator : public GradientAccumulator { ...@@ -53,10 +53,11 @@ class SortedGradientAccumulator : public GradientAccumulator {
public: public:
using GradientAccumulator::GradientAccumulator; using GradientAccumulator::GradientAccumulator;
void Add(std::shared_ptr<VarBase> var, size_t trace_id) override; void Add(std::shared_ptr<VariableWrapper> var, size_t trace_id) override;
private: private:
std::vector<std::pair<std::shared_ptr<VarBase>, size_t>> tmp_grad_vars_; std::vector<std::pair<std::shared_ptr<VariableWrapper>, size_t>>
tmp_grad_vars_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -113,9 +113,10 @@ static framework::RuntimeContext PrepareRuntimeContext( ...@@ -113,9 +113,10 @@ static framework::RuntimeContext PrepareRuntimeContext(
return framework::RuntimeContext(std::move(inputs), std::move(outputs)); return framework::RuntimeContext(std::move(inputs), std::move(outputs));
} }
template <typename VarType>
static std::string DebugString( static std::string DebugString(
const std::string& name, const std::string& name,
const std::vector<std::shared_ptr<VarBase>>& vars) { const std::vector<std::shared_ptr<VarType>>& vars) {
std::stringstream ss; std::stringstream ss;
ss << name << "{"; ss << name << "{";
...@@ -127,7 +128,7 @@ static std::string DebugString( ...@@ -127,7 +128,7 @@ static std::string DebugString(
continue; continue;
} }
ss << vars[i]->Name() << "["; ss << vars[i]->Name() << "[";
auto& var = vars[i]->Var(); const framework::Variable& var = vars[i]->Var();
if (!var.IsInitialized()) { if (!var.IsInitialized()) {
ss << "NOT_INITED_VAR"; ss << "NOT_INITED_VAR";
} else if (var.IsType<framework::LoDTensor>()) { } else if (var.IsType<framework::LoDTensor>()) {
...@@ -167,9 +168,10 @@ static std::string DebugString( ...@@ -167,9 +168,10 @@ static std::string DebugString(
return ss.str(); return ss.str();
} }
std::string LayerDebugString(const std::string& op_type, template <typename VarType>
const NameVarBaseMap& ins, static std::string LayerDebugStringImpl(const std::string& op_type,
const NameVarBaseMap& outs) { const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs) {
std::stringstream ss; std::stringstream ss;
ss << "Op(" << op_type << "): "; ss << "Op(" << op_type << "): ";
...@@ -192,28 +194,30 @@ std::string LayerDebugString(const std::string& op_type, ...@@ -192,28 +194,30 @@ std::string LayerDebugString(const std::string& op_type,
return ss.str(); return ss.str();
} }
void VarBase::AddGradOps(const std::weak_ptr<OpBase>& op) { std::string LayerDebugString(const std::string& op_type,
if (op.lock() == nullptr) { const NameVarMap<VarBase>& ins,
return; const NameVarMap<VarBase>& outs) {
} return LayerDebugStringImpl<VarBase>(op_type, ins, outs);
for (const auto& cur_op : grad_ops_) { }
if (cur_op.lock() == op.lock()) {
return; std::string LayerDebugString(const std::string& op_type,
} const NameVarMap<VariableWrapper>& ins,
} const NameVarMap<VariableWrapper>& outs) {
grad_ops_.emplace_back(op); return LayerDebugStringImpl<VariableWrapper>(op_type, ins, outs);
} }
void VarBase::ClearGradient() { void VarBase::ClearGradient() {
if (grad_var_) { if (grad_var_) {
if (grad_var_->var_.IsType<framework::SelectedRows>()) { if (grad_var_->Var().IsType<framework::SelectedRows>()) {
auto* grad_t = grad_var_->var_.GetMutable<framework::SelectedRows>(); auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::SelectedRows>();
if (grad_t->mutable_value()->IsInitialized()) { if (grad_t->mutable_value()->IsInitialized()) {
grad_t->mutable_rows()->clear(); grad_t->mutable_rows()->clear();
grad_t->mutable_value()->clear(); grad_t->mutable_value()->clear();
} }
} else { } else {
auto* grad_t = grad_var_->var_.GetMutable<framework::LoDTensor>(); auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
if (grad_t->IsInitialized()) { if (grad_t->IsInitialized()) {
auto* dev_ctx = auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place()); platform::DeviceContextPool::Instance().Get(grad_t->place());
...@@ -226,19 +230,20 @@ void VarBase::ClearGradient() { ...@@ -226,19 +230,20 @@ void VarBase::ClearGradient() {
std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
const bool blocking) const { const bool blocking) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_.IsInitialized() && (var_.IsType<framework::LoDTensor>() || Var().IsInitialized() && (Var().IsType<framework::LoDTensor>() ||
var_.IsType<framework::SelectedRows>()), Var().IsType<framework::SelectedRows>()),
true, platform::errors::InvalidArgument( true, platform::errors::InvalidArgument(
"Variable is not initialized or Variable's type is not " "Variable is not initialized or Variable's type is not "
"LoDTensor or SelectedRows when getting numpy tensor")); "LoDTensor or SelectedRows when getting numpy tensor"));
if (var_.IsType<framework::LoDTensor>()) { if (Var().IsType<framework::LoDTensor>()) {
auto& src_tensor = var_.Get<framework::LoDTensor>(); auto& src_tensor = Var().Get<framework::LoDTensor>();
// TODO(Jiabin): change this after move unique_name generator to CXX // TODO(Jiabin): change this after move unique_name generator to CXX
auto new_var = std::make_shared<VarBase>( auto new_var = std::make_shared<VarBase>(
true, Name() + std::to_string(copied_counter_++)); true, Name() + std::to_string(copied_counter_++));
auto* dst_tensor = new_var->var_.GetMutable<framework::LoDTensor>(); auto* dst_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>();
dst_tensor->set_lod(src_tensor.lod()); dst_tensor->set_lod(src_tensor.lod());
new_var->SetPersistable(Persistable()); new_var->SetPersistable(Persistable());
new_var->SetDataType(DataType()); new_var->SetDataType(DataType());
...@@ -257,12 +262,12 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -257,12 +262,12 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
} }
return new_var; return new_var;
} else { } else {
auto& src_selected_rows = var_.Get<framework::SelectedRows>(); auto& src_selected_rows = Var().Get<framework::SelectedRows>();
auto new_var = std::make_shared<VarBase>( auto new_var = std::make_shared<VarBase>(
false, "Itmp" + std::to_string(copied_counter_++)); false, "Itmp" + std::to_string(copied_counter_++));
new_var->SetType(framework::proto::VarType::SELECTED_ROWS); new_var->SetType(framework::proto::VarType::SELECTED_ROWS);
auto* dst_selected_rows = auto* dst_selected_rows =
new_var->var_.GetMutable<framework::SelectedRows>(); new_var->MutableVar()->GetMutable<framework::SelectedRows>();
framework::TensorCopy(src_selected_rows.value(), dst_place, framework::TensorCopy(src_selected_rows.value(), dst_place,
dst_selected_rows->mutable_value()); dst_selected_rows->mutable_value());
...@@ -281,39 +286,32 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -281,39 +286,32 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
return new_var; return new_var;
} }
} }
// create OpBase from optype
OpBase::OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place)
: id_(id), place_(place), attrs_(attrs) {
const auto& info = framework::OpInfoMap::Instance().Get(type);
// Step 1: Run forward
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true);
}
void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
VLOG(3) << "Construct Op: " << type << std::endl;
} }
void OpBase::CreateOperatorBase() { void OpBase::ClearBackwardTrace() {
const auto& info = framework::OpInfoMap::Instance().Get(type_); grad_pending_ops_.clear();
if (info.Checker() != nullptr) { allow_empty_vars_.clear();
info.Checker()->Check(&attrs_, true); ins_.clear();
} outs_.clear();
op_ = framework::OpRegistry::CreateOp(type_, {}, {}, {}, false);
} }
void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) { template <typename VarType>
auto* op_kernel = dynamic_cast<framework::OperatorWithKernel*>(op_.get()); static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel"); PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
auto& info = op_->Info(); auto& info = op.Info();
if (info.infer_var_type_) { if (info.infer_var_type_) {
RuntimeInferVarTypeContext infer_var_type_ctx(ins, &outs, attrs_); RuntimeInferVarTypeContext<VarType> infer_var_type_ctx(ins, &outs, attrs);
info.infer_var_type_(&infer_var_type_ctx); info.infer_var_type_(&infer_var_type_ctx);
} }
// Initialize output var type // Initialize output var type
for (auto& var_pair : outs) { for (auto& var_pair : outs) {
for (auto& var : var_pair.second) { for (auto& var : var_pair.second) {
...@@ -321,20 +319,29 @@ void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) { ...@@ -321,20 +319,29 @@ void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
} }
} }
VLOG(3) << "Running Op " << Type(); // VLOG(3) << "Running Op " << op.Type();
VLOG(5) << LayerDebugString(Type(), ins, outs); VLOG(5) << LayerDebugString(op.Type(), ins, outs);
auto prepared_op = auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs);
PreparedOp::Prepare(ins, outs, *op_kernel, place(), &attrs_);
prepared_op.Run(&ins, &outs, &attrs_); prepared_op.Run(ins, outs, attrs);
VLOG(4) << LayerDebugString(Type(), ins, outs); VLOG(4) << LayerDebugString(op.Type(), ins, outs);
} }
void OpBase::ClearBackwardTrace() { void OpBase::Run(const framework::OperatorBase& op,
grad_pending_ops_.clear(); const NameVarMap<VarBase>& ins,
ins_.clear(); const NameVarMap<VarBase>& outs,
outs_.clear(); const framework::AttributeMap& attrs,
const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, place);
}
void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, place);
} }
} // namespace imperative } // namespace imperative
......
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
#include <atomic> #include <atomic>
#include <cstdint> #include <cstdint>
#include <list> #include <list>
#include <map> // NOLINT #include <map>
#include <memory> // NOLINT #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <set> #include <set>
#include <string> // NOLINT #include <string>
#include <unordered_map> // NOLINT #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
...@@ -35,6 +36,7 @@ ...@@ -35,6 +36,7 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/flags.h" #include "paddle/fluid/imperative/flags.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -62,26 +64,28 @@ class VarBase { ...@@ -62,26 +64,28 @@ class VarBase {
public: public:
static std::vector<std::string> AliveVarNames(); static std::vector<std::string> AliveVarNames();
explicit VarBase(bool has_grad, const std::string& name) explicit VarBase(bool has_grad, const std::string& name)
: name_(name), : var_(std::make_shared<VariableWrapper>(name)),
grad_var_(has_grad ? new VarBase(false, GradVarName()) : nullptr) { grad_var_(has_grad ? new VarBase(false, GradVarName()) : nullptr) {
if (IsDebugEnabled()) { if (IsDebugEnabled()) {
VLOG(10) << "Construct VarBase: " << name; VLOG(10) << "Construct VarBase: " << Name();
name_set_.Insert(name_); name_set_.Insert(Name());
} }
} }
explicit VarBase(const std::string& name) : VarBase(true, name) {} explicit VarBase(const std::string& name) : VarBase(true, name) {}
~VarBase() { ~VarBase() {
VLOG(10) << "Destruct VarBase: " << name_; VLOG(10) << "Destruct VarBase: " << Name();
if (IsDebugEnabled()) { if (IsDebugEnabled()) {
name_set_.Remove(name_); name_set_.Remove(Name());
} }
} }
const framework::Variable& Var() const { return var_; } const std::shared_ptr<VariableWrapper>& SharedVar() const { return var_; }
framework::Variable* MutableVar() { return &var_; } const framework::Variable& Var() const { return var_->Var(); }
framework::Variable* MutableVar() { return var_->MutableVar(); }
bool HasGradVar() const { return grad_var_ != nullptr; } bool HasGradVar() const { return grad_var_ != nullptr; }
...@@ -94,119 +98,84 @@ class VarBase { ...@@ -94,119 +98,84 @@ class VarBase {
grad_var_ = std::make_shared<VarBase>(false, GradVarName()); grad_var_ = std::make_shared<VarBase>(false, GradVarName());
// NOTE(zhiqiu): we should keep grad_var_'s stop_gradient property same as // NOTE(zhiqiu): we should keep grad_var_'s stop_gradient property same as
// fwd varbase // fwd varbase
grad_var_->SetOverridedStopGradient(overrided_stop_gradient_); grad_var_->SetOverridedStopGradient(var_->InnerOverridedStopGradient());
} }
return grad_var_; return grad_var_;
} }
const framework::Variable& GradVar() const { const framework::Variable& GradVar() const {
PADDLE_ENFORCE_NOT_NULL(grad_var_, "Gradient of %s does not exist", name_); PADDLE_ENFORCE_NOT_NULL(
return grad_var_->var_; grad_var_,
platform::errors::NotFound("Gradient of %s does not exist", Name()));
return grad_var_->Var();
} }
framework::Variable* MutableGradVar() { framework::Variable* MutableGradVar() {
PADDLE_ENFORCE_NOT_NULL(grad_var_, "Gradient of %s does not exist", name_); PADDLE_ENFORCE_NOT_NULL(
return &(grad_var_->var_); grad_var_,
platform::errors::NotFound("Gradient of %s does not exist", Name()));
return grad_var_->MutableVar();
} }
// This is used for python api
void SetOverridedStopGradient(bool stop_gradient) { void SetOverridedStopGradient(bool stop_gradient) {
if (stop_gradient) { var_->SetOverridedStopGradient(stop_gradient);
overrided_stop_gradient_ = 1;
} else {
overrided_stop_gradient_ = 0;
}
if (grad_var_) { if (grad_var_) {
grad_var_->SetOverridedStopGradient(stop_gradient); grad_var_->SetOverridedStopGradient(stop_gradient);
} }
} }
// This is used for python api
bool OverridedStopGradient() const {
if (overrided_stop_gradient_ == 0) {
return false;
} else {
return true;
}
}
// This is used inside C++
int InnerOverridedStopGradient() const { return overrided_stop_gradient_; }
bool GradGenerated() const { return grad_generated_; } bool OverridedStopGradient() const { return var_->OverridedStopGradient(); }
void SetGradGenerated(bool generated) { grad_generated_ = generated; }
// This is used inside C++
void InnerSetOverridedStopGradient(bool stop_gradient) { void InnerSetOverridedStopGradient(bool stop_gradient) {
if (overrided_stop_gradient_ == -1) { if (var_->InnerOverridedStopGradient() == -1) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient); var_->InnerSetOverridedStopGradient(stop_gradient);
if (grad_var_) { if (grad_var_) {
grad_var_->InnerSetOverridedStopGradient(stop_gradient); grad_var_->InnerSetOverridedStopGradient(stop_gradient);
} }
} else {
VLOG(6) << "Ignore Stop gradient conversion for Var: " << Name()
<< "Set value is: " << overrided_stop_gradient_;
} }
} }
void SetPersistable(bool persistable) { persistable_ = persistable; } void SetPersistable(bool persistable) { var_->SetPersistable(persistable); }
bool Persistable() const { return persistable_; }
void AddGradOps(const std::weak_ptr<OpBase>& op); bool Persistable() const { return var_->Persistable(); }
std::vector<OpBase*> GradOps() { // Only grad var is allowed to call these 2 methods
std::vector<OpBase*> rlt; void AddGradOp(const std::shared_ptr<OpBase>& op) {
// TODO(jiabin): use better data structure to remove nullptr when we find it if (op &&
for (const auto& wk_ptr : grad_ops_) { std::find(grad_ops_.begin(), grad_ops_.end(), op) == grad_ops_.end()) {
OpBase* tmp_op = wk_ptr.lock().get(); grad_ops_.emplace_back(op);
if (tmp_op) rlt.emplace_back(tmp_op);
} }
return rlt; }
const std::vector<std::shared_ptr<OpBase>>& GradOps() const {
return grad_ops_;
} }
void ClearGradOps() { grad_ops_.clear(); } void ClearGradOps() { grad_ops_.clear(); }
const std::string& Name() const { return name_; } const std::string& Name() const { return var_->Name(); }
void SetName(const std::string& name) { void SetName(const std::string& name) {
name_ = name; var_->SetName(name);
if (grad_var_) { if (grad_var_) {
grad_var_->SetName(GradVarName()); grad_var_->SetName(GradVarName());
} }
} }
std::string GradVarName() { return framework::GradVarName(name_); } std::string GradVarName() { return framework::GradVarName(Name()); }
void SetType(framework::proto::VarType::Type type) { type_ = type; } void SetType(framework::proto::VarType::Type type) { var_->SetType(type); }
framework::proto::VarType::Type Type() const { return type_; } framework::proto::VarType::Type Type() const { return var_->Type(); }
void SetDataType(framework::proto::VarType::Type data_type) { void SetDataType(framework::proto::VarType::Type data_type) {
data_type_ = data_type; var_->SetDataType(data_type);
if (grad_var_) { if (grad_var_) {
grad_var_->SetDataType(data_type_); grad_var_->SetDataType(data_type);
} }
} }
framework::proto::VarType::Type DataType() const { framework::proto::VarType::Type DataType() const { return var_->DataType(); }
const framework::Tensor* tensor = nullptr;
if (var_.IsInitialized()) {
if (type_ == framework::proto::VarType::LOD_TENSOR) {
tensor = &(var_.Get<framework::LoDTensor>());
} else if (type_ == framework::proto::VarType::SELECTED_ROWS) {
tensor = &(var_.Get<framework::SelectedRows>().value());
} else {
VLOG(6) << "Variable " << name_ << " is not initialized";
return data_type_;
}
}
if (tensor && tensor->IsInitialized()) {
return tensor->type();
} else {
VLOG(6) << "The tensor of variable " << name_ << " is not initialized";
return data_type_;
}
}
void ClearGradient(); void ClearGradient();
...@@ -214,26 +183,23 @@ class VarBase { ...@@ -214,26 +183,23 @@ class VarBase {
const bool blocking) const; const bool blocking) const;
private: private:
framework::Variable var_; /**
std::string name_; * NOTE(zengjinle): never remove the const qualifier of `var_` if you are
* not very familiar with the autograd idea (including the higher order
* derivative).
*/
const std::shared_ptr<VariableWrapper> var_;
std::shared_ptr<VarBase> grad_var_; std::shared_ptr<VarBase> grad_var_;
std::vector<std::shared_ptr<OpBase>> grad_ops_;
mutable size_t copied_counter_ = 0; mutable size_t copied_counter_ = 0;
// grad_op indicates which grad_op will this var be used as input
std::vector<std::weak_ptr<OpBase>> grad_ops_;
// add this property for users may set stop_gradient themselves and this
// should override the
// frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1};
bool grad_generated_{false};
bool persistable_{false};
framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR};
framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32};
static ThreadSafeNameSet name_set_; static ThreadSafeNameSet name_set_;
}; };
using VariableWrapperList = std::vector<std::shared_ptr<VariableWrapper>>;
class Layer { class Layer {
public: public:
virtual ~Layer() {} virtual ~Layer() {}
...@@ -244,6 +210,7 @@ class Layer { ...@@ -244,6 +210,7 @@ class Layer {
} }
}; };
template <typename VarType>
class DygraphExecutionContext : public framework::ExecutionContext { class DygraphExecutionContext : public framework::ExecutionContext {
using Variable = framework::Variable; using Variable = framework::Variable;
...@@ -253,9 +220,9 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -253,9 +220,9 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const platform::DeviceContext& device_context, const platform::DeviceContext& device_context,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
std::vector<framework::KernelConfig>* configs, std::vector<framework::KernelConfig>* configs,
const NameVarBaseMap& var_base_map_in, const NameVarMap<VarType>& var_base_map_in,
const NameVarBaseMap& var_base_map_out, const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap* attrs) const framework::AttributeMap& attrs)
: ExecutionContext(op, scope, device_context, ctx, configs), : ExecutionContext(op, scope, device_context, ctx, configs),
var_base_map_in_(var_base_map_in), var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out), var_base_map_out_(var_base_map_out),
...@@ -303,16 +270,16 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -303,16 +270,16 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
return attrs_->count(name); return attrs_.count(name) != 0;
} }
const framework::AttributeMap& Attrs() const override { return *attrs_; } const framework::AttributeMap& Attrs() const override { return attrs_; }
const framework::Attribute& GetAttr(const std::string& name) const override { const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_->find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, attrs_->end(), it, attrs_.end(),
platform::errors::NotFound("can not find [%s] in attrs", name)); platform::errors::NotFound("can not find [%s] in attrs", name));
return it->second; return it->second;
...@@ -395,16 +362,17 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -395,16 +362,17 @@ class DygraphExecutionContext : public framework::ExecutionContext {
} }
private: private:
const NameVarBaseMap& var_base_map_in_; const NameVarMap<VarType>& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_; const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap* attrs_; const framework::AttributeMap& attrs_;
}; };
// infer var type context for imperative mode // infer var type context for imperative mode
template <typename VarType>
class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
public: public:
RuntimeInferVarTypeContext(const NameVarBaseMap& inputs, RuntimeInferVarTypeContext(const NameVarMap<VarType>& inputs,
const NameVarBaseMap* outputs, const NameVarMap<VarType>* outputs,
const framework::AttributeMap& attrs_map) const framework::AttributeMap& attrs_map)
: InferVarTypeContext(nullptr, nullptr), : InferVarTypeContext(nullptr, nullptr),
inputs_(inputs), inputs_(inputs),
...@@ -536,65 +504,51 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { ...@@ -536,65 +504,51 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext {
} }
private: private:
const NameVarBaseMap& inputs_; const NameVarMap<VarType>& inputs_;
const NameVarBaseMap* outputs_; const NameVarMap<VarType>* outputs_;
const framework::AttributeMap& attrs_; const framework::AttributeMap& attrs_;
std::unordered_map<std::string, std::vector<std::string>> input_names_; std::unordered_map<std::string, std::vector<std::string>> input_names_;
std::unordered_map<std::string, std::vector<std::string>> output_names_; std::unordered_map<std::string, std::vector<std::string>> output_names_;
std::unordered_map<std::string, VarBase*> var_set_; std::unordered_map<std::string, VarType*> var_set_;
}; };
// TODO(zjl): to support py_func layer // TODO(zjl): to support py_func layer
class OpBase : public std::enable_shared_from_this<OpBase> { class OpBase {
DISABLE_COPY_AND_ASSIGN(OpBase); DISABLE_COPY_AND_ASSIGN(OpBase);
public: public:
~OpBase() { VLOG(3) << "Destruct Op: " << Type() << std::endl; } OpBase() = default;
// Developer should not rely on this method to create OpBase. ~OpBase() { VLOG(3) << "Destruct Op: " << Type(); }
// OpBase should be created in Tracer and managed by Tracer totally.
template <typename... Args>
static std::shared_ptr<OpBase> Create(Args&&... args) {
return std::shared_ptr<OpBase>(new OpBase(std::forward<Args>(args)...));
}
size_t id() const { return id_; } size_t id() const { return id_; }
const std::string& Type() const { return op_->Type(); } const std::string& Type() const { return op_->Type(); }
void Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs);
const framework::VariableNameMap& InputNameMap() const {
return op_->Inputs();
}
const framework::VariableNameMap& OutputNameMap() const {
return op_->Outputs();
}
const framework::AttributeMap& Attrs() const { return attrs_; } const framework::AttributeMap& Attrs() const { return attrs_; }
const framework::OpInfo& Info() const { return op_->Info(); } const framework::OpInfo& Info() const { return op_->Info(); }
const framework::OperatorBase& InnerOp() const { return *op_; }
void ClearBackwardTrace(); void ClearBackwardTrace();
const std::vector<OpBase*>& GradPendingOps() const { const std::vector<std::shared_ptr<OpBase>>& GradPendingOps() const {
return grad_pending_ops_; return grad_pending_ops_;
} }
void SetGradPendingOps(std::vector<OpBase*> vec_temp) { void SetGradPendingOps(std::vector<std::shared_ptr<OpBase>> pending_ops) {
grad_pending_ops_.swap(vec_temp); grad_pending_ops_ = std::move(pending_ops);
} }
void InsertGradPendingOps(OpBase* op) { grad_pending_ops_.emplace_back(op); } NameVarMap<VariableWrapper>* GetMutableOutsMap() { return &outs_; }
NameVarMap<VariableWrapper>* GetMutableInsMap() { return &ins_; }
const NameVarMap<VariableWrapper>& GetInsMap() { return ins_; }
const NameVarMap<VariableWrapper>& GetOutsMap() { return outs_; }
void SortGradPendingOps() {
std::sort(grad_pending_ops_.begin(), grad_pending_ops_.end(),
[](OpBase* op1, OpBase* op2) { return op1->id() > op2->id(); });
}
NameVarBaseMap* GetMutableOutsMap() { return &outs_; }
NameVarBaseMap* GetMutableInsMap() { return &ins_; }
const NameVarBaseMap& GetInsMap() { return ins_; }
const NameVarBaseMap& GetOutsMap() { return outs_; }
const platform::Place& place() const { return place_; } const platform::Place& place() const { return place_; }
// TODO(jiabin) prepare for backward hook // TODO(jiabin) prepare for backward hook
...@@ -609,41 +563,40 @@ class OpBase : public std::enable_shared_from_this<OpBase> { ...@@ -609,41 +563,40 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
} }
} }
private: void SetType(const std::string& type);
OpBase(size_t id, const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place);
public: void CheckAttrs() {
OpBase() {} auto& info = op_->Info();
if (info.Checker() != nullptr) {
info.Checker()->Check(&attrs_, true);
}
}
void SetType(const std::string& type) { type_ = type; } void SetInput(const std::string& name, VariableWrapperList vars) {
void SetInput(const std::string& name, ins_[name] = std::move(vars);
std::vector<std::shared_ptr<VarBase>> vec_var_base) {
ins_[name] = std::move(vec_var_base);
} }
void SetOutput(const std::string& name,
std::vector<std::shared_ptr<VarBase>> vec_var_base) { void SetOutput(const std::string& name, VariableWrapperList vars) {
outs_[name] = std::move(vec_var_base); outs_[name] = std::move(vars);
} }
void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
void SetAttr(const std::string& name, const framework::Attribute& v) { void SetAttr(const std::string& name, const framework::Attribute& v) {
attrs_[name] = v; attrs_[name] = v;
} }
void SetBlockAttr(const std::string& name, framework::BlockDesc* block) { void SetBlockAttr(const std::string& name, framework::BlockDesc* block) {
PADDLE_THROW("SetBlockAttr is not support in dygraph OpBase"); PADDLE_THROW("SetBlockAttr is not support in dygraph OpBase");
} }
const framework::AttributeMap& Attrs() { return attrs_; } const framework::AttributeMap& Attrs() { return attrs_; }
void CreateOperatorBase();
void SetId(size_t id) { id_ = id; } void SetId(size_t id) { id_ = id; }
void SetPlace(platform::Place place) { place_ = place; }
bool HasAttr(const std::string& name) const { void SetPlace(const platform::Place& place) { place_ = place; }
return attrs_.find(name) != attrs_.end();
} bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; }
const framework::Attribute& GetAttr(const std::string& name) const { const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
...@@ -657,31 +610,49 @@ class OpBase : public std::enable_shared_from_this<OpBase> { ...@@ -657,31 +610,49 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
return boost::get<T>(GetAttr(name)); return boost::get<T>(GetAttr(name));
} }
private: void AddAllowedEmptyVar(const VariableWrapper* var) {
size_t id_; allow_empty_vars_.emplace(var);
}
bool IsAllowedEmptyVar(const VariableWrapper* var) {
return allow_empty_vars_.count(var) > 0;
}
static void Run(const framework::OperatorBase& op,
const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
static void Run(const framework::OperatorBase& op,
const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
private:
NameVarMap<VariableWrapper> ins_;
NameVarMap<VariableWrapper> outs_;
framework::AttributeMap attrs_;
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
std::vector<std::function<void()>> backward_hooks_; std::vector<std::shared_ptr<OpBase>> grad_pending_ops_;
platform::Place place_; platform::Place place_;
// Not need to be std::weak_ptr, because op is binded to a certain Tracer, std::unordered_set<const VariableWrapper*> allow_empty_vars_;
// and would not be used by a Tracer that does not create itself.
std::vector<OpBase*> grad_pending_ops_; size_t id_{-1UL};
// This part is only used for backward std::vector<std::function<void()>> backward_hooks_;
NameVarBaseMap ins_;
NameVarBaseMap outs_;
std::string type_;
framework::AttributeMap attrs_;
}; };
template <typename VarType>
class DygraphInferShapeContext : public framework::InferShapeContext { class DygraphInferShapeContext : public framework::InferShapeContext {
using DDim = framework::DDim; using DDim = framework::DDim;
public: public:
DygraphInferShapeContext(const NameVarBaseMap* in, const NameVarBaseMap* out, DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr) const framework::AttributeMap* attr)
: var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {} : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr) {}
...@@ -909,9 +880,6 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -909,9 +880,6 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
void SetOutputsDim(const std::string& name, void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override { const std::vector<DDim>& dims) override {
// auto& vars = OutputVars(name);
// SetDims(vars, dims);
auto it = var_base_map_out_->find(name); auto it = var_base_map_out_->find(name);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
it, var_base_map_out_->end(), it, var_base_map_out_->end(),
...@@ -992,9 +960,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -992,9 +960,8 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
} }
private: private:
const NameVarBaseMap* var_base_map_in_; const NameVarMap<VarType>* var_base_map_in_;
const NameVarBaseMap* var_base_map_out_; const NameVarMap<VarType>* var_base_map_out_;
std::string type_;
const framework::AttributeMap* attrs_; const framework::AttributeMap* attrs_;
}; };
......
...@@ -28,8 +28,9 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { ...@@ -28,8 +28,9 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
} }
} }
void PreparedOp::PrepareData( template <typename VarType>
const platform::Place& place, const NameVarBaseMap& ins, static void PrepareDataImpl(
const platform::Place& place, const NameVarMap<VarType>& ins,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) { const framework::OpKernelType& expected_kernel_key) {
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
...@@ -59,22 +60,37 @@ void PreparedOp::PrepareData( ...@@ -59,22 +60,37 @@ void PreparedOp::PrepareData(
} }
} }
void PreparedOp::PrepareData(
const platform::Place& place, const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) {
PrepareDataImpl<VarBase>(place, ins, op, expected_kernel_key);
}
void PreparedOp::PrepareData(
const platform::Place& place, const NameVarMap<VariableWrapper>& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key) {
PrepareDataImpl<VariableWrapper>(place, ins, op, expected_kernel_key);
}
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func, const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs) std::vector<framework::KernelConfig>* kernel_configs)
: op_(op), : op_(op),
ctx_(ctx), ctx_(ctx),
func_(std::move(func)), func_(func),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
kernel_configs_(kernel_configs) {} kernel_configs_(kernel_configs) {}
PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins, template <typename VarType>
const NameVarBaseMap& outs, PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
const framework::OperatorWithKernel& op, const NameVarMap<VarType>& outs,
platform::Place place, const framework::OperatorWithKernel& op,
const framework::AttributeMap* attrs) { platform::Place place,
const framework::AttributeMap& attrs) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -90,8 +106,9 @@ PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins, ...@@ -90,8 +106,9 @@ PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
framework::RuntimeContext ctx({}, {}); framework::RuntimeContext ctx({}, {});
auto expected_kernel_key = op.GetExpectedKernelType(DygraphExecutionContext( auto expected_kernel_key =
op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs)); op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, nullptr, ins, outs, attrs));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -108,24 +125,57 @@ PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins, ...@@ -108,24 +125,57 @@ PreparedOp PreparedOp::Prepare(const NameVarBaseMap& ins,
place = dev_ctx->GetPlace(); place = dev_ctx->GetPlace();
} }
PrepareData(place, ins, op, expected_kernel_key); PrepareDataImpl<VarType>(place, ins, op, expected_kernel_key);
return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs); return PreparedOp(op, ctx, kernel_iter->second, dev_ctx, kernel_configs);
} }
void PreparedOp::Run(const NameVarBaseMap* in, const NameVarBaseMap* out, PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const framework::AttributeMap* attrs) { const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareOpImpl<VarBase>(ins, outs, op, place, attrs);
}
PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs) {
return PrepareOpImpl<VariableWrapper>(ins, outs, op, place, attrs);
}
template <typename VarType>
static void PreparedOpRunImpl(
const framework::OperatorBase& op, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs) {
// TODO(zjl): remove scope in dygraph // TODO(zjl): remove scope in dygraph
framework::Scope scope; framework::Scope scope;
DygraphInferShapeContext infer_shape_ctx(in, out, attrs); DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs);
static_cast<const framework::OperatorWithKernel&>(op).InferShape(
&infer_shape_ctx);
framework::OperatorWithKernel* op_ker = func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx,
(framework::OperatorWithKernel*)(&op_); kernel_configs, ins, outs, attrs));
}
op_ker->InferShape(&infer_shape_ctx); void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VarBase>(op_, ctx_, func_, dev_ctx_, kernel_configs_, ins,
outs, attrs);
}
func_(DygraphExecutionContext(op_, scope, *dev_ctx_, ctx_, kernel_configs_, void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
*in, *out, attrs)); const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs) {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, func_, dev_ctx_,
kernel_configs_, ins, outs, attrs);
} }
} // namespace imperative } // namespace imperative
......
...@@ -30,28 +30,42 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var); ...@@ -30,28 +30,42 @@ const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
class PreparedOp { class PreparedOp {
public: public:
static PreparedOp Prepare(const NameVarBaseMap& ins, PreparedOp(const framework::OperatorBase& op,
const NameVarBaseMap& outs, const framework::RuntimeContext& ctx,
const framework::OperatorWithKernel::OpKernelFunc& func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
static PreparedOp Prepare(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs);
static PreparedOp Prepare(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
platform::Place place, const platform::Place& place,
const framework::AttributeMap* attrs); const framework::AttributeMap& attrs);
inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx_; } inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx_; }
void Run(const NameVarBaseMap* in, const NameVarBaseMap* out, void Run(const NameVarMap<VarBase>& in, const NameVarMap<VarBase>& out,
const framework::AttributeMap* attrs); const framework::AttributeMap& attrs);
void Run(const NameVarMap<VariableWrapper>& ins,
const NameVarMap<VariableWrapper>& outs,
const framework::AttributeMap& attrs);
static void PrepareData(const platform::Place& place, static void PrepareData(const platform::Place& place,
const NameVarBaseMap& ins, const NameVarMap<VarBase>& ins,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key); const framework::OpKernelType& expected_kernel_key);
private: static void PrepareData(const platform::Place& place,
PreparedOp(const framework::OperatorBase& op, const NameVarMap<VariableWrapper>& ins,
const framework::RuntimeContext& ctx, const framework::OperatorWithKernel& op,
framework::OperatorWithKernel::OpKernelFunc func, const framework::OpKernelType& expected_kernel_key);
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
private: private:
const framework::OperatorBase& op_; const framework::OperatorBase& op_;
......
...@@ -44,7 +44,8 @@ TEST(test_layer, test_runtime_context) { ...@@ -44,7 +44,8 @@ TEST(test_layer, test_runtime_context) {
imperative::NameVarBaseMap ins = {in_pair}; imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap attrs; framework::AttributeMap attrs;
auto* ctx = new imperative::RuntimeInferVarTypeContext(ins, &outs, attrs); auto *ctx = new imperative::RuntimeInferVarTypeContext<imperative::VarBase>(
ins, &outs, attrs);
ASSERT_TRUE(ctx->HasVar("vin")); ASSERT_TRUE(ctx->HasVar("vin"));
ASSERT_TRUE(ctx->HasInput("X")); ASSERT_TRUE(ctx->HasInput("X"));
ASSERT_TRUE(ctx->HasOutput("Out")); ASSERT_TRUE(ctx->HasOutput("Out"));
...@@ -57,9 +58,9 @@ TEST(test_layer, test_runtime_context) { ...@@ -57,9 +58,9 @@ TEST(test_layer, test_runtime_context) {
ASSERT_ANY_THROW(ctx->SetLoDLevel("vin", 2)); ASSERT_ANY_THROW(ctx->SetLoDLevel("vin", 2));
} }
std::string LayerDebugString(const std::string& op_type, std::string LayerDebugString(const std::string &op_type,
const NameVarBaseMap& ins, const NameVarBaseMap &ins,
const NameVarBaseMap& outs); const NameVarBaseMap &outs);
TEST(test_layer, test_debug_string) { TEST(test_layer, test_debug_string) {
platform::CPUPlace place; platform::CPUPlace place;
...@@ -67,7 +68,7 @@ TEST(test_layer, test_debug_string) { ...@@ -67,7 +68,7 @@ TEST(test_layer, test_debug_string) {
new imperative::VarBase(false, "vin")); new imperative::VarBase(false, "vin"));
var_pair in_pair = var_pair("X", vb_vector(1, vin)); var_pair in_pair = var_pair("X", vb_vector(1, vin));
auto test_func = [&](std::shared_ptr<imperative::VarBase>& vout) { auto test_func = [&](std::shared_ptr<imperative::VarBase> &vout) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout)); var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {in_pair}; imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
...@@ -119,6 +120,34 @@ TEST(test_layer, test_debug_string) { ...@@ -119,6 +120,34 @@ TEST(test_layer, test_debug_string) {
ASSERT_TRUE(res_sr.find("SelectedRows") != std::string::npos); ASSERT_TRUE(res_sr.find("SelectedRows") != std::string::npos);
} }
static std::shared_ptr<imperative::OpBase> CreateOpBase(
size_t id, const std::string &type, const imperative::NameVarBaseMap &ins,
const imperative::NameVarBaseMap &outs,
const framework::AttributeMap &attrs, const platform::Place &place) {
auto op = std::make_shared<imperative::OpBase>();
op->SetId(id);
op->SetPlace(place);
op->SetType(type);
op->SetAttrMap(attrs);
for (auto &pair : ins) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetInput(pair.first, vars);
}
for (auto &pair : outs) {
std::vector<std::shared_ptr<VariableWrapper>> vars;
for (auto &var : pair.second) {
vars.emplace_back(var->SharedVar());
}
op->SetOutput(pair.first, vars);
}
return op;
}
TEST(test_layer, test_clear_backward_info) { TEST(test_layer, test_clear_backward_info) {
std::shared_ptr<imperative::VarBase> vin( std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin")); new imperative::VarBase(false, "vin"));
...@@ -133,13 +162,11 @@ TEST(test_layer, test_clear_backward_info) { ...@@ -133,13 +162,11 @@ TEST(test_layer, test_clear_backward_info) {
imperative::NameVarBaseMap outs = {out_pair}; imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap concat_att_map; framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1; concat_att_map["axis"] = 1;
std::shared_ptr<imperative::OpBase> op(
OpBase::Create(0, "mul", ins, outs, concat_att_map, place)); auto op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place);
std::shared_ptr<imperative::OpBase> preceding_op( auto preceding_op = CreateOpBase(0, "mul", ins, outs, concat_att_map, place);
OpBase::Create(0, "mul", ins, outs, concat_att_map, place)); op->SetGradPendingOps({preceding_op});
op->InsertGradPendingOps(preceding_op.get());
*(op->GetMutableInsMap()) = ins;
*(op->GetMutableOutsMap()) = outs;
ASSERT_GT(op->GetInsMap().size(), 0UL); ASSERT_GT(op->GetInsMap().size(), 0UL);
ASSERT_GT(op->GetOutsMap().size(), 0UL); ASSERT_GT(op->GetOutsMap().size(), 0UL);
ASSERT_GT(op->GradPendingOps().size(), 0UL); ASSERT_GT(op->GradPendingOps().size(), 0UL);
...@@ -163,10 +190,10 @@ TEST(test_layer, test_varbase_basic) { ...@@ -163,10 +190,10 @@ TEST(test_layer, test_varbase_basic) {
std::shared_ptr<imperative::VarBase> vin_with_grad( std::shared_ptr<imperative::VarBase> vin_with_grad(
new imperative::VarBase(true, "vin")); new imperative::VarBase(true, "vin"));
ASSERT_ANY_THROW(vin->MutableGradVar()); ASSERT_ANY_THROW(vin->MutableGradVar());
ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable*>( ASSERT_NO_THROW(ASSERT_TRUE(dynamic_cast<framework::Variable *>(
vin_with_grad->MutableGradVar()) != 0)); vin_with_grad->MutableGradVar()) != 0));
ASSERT_TRUE( ASSERT_TRUE(dynamic_cast<framework::Variable *>(
dynamic_cast<framework::Variable*>(vin_with_grad->MutableGradVar()) != 0); vin_with_grad->MutableGradVar()) != 0);
vin_with_grad->SetOverridedStopGradient(false); vin_with_grad->SetOverridedStopGradient(false);
ASSERT_FALSE(vin_with_grad->OverridedStopGradient()); ASSERT_FALSE(vin_with_grad->OverridedStopGradient());
ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true)); ASSERT_NO_FATAL_FAILURE(vin_with_grad->SetPersistable(true));
...@@ -195,14 +222,14 @@ TEST(test_layer, test_dygraph_execution_context) { ...@@ -195,14 +222,14 @@ TEST(test_layer, test_dygraph_execution_context) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false); auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
paddle::platform::CPUPlace cpu_place; paddle::platform::CPUPlace cpu_place;
paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance(); paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(cpu_place); auto *dev_ctx = pool.Get(cpu_place);
paddle::framework::RuntimeContext ctx({}, {}); paddle::framework::RuntimeContext ctx({}, {});
framework::Scope scope; framework::Scope scope;
DygraphExecutionContext dy_exe_context(*(op.get()), scope, *dev_ctx, ctx, DygraphExecutionContext<imperative::VarBase> dy_exe_context(
nullptr, ins, outs, &concat_att_map); *(op.get()), scope, *dev_ctx, ctx, nullptr, ins, outs, concat_att_map);
ASSERT_EQ(dy_exe_context.InputSize("X"), 1u); ASSERT_EQ(dy_exe_context.InputSize("X"), 1u);
ASSERT_EQ(dy_exe_context.InputName("X"), "vin"); ASSERT_EQ(dy_exe_context.InputName("X"), "vin");
...@@ -229,7 +256,8 @@ TEST(test_layer, test_dygraph_infershape_context) { ...@@ -229,7 +256,8 @@ TEST(test_layer, test_dygraph_infershape_context) {
framework::AttributeMap concat_att_map; framework::AttributeMap concat_att_map;
concat_att_map["axis"] = 1; concat_att_map["axis"] = 1;
DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &concat_att_map); DygraphInferShapeContext<imperative::VarBase> infer_shape_ctx(
&ins, &outs, &concat_att_map);
bool have_x = infer_shape_ctx.HasOutputs("Out"); bool have_x = infer_shape_ctx.HasOutputs("Out");
ASSERT_EQ(have_x, true); ASSERT_EQ(have_x, true);
......
...@@ -114,7 +114,7 @@ TEST(test_prepare_op, test_prepare_op) { ...@@ -114,7 +114,7 @@ TEST(test_prepare_op, test_prepare_op) {
ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare( ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare(
ins, outs, ins, outs,
dynamic_cast<framework::OperatorWithKernel&>(*op), dynamic_cast<framework::OperatorWithKernel&>(*op),
place, &split_attr_map)); place, split_attr_map));
} }
const framework::Tensor* GetTensorFromVar(const framework::Variable& var); const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
...@@ -165,7 +165,7 @@ TEST(test_prepare_op, test_prepare_data) { ...@@ -165,7 +165,7 @@ TEST(test_prepare_op, test_prepare_data) {
// test if it can be transformed to GPU place // test if it can be transformed to GPU place
PreparedOp prepared_op = PreparedOp::Prepare( PreparedOp prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place, ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), gpu_place,
&attr_map); attr_map);
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
...@@ -213,7 +213,7 @@ TEST(test_prepare_op, test_prepare_data_same_place) { ...@@ -213,7 +213,7 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
// test if it never transferred on GPU place // test if it never transferred on GPU place
PreparedOp prepared_op = PreparedOp::Prepare( PreparedOp prepared_op = PreparedOp::Prepare(
ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place, ins, outs, dynamic_cast<framework::OperatorWithKernel&>(*op), cpu_place,
&attr_map); attr_map);
for (const auto& name_pair : ins) { for (const auto& name_pair : ins) {
for (const auto& vb : name_pair.second) { for (const auto& vb : name_pair.second) {
ASSERT_TRUE(platform::is_same_place( ASSERT_TRUE(platform::is_same_place(
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <paddle/fluid/framework/op_registry.h> #include <paddle/fluid/framework/op_registry.h>
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -147,9 +148,9 @@ TEST(test_tracer, test_track_backward_output) { ...@@ -147,9 +148,9 @@ TEST(test_tracer, test_track_backward_output) {
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
auto* engine = tracer.GetDefaultEngine(); ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_NE(engine->GradVars().size(), 0UL); ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran. ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
} }
TEST(test_tracer, test_track_backward_input) { TEST(test_tracer, test_track_backward_input) {
...@@ -186,9 +187,10 @@ TEST(test_tracer, test_track_backward_input) { ...@@ -186,9 +187,10 @@ TEST(test_tracer, test_track_backward_input) {
framework::AttributeMap mul_attr_map; framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false; mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true); tracer.TraceOp("mul", ins, outs, mul_attr_map, place, true);
auto* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL); ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran. ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
} }
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
TEST(test_tracer, test_trace_op_with_multi_device_inputs) { TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
...@@ -344,10 +346,12 @@ TEST(test_tracer, test_var_without_grad_var) { ...@@ -344,10 +346,12 @@ TEST(test_tracer, test_var_without_grad_var) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0); ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
} }
ASSERT_EQ(x_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y_in->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOps().size(), 1UL);
detail::BackwardStrategy back_st; detail::BackwardStrategy back_st;
imperative::Engine* engine = tracer.GetDefaultEngine(); imperative::Engine* engine = tracer.GetDefaultEngine();
ASSERT_NE(engine->GradVars().size(), 0UL);
ASSERT_NE(engine->GradOps().size(), 0UL); // trace_backward already ran.
engine->Init(vout.get(), back_st); engine->Init(vout.get(), back_st);
engine->Execute(); engine->Execute();
...@@ -369,10 +373,137 @@ TEST(test_tracer, test_var_without_grad_var) { ...@@ -369,10 +373,137 @@ TEST(test_tracer, test_var_without_grad_var) {
} }
} }
template <typename T>
using WeakPtrSet =
std::set<std::weak_ptr<T>, std::owner_less<std::weak_ptr<T>>>;
static void TestVarOpDestructionMain(const platform::Place& place,
int64_t tensor_size = 10,
size_t loop_num = 10) {
WeakPtrSet<VariableWrapper> var_wrappers;
WeakPtrSet<VarBase> var_bases;
WeakPtrSet<OpBase> op_bases;
Tracer tracer;
{
auto x = std::make_shared<VarBase>("x");
auto y = std::make_shared<VarBase>("y");
x->MutableVar()
->GetMutable<framework::LoDTensor>()
->Resize({tensor_size, tensor_size})
.mutable_data<float>(place);
y->MutableVar()
->GetMutable<framework::LoDTensor>()
->Resize({tensor_size, tensor_size})
.mutable_data<float>(place);
x->SetOverridedStopGradient(false);
y->SetOverridedStopGradient(true);
for (size_t i = 0; i < loop_num; ++i) {
size_t var_wrapper_num = var_wrappers.size();
size_t var_base_num = var_bases.size();
size_t op_base_num = op_bases.size();
auto z = std::make_shared<VarBase>("z_" + std::to_string(i));
tracer.TraceOp("mul", NameVarBaseMap{{"X", {x}}, {"Y", {y}}},
NameVarBaseMap{{"Out", {z}}}, framework::AttributeMap{},
place, true);
ASSERT_EQ(z->GradOps().size(), 0UL);
ASSERT_EQ(z->GradVarBase()->GradOps().size(), 1UL);
auto new_op = z->GradVarBase()->GradOps()[0];
ASSERT_EQ(x->GradOps().size(), 0UL);
ASSERT_EQ(y->GradOps().size(), 0UL);
std::unordered_set<std::shared_ptr<OpBase>> expected_pending_ops;
if (i == 0) {
ASSERT_EQ(x->GradVarBase()->GradOps().size(), 0UL);
ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL);
} else {
ASSERT_EQ(x->GradVarBase()->GradOps().size(), 1UL);
ASSERT_EQ(y->GradVarBase()->GradOps().size(), 0UL);
for (auto& op : x->GradVarBase()->GradOps()) {
expected_pending_ops.emplace(op);
}
for (auto& op : y->GradVarBase()->GradOps()) {
expected_pending_ops.emplace(op);
}
std::unordered_set<std::shared_ptr<OpBase>> actual_pending_ops;
for (auto& op : new_op->GradPendingOps()) {
actual_pending_ops.emplace(op);
}
ASSERT_TRUE(expected_pending_ops == actual_pending_ops);
ASSERT_EQ(expected_pending_ops.count(new_op), 0UL);
}
var_wrappers.emplace(x->SharedVar());
var_wrappers.emplace(x->GradVarBase()->SharedVar());
var_wrappers.emplace(y->SharedVar());
var_wrappers.emplace(y->GradVarBase()->SharedVar());
var_wrappers.emplace(z->SharedVar());
var_wrappers.emplace(z->GradVarBase()->SharedVar());
var_bases.emplace(x);
var_bases.emplace(x->GradVarBase());
var_bases.emplace(y);
var_bases.emplace(y->GradVarBase());
var_bases.emplace(z);
var_bases.emplace(z->GradVarBase());
for (auto& op : expected_pending_ops) {
op_bases.emplace(op);
}
if (i == 0) {
ASSERT_EQ(var_wrapper_num, 0UL);
ASSERT_EQ(var_base_num, 0UL);
ASSERT_EQ(op_base_num, 0UL);
ASSERT_EQ(var_wrappers.size(), 6UL);
ASSERT_EQ(var_bases.size(), 6UL);
ASSERT_EQ(op_bases.size(), 0UL);
} else {
ASSERT_EQ(var_wrappers.size(), var_wrapper_num + 2);
ASSERT_EQ(var_bases.size(), var_base_num + 2);
ASSERT_EQ(op_bases.size(), op_base_num + 1);
}
x = z; // recurrent usage
}
}
for (auto& var : var_wrappers) {
ASSERT_TRUE(var.expired());
}
for (auto& var : var_bases) {
ASSERT_TRUE(var.expired());
}
for (auto& op : op_bases) {
ASSERT_TRUE(op.expired());
}
}
TEST(test_tracer, test_var_op_destruction) {
TestVarOpDestructionMain(platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
TestVarOpDestructionMain(platform::CUDAPlace(0));
#endif
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP(mul);
USE_OP(mul_grad);
USE_OP(reduce_sum); USE_OP(reduce_sum);
USE_OP(reduce_sum_grad); USE_OP(reduce_sum_grad);
USE_OP(elementwise_add); USE_OP(elementwise_add);
...@@ -15,7 +15,10 @@ ...@@ -15,7 +15,10 @@
#include <set> #include <set>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -48,22 +51,24 @@ static void ClearNoNeedBufferInputs(OpBase* op) { ...@@ -48,22 +51,24 @@ static void ClearNoNeedBufferInputs(OpBase* op) {
PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(), true, PADDLE_ENFORCE_EQ(var.IsType<framework::LoDTensor>(), true,
"Only support LoDTensor"); "Only support LoDTensor");
// TODO(zjl): support higher order derivatives // TODO(zjl): support higher order derivatives
auto new_var = new VarBase(false, each_var->Name()); auto new_var = new VariableWrapper(each_var->Name());
auto* new_tensor = auto* new_tensor =
new_var->MutableVar()->GetMutable<framework::LoDTensor>(); new_var->MutableVar()->GetMutable<framework::LoDTensor>();
auto& old_tensor = var.Get<framework::LoDTensor>(); auto& old_tensor = var.Get<framework::LoDTensor>();
new_tensor->Resize(old_tensor.dims()); new_tensor->Resize(old_tensor.dims());
new_tensor->set_lod(old_tensor.lod()); new_tensor->set_lod(old_tensor.lod());
each_var.reset(new_var); each_var.reset(new_var);
op->AddAllowedEmptyVar(new_var);
} }
} }
} }
static std::vector<std::unique_ptr<OpBase>> CreateGradOpBases( static std::vector<std::shared_ptr<OpBase>> CreateGradOpBases(
const OpBase* fw_op_base, const NameVarBaseMap& in, const framework::OpInfo& info, const std::string& type,
const NameVarBaseMap& out) { const NameVarBaseMap& in, const NameVarBaseMap& out,
if (fw_op_base->Info().dygraph_grad_op_maker_) { const framework::AttributeMap& attrs) {
return fw_op_base->Info().dygraph_grad_op_maker_(fw_op_base, in, out); if (info.dygraph_grad_op_maker_) {
return info.dygraph_grad_op_maker_(type, in, out, attrs);
} else { } else {
return {}; return {};
} }
...@@ -83,17 +88,22 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -83,17 +88,22 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward) { const platform::Place& place, bool trace_backward) {
VLOG(1) << "Trace Op: " << type; VLOG(1) << "Trace Op: " << type;
size_t op_id = GenerateUniqueId(); auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
auto op = OpBase::Create(op_id, type, ins, outs, attrs, place); const auto& op_info = op->Info();
op->Run(ins, outs); auto* attr_checker = op_info.Checker();
if (attr_checker) {
attr_checker->Check(&attrs, true);
}
OpBase::Run(*op, ins, outs, attrs, place);
if (enable_program_desc_tracing_) { if (enable_program_desc_tracing_) {
VLOG(5) << "Trace op " << type << " into ProgramDesc"; VLOG(5) << "Trace op " << type << " into ProgramDesc";
program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs()); program_desc_tracer_->InsertOp(type, ins, outs, attrs);
} }
if (ComputeRequiredGrad(ins, outs, trace_backward)) { if (ComputeRequiredGrad(ins, outs, trace_backward)) {
TraceBackward(op, ins, outs); TraceBackward(op_info, type, ins, outs, attrs, place);
} else { } else {
VLOG(3) << "No Grad to track for Op: " << type; VLOG(3) << "No Grad to track for Op: " << type;
} }
...@@ -102,22 +112,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -102,22 +112,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const NameVarBaseMap& outs,
framework::AttributeMap attrs) { framework::AttributeMap attrs) {
VLOG(1) << "Trace Op: " << type; TraceOp(type, ins, outs, std::move(attrs), expected_place_, no_grad_);
size_t op_id = GenerateUniqueId();
auto op =
OpBase::Create(op_id, type, ins, outs, std::move(attrs), expected_place_);
op->Run(ins, outs);
if (enable_program_desc_tracing_) {
VLOG(5) << "Trace op " << type << " into ProgramDesc";
program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs());
}
if (ComputeRequiredGrad(ins, outs, no_grad_)) {
TraceBackward(op, ins, outs);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
} }
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
...@@ -138,78 +133,19 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, ...@@ -138,78 +133,19 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
return false; return false;
} }
void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op, void Tracer::TraceBackward(const framework::OpInfo& info,
const NameVarBaseMap& ins, const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs) { const NameVarBaseMap& outs,
// grad_to_var is a map of framework::GradVarName(in_var_name/out_var_name) -> const framework::AttributeMap& attrs,
// in_var_name/out_var_name const platform::Place& place) {
std::unordered_map<std::string, std::string> grad_to_var; auto grad_op_bases = CreateGradOpBases(info, type, ins, outs, attrs);
auto grad_op_num = grad_op_bases.size();
// Get grad_op_desc using fwd_op_desc if (grad_op_num == 0) return;
std::vector<std::unique_ptr<OpBase>> grad_op_bases_ =
CreateGradOpBases(fwd_op.get(), ins, outs); size_t trace_id = GenerateUniqueId();
for (auto& grad_op : grad_op_bases) {
size_t grad_op_num = grad_op_bases_.size(); grad_op->SetPlace(place);
std::set<VarBase*> set_input_vars;
for (auto& fwd_in_it : ins) {
for (auto& var_base_it : fwd_in_it.second) {
set_input_vars.insert(var_base_it.get());
}
}
for (auto& fwd_out_it : outs) {
for (auto& var_base_it : fwd_out_it.second) {
set_input_vars.insert(var_base_it.get());
}
}
for (size_t i = 0; i < grad_op_num; ++i) {
size_t trace_id = fwd_op->id();
std::shared_ptr<OpBase> grad_op = std::move(grad_op_bases_[i]);
grad_op->SetId(trace_id); grad_op->SetId(trace_id);
grad_op->SetPlace(fwd_op->place());
grad_op->CreateOperatorBase();
auto& grad_in = *(grad_op->GetMutableInsMap());
auto& grad_out = *(grad_op->GetMutableOutsMap());
for (auto& grad_in_it : grad_in) {
for (auto& var_base_it : grad_in_it.second) {
if (set_input_vars.count(var_base_it.get()) == 0) {
var_base_it->AddGradOps(grad_op);
engine_->InsertGradVar(var_base_it.get());
}
}
}
std::set<OpBase*> visited_preceding_ops;
for (auto& grad_out_it : grad_out) {
bool flag_clear_list = false;
for (auto& var_base_it : grad_out_it.second) {
if ((!var_base_it->OverridedStopGradient()) ||
(grad_out_it.second.size() > 1)) {
auto preceding_ops = var_base_it->GradOps();
if (!preceding_ops.empty()) {
for (const auto& op : preceding_ops) {
visited_preceding_ops.insert(op);
}
}
} else {
flag_clear_list = true;
}
}
if (flag_clear_list) {
grad_out_it.second.clear();
}
}
std::vector<OpBase*> vec_preceding_ops(visited_preceding_ops.begin(),
visited_preceding_ops.end());
grad_op->SetGradPendingOps(std::move(vec_preceding_ops));
// this OpBase* is just used to manage op's life time
engine_->InsertOp(grad_op.get(), grad_op);
ClearNoNeedBufferInputs(grad_op.get()); ClearNoNeedBufferInputs(grad_op.get());
} }
} }
......
...@@ -64,9 +64,6 @@ class Tracer { ...@@ -64,9 +64,6 @@ class Tracer {
bool ComputeRequiredGrad(const NameVarBaseMap& ins, bool ComputeRequiredGrad(const NameVarBaseMap& ins,
const NameVarBaseMap& outs, bool trace_backward); const NameVarBaseMap& outs, bool trace_backward);
void TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const NameVarBaseMap& ins, const NameVarBaseMap& outs);
Engine* GetDefaultEngine() const { return engine_.get(); } Engine* GetDefaultEngine() const { return engine_.get(); }
void SetEnableProgramDescTracing(bool enabled) { void SetEnableProgramDescTracing(bool enabled) {
...@@ -94,6 +91,11 @@ class Tracer { ...@@ -94,6 +91,11 @@ class Tracer {
void SetNoGrad(bool no_grad) { no_grad_ = no_grad; } void SetNoGrad(bool no_grad) { no_grad_ = no_grad; }
private: private:
void TraceBackward(const framework::OpInfo& info, const std::string& type,
const NameVarBaseMap& ins, const NameVarBaseMap& outs,
const framework::AttributeMap& attrs,
const platform::Place& place);
static size_t GenerateUniqueId() { static size_t GenerateUniqueId() {
static std::atomic<size_t> id{0}; static std::atomic<size_t> id{0};
return id.fetch_add(1); return id.fetch_add(1);
......
...@@ -22,12 +22,16 @@ limitations under the License. */ ...@@ -22,12 +22,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
class VariableWrapper;
class VarBase; class VarBase;
class OpBase; class OpBase;
class Tracer; class Tracer;
using NameVarBaseMap = template <typename T>
std::map<std::string, std::vector<std::shared_ptr<VarBase>>>; using NameVarMap = std::map<std::string, std::vector<std::shared_ptr<T>>>;
using NameVarBaseMap = NameVarMap<VarBase>;
using NameVariableWrapperMap = NameVarMap<VariableWrapper>;
using WeakNameVarBaseMap = using WeakNameVarBaseMap =
std::map<std::string, std::vector<std::weak_ptr<VarBase>>>; std::map<std::string, std::vector<std::weak_ptr<VarBase>>>;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace imperative {
class VariableWrapper {
public:
explicit VariableWrapper(const std::string& name) : name_(name) {}
const framework::Variable& Var() const { return var_; }
framework::Variable* MutableVar() { return &var_; }
// This is used for python api
void SetOverridedStopGradient(bool stop_gradient) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient);
}
// This is used for python api
bool OverridedStopGradient() const { return overrided_stop_gradient_ != 0; }
// This is used inside C++
int InnerOverridedStopGradient() const { return overrided_stop_gradient_; }
// This is used inside C++
void InnerSetOverridedStopGradient(bool stop_gradient) {
if (overrided_stop_gradient_ == -1) {
overrided_stop_gradient_ = static_cast<int>(stop_gradient);
} else {
VLOG(6) << "Ignore Stop gradient conversion for Var: " << Name()
<< "Set value is: " << overrided_stop_gradient_;
}
}
void SetPersistable(bool persistable) { persistable_ = persistable; }
bool Persistable() const { return persistable_; }
const std::string& Name() const { return name_; }
void SetName(const std::string& name) { name_ = name; }
void SetType(framework::proto::VarType::Type type) { type_ = type; }
framework::proto::VarType::Type Type() const { return type_; }
void SetDataType(framework::proto::VarType::Type data_type) {
data_type_ = data_type;
}
framework::proto::VarType::Type DataType() const {
const framework::Tensor* tensor = nullptr;
if (var_.IsInitialized()) {
if (type_ == framework::proto::VarType::LOD_TENSOR) {
tensor = &(var_.Get<framework::LoDTensor>());
} else if (type_ == framework::proto::VarType::SELECTED_ROWS) {
tensor = &(var_.Get<framework::SelectedRows>().value());
} else {
VLOG(6) << "Variable " << name_ << " is not initialized";
return data_type_;
}
}
if (tensor && tensor->IsInitialized()) {
return tensor->type();
} else {
VLOG(6) << "The tensor of variable " << name_ << " is not initialized";
return data_type_;
}
}
private:
framework::Variable var_;
std::string name_;
// add this property for users may set stop_gradient themselves and this
// should override the frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1};
bool persistable_{false};
framework::proto::VarType::Type type_{framework::proto::VarType::LOD_TENSOR};
framework::proto::VarType::Type data_type_{framework::proto::VarType::FP32};
};
} // namespace imperative
} // namespace paddle
...@@ -68,8 +68,7 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -68,8 +68,7 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
...@@ -86,8 +85,6 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -86,8 +85,6 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) { static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", this->Output("Out")); op->SetInput("Out", this->Output("Out"));
} }
return op;
} }
}; };
...@@ -727,8 +724,7 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -727,8 +724,7 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("relu_grad_grad"); op->SetType("relu_grad_grad");
// input1: Out // input1: Out
op->SetInput("Out", this->Input("Out")); op->SetInput("Out", this->Input("Out"));
...@@ -737,7 +733,6 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -737,7 +733,6 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
// output: ddy // output: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
} }
}; };
...@@ -750,8 +745,7 @@ class LeakyReluDoubleGradMaker ...@@ -750,8 +745,7 @@ class LeakyReluDoubleGradMaker
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("leaky_relu_grad_grad"); op->SetType("leaky_relu_grad_grad");
// input1: Out // input1: Out
op->SetInput("Out", this->Input("Out")); op->SetInput("Out", this->Input("Out"));
...@@ -760,7 +754,6 @@ class LeakyReluDoubleGradMaker ...@@ -760,7 +754,6 @@ class LeakyReluDoubleGradMaker
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
// Out@GRAD@GRAD: ddy // Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
} }
}; };
...@@ -772,8 +765,7 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -772,8 +765,7 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("elu_grad_grad"); op->SetType("elu_grad_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
...@@ -785,7 +777,6 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -785,7 +777,6 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
// Out@GRAD@GRAD: ddy // Out@GRAD@GRAD: ddy
op->SetOutput("DX", this->InputGrad("X")); op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
} }
}; };
...@@ -797,8 +788,7 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -797,8 +788,7 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("sqrt_grad_grad"); op->SetType("sqrt_grad_grad");
op->SetInput("Out", this->Input("Out")); op->SetInput("Out", this->Input("Out"));
op->SetInput("DX", this->Output(framework::GradVarName("X"))); op->SetInput("DX", this->Output(framework::GradVarName("X")));
...@@ -806,7 +796,6 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -806,7 +796,6 @@ class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput("DOut", this->InputGrad("Out")); op->SetOutput("DOut", this->InputGrad("Out"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
} }
}; };
...@@ -818,8 +807,7 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -818,8 +807,7 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("square_grad_grad"); op->SetType("square_grad_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
// Out@GRAD: dy // Out@GRAD: dy
...@@ -833,7 +821,6 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> { ...@@ -833,7 +821,6 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
op->SetOutput("DX", this->InputGrad("X")); op->SetOutput("DX", this->InputGrad("X"));
// Out@GRAD@GRAD: ddy // Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
} }
}; };
...@@ -849,16 +836,13 @@ class PowGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -849,16 +836,13 @@ class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("pow_grad"); op->SetType("pow_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetInput("FactorTensor", this->Input("FactorTensor")); op->SetInput("FactorTensor", this->Input("FactorTensor"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
class PowOp : public framework::OperatorWithKernel { class PowOp : public framework::OperatorWithKernel {
......
...@@ -93,13 +93,11 @@ class AddPositionEncodingGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -93,13 +93,11 @@ class AddPositionEncodingGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("add_position_encoding_grad"); op->SetType("add_position_encoding_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -132,8 +132,7 @@ class AffineChannelGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -132,8 +132,7 @@ class AffineChannelGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("affine_channel_grad"); op->SetType("affine_channel_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
...@@ -144,8 +143,6 @@ class AffineChannelGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -144,8 +143,6 @@ class AffineChannelGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -204,8 +204,7 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -204,8 +204,7 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("affine_grid_grad"); op->SetType("affine_grid_grad");
op->SetInput("OutputShape", this->Input("OutputShape")); op->SetInput("OutputShape", this->Input("OutputShape"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
...@@ -213,7 +212,6 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -213,7 +212,6 @@ class AffineGridGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("Theta"), this->InputGrad("Theta")); op->SetOutput(framework::GradVarName("Theta"), this->InputGrad("Theta"));
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -108,15 +108,13 @@ class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -108,15 +108,13 @@ class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("argsort_grad"); op->SetType("argsort_grad");
op->SetInput("Indices", this->Output("Indices")); op->SetInput("Indices", this->Output("Indices"));
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -216,14 +216,12 @@ class ArrayToLoDTensorGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -216,14 +216,12 @@ class ArrayToLoDTensorGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("lod_tensor_to_array"); grad_op->SetType("lod_tensor_to_array");
grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetInput("RankTable", this->Input("RankTable")); grad_op->SetInput("RankTable", this->Input("RankTable"));
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -99,12 +99,10 @@ class AssignGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -99,12 +99,10 @@ class AssignGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto *op = new T();
op->SetType("assign"); op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out")); op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("X")); op->SetOutput("Out", this->InputGrad("X"));
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -721,8 +721,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -721,8 +721,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}; };
template <typename T> template <typename T>
std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const { void BatchNormGradMaker<T>::Apply(GradOpPtr<T> op) const {
auto *op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
...@@ -746,8 +745,6 @@ std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const { ...@@ -746,8 +745,6 @@ std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<T>(op);
} }
} // namespace operators } // namespace operators
......
...@@ -165,7 +165,7 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -165,7 +165,7 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override; void Apply(GradOpPtr<T> op) const override;
}; };
class BatchNormOpInferVarType class BatchNormOpInferVarType
......
...@@ -154,8 +154,7 @@ class BilinearTensorProductGradOpMaker ...@@ -154,8 +154,7 @@ class BilinearTensorProductGradOpMaker
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("bilinear_tensor_product_grad"); op->SetType("bilinear_tensor_product_grad");
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
...@@ -169,8 +168,6 @@ class BilinearTensorProductGradOpMaker ...@@ -169,8 +168,6 @@ class BilinearTensorProductGradOpMaker
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight")); op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
return op;
} }
}; };
......
...@@ -141,15 +141,13 @@ class BprLossGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -141,15 +141,13 @@ class BprLossGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("bpr_loss_grad"); op->SetType("bpr_loss_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -44,14 +44,12 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -44,14 +44,12 @@ class CastOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
auto grad = new T();
grad->SetType("cast"); grad->SetType("cast");
grad->SetInput("X", this->OutputGrad("Out")); grad->SetInput("X", this->OutputGrad("Out"));
grad->SetOutput("Out", this->InputGrad("X")); grad->SetOutput("Out", this->InputGrad("X"));
grad->SetAttr("out_dtype", this->GetAttr("in_dtype")); grad->SetAttr("out_dtype", this->GetAttr("in_dtype"));
grad->SetAttr("in_dtype", this->GetAttr("out_dtype")); grad->SetAttr("in_dtype", this->GetAttr("out_dtype"));
return std::unique_ptr<T>(grad);
} }
}; };
......
...@@ -129,8 +129,7 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -129,8 +129,7 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> retv) const override {
std::unique_ptr<T> retv(new T());
retv->SetType("center_loss_grad"); retv->SetType("center_loss_grad");
retv->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss")); retv->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
retv->SetInput("SampleCenterDiff", this->Output("SampleCenterDiff")); retv->SetInput("SampleCenterDiff", this->Output("SampleCenterDiff"));
...@@ -138,7 +137,6 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -138,7 +137,6 @@ class CenterLossOpGradMaker : public framework::SingleGradOpMaker<T> {
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs()); retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
......
...@@ -84,14 +84,12 @@ class ClipGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -84,14 +84,12 @@ class ClipGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("clip_grad"); op->SetType("clip_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -62,13 +62,11 @@ class CAllGatherOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -62,13 +62,11 @@ class CAllGatherOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> retv) const override {
std::unique_ptr<T> retv(new T());
retv->SetType("c_reducescatter"); retv->SetType("c_reducescatter");
retv->SetInput("X", this->OutputGrad("Out")); retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X")); retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs()); retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
......
...@@ -23,13 +23,11 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -23,13 +23,11 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> retv) const override {
std::unique_ptr<T> retv(new T());
retv->SetType("c_allreduce_sum"); retv->SetType("c_allreduce_sum");
retv->SetInput("X", this->OutputGrad("Out")); retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X")); retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs()); retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
......
...@@ -66,13 +66,11 @@ class CReduceScatterOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -66,13 +66,11 @@ class CReduceScatterOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> retv) const override {
std::unique_ptr<T> retv(new T());
retv->SetType("c_allgather"); retv->SetType("c_allgather");
retv->SetInput("X", this->OutputGrad("Out")); retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X")); retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs()); retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
......
...@@ -185,8 +185,7 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -185,8 +185,7 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("concat_grad"); op->SetType("concat_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
if (this->HasInput("AxisTensor")) { if (this->HasInput("AxisTensor")) {
...@@ -195,7 +194,6 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -195,7 +194,6 @@ class ConcatGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -260,8 +260,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -260,8 +260,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto grad_op = new T();
grad_op->SetType("conditional_block_grad"); grad_op->SetType("conditional_block_grad");
grad_op->SetInput(ConditionalOp::kCondition, grad_op->SetInput(ConditionalOp::kCondition,
this->Input(ConditionalOp::kCondition)); this->Input(ConditionalOp::kCondition));
...@@ -278,7 +277,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -278,7 +277,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", grad_op->SetAttr("is_scalar_condition",
this->GetAttr("is_scalar_condition")); this->GetAttr("is_scalar_condition"));
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -215,15 +215,13 @@ class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -215,15 +215,13 @@ class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("read_from_array"); grad_op->SetType("read_from_array");
grad_op->SetInput("I", this->Input("I")); grad_op->SetInput("I", this->Input("I"));
grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetInput("X_W", this->Input("X")); grad_op->SetInput("X_W", this->Input("X"));
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
...@@ -233,14 +231,12 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -233,14 +231,12 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("write_to_array"); grad_op->SetType("write_to_array");
grad_op->SetInput("I", this->Input("I")); grad_op->SetInput("I", this->Input("I"));
grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -329,8 +329,7 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -329,8 +329,7 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> while_grad) const override {
auto *while_grad = new T();
while_grad->SetType("while_grad"); while_grad->SetType("while_grad");
while_grad->SetInput(kX, this->Input(kX)); while_grad->SetInput(kX, this->Input(kX));
while_grad->SetInput(kOutputs, this->Output(kOutputs)); while_grad->SetInput(kOutputs, this->Output(kOutputs));
...@@ -402,8 +401,6 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -402,8 +401,6 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
while_grad->SetAttr("original_output_grad", output_grads_list); while_grad->SetAttr("original_output_grad", output_grads_list);
while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>()); while_grad->SetAttr(kSkipEagerDeletionVars, std::vector<std::string>());
return std::unique_ptr<T>(while_grad);
} }
}; };
......
...@@ -612,8 +612,7 @@ class Conv2DGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -612,8 +612,7 @@ class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
...@@ -624,8 +623,6 @@ class Conv2DGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -624,8 +623,6 @@ class Conv2DGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter")); op->SetOutput(framework::GradVarName("Filter"), this->InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
...@@ -634,8 +631,7 @@ class Conv3DGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -634,8 +631,7 @@ class Conv3DGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
...@@ -649,8 +645,6 @@ class Conv3DGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -649,8 +645,6 @@ class Conv3DGradMaker : public framework::SingleGradOpMaker<T> {
} }
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
...@@ -663,8 +657,7 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -663,8 +657,7 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW // I, W, dO, ddI, ddW
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
...@@ -682,16 +675,14 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -682,16 +675,14 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput("DDOutput", op->SetOutput("DDOutput",
ddx.empty() ddx.empty()
? this->Empty() ? this->EmptyInputGrad()
: this->InputGrad(framework::GradVarName("Output"))); : this->InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", op->SetOutput("DFilter", ddx.empty() ? this->EmptyInputGrad()
ddx.empty() ? this->Empty() : this->InputGrad("Filter")); : this->InputGrad("Filter"));
op->SetOutput("DInput", op->SetOutput("DInput", ddw.empty() ? this->EmptyInputGrad()
ddw.empty() ? this->Empty() : this->InputGrad("Input")); : this->InputGrad("Input"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
...@@ -704,8 +695,7 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -704,8 +695,7 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW // I, W, dO, ddI, ddW
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
...@@ -720,16 +710,14 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -720,16 +710,14 @@ class Conv3DDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput("DDOutput", op->SetOutput("DDOutput",
ddx.empty() ddx.empty()
? this->Empty() ? this->EmptyInputGrad()
: this->InputGrad(framework::GradVarName("Output"))); : this->InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", op->SetOutput("DFilter", ddx.empty() ? this->EmptyInputGrad()
ddx.empty() ? this->Empty() : this->InputGrad("Filter")); : this->InputGrad("Filter"));
op->SetOutput("DInput", op->SetOutput("DInput", ddw.empty() ? this->EmptyInputGrad()
ddw.empty() ? this->Empty() : this->InputGrad("Input")); : this->InputGrad("Input"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -199,8 +199,7 @@ class ConvShiftGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -199,8 +199,7 @@ class ConvShiftGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("conv_shift_grad"); op->SetType("conv_shift_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
...@@ -208,7 +207,6 @@ class ConvShiftGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -208,7 +207,6 @@ class ConvShiftGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -441,8 +441,7 @@ class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -441,8 +441,7 @@ class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
...@@ -454,7 +453,6 @@ class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -454,7 +453,6 @@ class ConvTransposeGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -172,8 +172,7 @@ class CosSimGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -172,8 +172,7 @@ class CosSimGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto* grad_op = new T();
grad_op->SetType("cos_sim_grad"); grad_op->SetType("cos_sim_grad");
grad_op->SetInput("X", this->Input("X")); grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput("Y", this->Input("Y")); grad_op->SetInput("Y", this->Input("Y"));
...@@ -184,7 +183,6 @@ class CosSimGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -184,7 +183,6 @@ class CosSimGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); grad_op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -187,8 +187,7 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -187,8 +187,7 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("crop_grad"); op->SetType("crop_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
...@@ -197,7 +196,6 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -197,7 +196,6 @@ class CropGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -279,8 +279,7 @@ class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -279,8 +279,7 @@ class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("crop_tensor_grad"); op->SetType("crop_tensor_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
...@@ -292,7 +291,6 @@ class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -292,7 +291,6 @@ class CropTensorGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -263,15 +263,13 @@ class CrossEntropyGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -263,15 +263,13 @@ class CrossEntropyGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("cross_entropy_grad"); op->SetType("cross_entropy_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
...@@ -372,8 +370,7 @@ class CrossEntropyGradOpMaker2 : public framework::SingleGradOpMaker<T> { ...@@ -372,8 +370,7 @@ class CrossEntropyGradOpMaker2 : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("cross_entropy_grad2"); op->SetType("cross_entropy_grad2");
op->SetInput("Label", this->Input("Label")); op->SetInput("Label", this->Input("Label"));
op->SetInput("MatchX", this->Output("MatchX")); op->SetInput("MatchX", this->Output("MatchX"));
...@@ -381,7 +378,6 @@ class CrossEntropyGradOpMaker2 : public framework::SingleGradOpMaker<T> { ...@@ -381,7 +378,6 @@ class CrossEntropyGradOpMaker2 : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -203,8 +203,7 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -203,8 +203,7 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("cudnn_lstm_grad"); op->SetType("cudnn_lstm_grad");
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("InitH", this->Input("InitH")); op->SetInput("InitH", this->Input("InitH"));
...@@ -223,7 +222,6 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -223,7 +222,6 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH")); op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH"));
op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC")); op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -58,15 +58,13 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -58,15 +58,13 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("cumsum"); grad_op->SetType("cumsum");
grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", boost::get<int>(this->GetAttr("axis"))); grad_op->SetAttr("axis", boost::get<int>(this->GetAttr("axis")));
grad_op->SetAttr("reverse", !boost::get<bool>(this->GetAttr("reverse"))); grad_op->SetAttr("reverse", !boost::get<bool>(this->GetAttr("reverse")));
grad_op->SetAttr("exclusive", boost::get<bool>(this->GetAttr("exclusive"))); grad_op->SetAttr("exclusive", boost::get<bool>(this->GetAttr("exclusive")));
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -131,15 +131,13 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -131,15 +131,13 @@ class CVMGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("cvm_grad"); op->SetType("cvm_grad");
op->SetInput("CVM", this->Input("CVM")); op->SetInput("CVM", this->Input("CVM"));
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -461,8 +461,7 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -461,8 +461,7 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto *op = new T();
op->SetType("data_norm_grad"); op->SetType("data_norm_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
...@@ -482,8 +481,6 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -482,8 +481,6 @@ class DataNormGradMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("BatchSum")); this->InputGrad("BatchSum"));
op->SetOutput(framework::GradVarName("BatchSquareSum"), op->SetOutput(framework::GradVarName("BatchSquareSum"),
this->InputGrad("BatchSquareSum")); this->InputGrad("BatchSquareSum"));
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -228,9 +228,7 @@ class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -228,9 +228,7 @@ class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("deformable_conv_grad"); op->SetType("deformable_conv_grad");
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
...@@ -244,7 +242,6 @@ class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -244,7 +242,6 @@ class DeformableConvGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Mask"), this->InputGrad("Mask")); op->SetOutput(framework::GradVarName("Mask"), this->InputGrad("Mask"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -211,9 +211,7 @@ class DeformableConvV1GradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -211,9 +211,7 @@ class DeformableConvV1GradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("deformable_conv_v1_grad"); op->SetType("deformable_conv_v1_grad");
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Filter", this->Input("Filter")); op->SetInput("Filter", this->Input("Filter"));
...@@ -225,7 +223,6 @@ class DeformableConvV1GradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -225,7 +223,6 @@ class DeformableConvV1GradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Offset"), this->InputGrad("Offset")); op->SetOutput(framework::GradVarName("Offset"), this->InputGrad("Offset"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -211,9 +211,7 @@ class DeformablePSROIPoolGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -211,9 +211,7 @@ class DeformablePSROIPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("deformable_psroi_pooling_grad"); op->SetType("deformable_psroi_pooling_grad");
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
op->SetInput("Trans", this->Input("Trans")); op->SetInput("Trans", this->Input("Trans"));
...@@ -225,7 +223,6 @@ class DeformablePSROIPoolGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -225,7 +223,6 @@ class DeformablePSROIPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Trans"), this->InputGrad("Trans")); op->SetOutput(framework::GradVarName("Trans"), this->InputGrad("Trans"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -627,8 +627,7 @@ class ROIPerspectiveTransformGradMaker ...@@ -627,8 +627,7 @@ class ROIPerspectiveTransformGradMaker
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("roi_perspective_transform_grad"); op->SetType("roi_perspective_transform_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("ROIs", this->Input("ROIs")); op->SetInput("ROIs", this->Input("ROIs"));
...@@ -637,7 +636,6 @@ class ROIPerspectiveTransformGradMaker ...@@ -637,7 +636,6 @@ class ROIPerspectiveTransformGradMaker
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -178,8 +178,7 @@ class SigmoidFocalLossGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -178,8 +178,7 @@ class SigmoidFocalLossGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("sigmoid_focal_loss_grad"); op->SetType("sigmoid_focal_loss_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label")); op->SetInput("Label", this->Input("Label"));
...@@ -187,7 +186,6 @@ class SigmoidFocalLossGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -187,7 +186,6 @@ class SigmoidFocalLossGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -269,8 +269,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -269,8 +269,7 @@ class Yolov3LossGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("yolov3_loss_grad"); op->SetType("yolov3_loss_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("GTBox", this->Input("GTBox")); op->SetInput("GTBox", this->Input("GTBox"));
...@@ -283,10 +282,9 @@ class Yolov3LossGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -283,10 +282,9 @@ class Yolov3LossGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {}); op->SetOutput(framework::GradVarName("GTBox"), this->EmptyInputGrad());
op->SetOutput(framework::GradVarName("GTLabel"), {}); op->SetOutput(framework::GradVarName("GTLabel"), this->EmptyInputGrad());
op->SetOutput(framework::GradVarName("GTScore"), {}); op->SetOutput(framework::GradVarName("GTScore"), this->EmptyInputGrad());
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -144,14 +144,12 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -144,14 +144,12 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("dropout_grad"); op->SetType("dropout_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("Mask", this->Output("Mask")); op->SetInput("Mask", this->Output("Mask"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -76,8 +76,7 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -76,8 +76,7 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_add_grad_grad"); op->SetType("elementwise_add_grad_grad");
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out"))); op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
...@@ -87,7 +86,6 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -87,7 +86,6 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return op;
} }
}; };
......
...@@ -73,8 +73,7 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -73,8 +73,7 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_div_grad"); op->SetType("elementwise_div_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
...@@ -83,7 +82,6 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -83,7 +82,6 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
...@@ -93,8 +91,7 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -93,8 +91,7 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_div_grad_grad"); op->SetType("elementwise_div_grad_grad");
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
op->SetInput("Out", this->Input("Out")); op->SetInput("Out", this->Input("Out"));
...@@ -107,8 +104,6 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -107,8 +104,6 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetOutput("DOut", this->InputGrad("Out")); op->SetOutput("DOut", this->InputGrad("Out"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return op;
} }
}; };
......
...@@ -49,8 +49,7 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -49,8 +49,7 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_max_grad"); op->SetType("elementwise_max_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
...@@ -58,7 +57,6 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -58,7 +57,6 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -49,8 +49,7 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -49,8 +49,7 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_min_grad"); op->SetType("elementwise_min_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
...@@ -58,7 +57,6 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -58,7 +57,6 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -76,8 +76,7 @@ class ElementwiseMulOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -76,8 +76,7 @@ class ElementwiseMulOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_mul_grad"); op->SetType("elementwise_mul_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
...@@ -85,7 +84,6 @@ class ElementwiseMulOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -85,7 +84,6 @@ class ElementwiseMulOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
} }
}; };
...@@ -95,8 +93,7 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -95,8 +93,7 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_mul_grad_grad"); op->SetType("elementwise_mul_grad_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
...@@ -109,7 +106,6 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -109,7 +106,6 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
} }
}; };
......
...@@ -363,8 +363,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference, ...@@ -363,8 +363,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; \ using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker; \
\ \
protected: \ protected: \
std::unique_ptr<T> Apply() const override { \ void Apply(::paddle::framework::GradOpPtr<T> op) const override { \
auto *op = new T(); \
op->SetType(#kernel_type "_grad"); \ op->SetType(#kernel_type "_grad"); \
op->SetInput("X", this->Input("X")); \ op->SetInput("X", this->Input("X")); \
op->SetInput("Y", this->Input("Y")); \ op->SetInput("Y", this->Input("Y")); \
...@@ -375,7 +374,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference, ...@@ -375,7 +374,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
this->InputGrad("X")); \ this->InputGrad("X")); \
op->SetOutput(::paddle::framework::GradVarName("Y"), \ op->SetOutput(::paddle::framework::GradVarName("Y"), \
this->InputGrad("Y")); \ this->InputGrad("Y")); \
return std::unique_ptr<T>(op); \
} \ } \
} }
......
...@@ -23,8 +23,7 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -23,8 +23,7 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_pow_grad"); op->SetType("elementwise_pow_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
...@@ -32,7 +31,6 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -32,7 +31,6 @@ class ElementwisePowOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
} }
}; };
class ElementwisePowOpMaker : public ElementwiseOpMaker { class ElementwisePowOpMaker : public ElementwiseOpMaker {
......
...@@ -75,8 +75,7 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -75,8 +75,7 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_sub_grad_grad"); op->SetType("elementwise_sub_grad_grad");
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out"))); op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
...@@ -86,7 +85,6 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -86,7 +85,6 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return op;
} }
}; };
......
...@@ -101,14 +101,12 @@ class ErfGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -101,14 +101,12 @@ class ErfGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("erf_grad"); grad_op->SetType("erf_grad");
grad_op->SetInput("X", this->Input("X")); grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -103,15 +103,13 @@ class ExpandAsGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -103,15 +103,13 @@ class ExpandAsGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("expand_as_grad"); op->SetType("expand_as_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("target_tensor", this->Input("target_tensor")); op->SetInput("target_tensor", this->Input("target_tensor"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -203,8 +203,7 @@ class ExpandGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -203,8 +203,7 @@ class ExpandGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("expand_grad"); op->SetType("expand_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
...@@ -212,7 +211,6 @@ class ExpandGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -212,7 +211,6 @@ class ExpandGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("expand_times_tensor", this->Input("expand_times_tensor")); op->SetInput("expand_times_tensor", this->Input("expand_times_tensor"));
op->SetInput("ExpandTimes", this->Input("ExpandTimes")); op->SetInput("ExpandTimes", this->Input("ExpandTimes"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -113,8 +113,7 @@ class FilterByInstagGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -113,8 +113,7 @@ class FilterByInstagGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("filter_by_instag_grad"); op->SetType("filter_by_instag_grad");
op->SetInput("IndexMap", this->Output("IndexMap")); op->SetInput("IndexMap", this->Output("IndexMap"));
op->SetInput("Ins", this->Input("Ins")); op->SetInput("Ins", this->Input("Ins"));
...@@ -122,7 +121,6 @@ class FilterByInstagGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -122,7 +121,6 @@ class FilterByInstagGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("LossWeight", this->Output("LossWeight")); op->SetInput("LossWeight", this->Output("LossWeight"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("Ins"), this->InputGrad("Ins")); op->SetOutput(framework::GradVarName("Ins"), this->InputGrad("Ins"));
return op;
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -142,14 +142,12 @@ class FlattenGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -142,14 +142,12 @@ class FlattenGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("flatten_grad"); grad_op->SetType("flatten_grad");
grad_op->SetInput("X", this->Input("X")); grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
...@@ -211,14 +209,12 @@ class Flatten2GradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -211,14 +209,12 @@ class Flatten2GradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("flatten2_grad"); grad_op->SetType("flatten2_grad");
grad_op->SetInput("XShape", this->Output("XShape")); grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -121,9 +121,7 @@ class FSPGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -121,9 +121,7 @@ class FSPGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("fsp_grad"); op->SetType("fsp_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
...@@ -134,8 +132,6 @@ class FSPGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -134,8 +132,6 @@ class FSPGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
return op;
} }
}; };
......
...@@ -61,8 +61,7 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -61,8 +61,7 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Output("Y")); op->SetInput("Y", this->Output("Y"));
...@@ -79,8 +78,6 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -79,8 +78,6 @@ class FusedBatchNormActGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return op;
} }
}; };
......
...@@ -227,8 +227,7 @@ class FusedElemwiseActivationGradMaker ...@@ -227,8 +227,7 @@ class FusedElemwiseActivationGradMaker
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType(this->ForwardOpType() + "_grad"); grad_op->SetType(this->ForwardOpType() + "_grad");
for (auto &input_param : this->InputNames()) { for (auto &input_param : this->InputNames()) {
...@@ -255,11 +254,10 @@ class FusedElemwiseActivationGradMaker ...@@ -255,11 +254,10 @@ class FusedElemwiseActivationGradMaker
grad_op->SetOutput(framework::GradVarName("IntermediateOut"), grad_op->SetOutput(framework::GradVarName("IntermediateOut"),
this->OutputGrad("IntermediateOut")); this->OutputGrad("IntermediateOut"));
} else { } else {
grad_op->SetInput("IntermediateOut", {}); grad_op->SetInput("IntermediateOut", this->EmptyOutput());
grad_op->SetOutput(framework::GradVarName("IntermediateOut"), {}); grad_op->SetOutput(framework::GradVarName("IntermediateOut"),
this->EmptyOutputGrad());
} }
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -158,15 +158,13 @@ class FusedEmbeddingSeqPoolGradOpMaker ...@@ -158,15 +158,13 @@ class FusedEmbeddingSeqPoolGradOpMaker
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("fused_embedding_seq_pool_grad"); op->SetType("fused_embedding_seq_pool_grad");
op->SetInput("Ids", this->Input("Ids")); op->SetInput("Ids", this->Input("Ids"));
op->SetInput("W", this->Input("W")); op->SetInput("W", this->Input("W"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -151,15 +151,13 @@ class GatherNdGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -151,15 +151,13 @@ class GatherNdGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("gather_nd_grad"); op->SetType("gather_nd_grad");
op->SetInput("Index", this->Input("Index")); op->SetInput("Index", this->Input("Index"));
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -113,15 +113,13 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -113,15 +113,13 @@ class GatherGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("gather_grad"); op->SetType("gather_grad");
op->SetInput("Index", this->Input("Index")); op->SetInput("Index", this->Input("Index"));
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -176,8 +176,7 @@ class GridSampleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -176,8 +176,7 @@ class GridSampleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("grid_sampler_grad"); op->SetType("grid_sampler_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Grid", this->Input("Grid")); op->SetInput("Grid", this->Input("Grid"));
...@@ -187,7 +186,6 @@ class GridSampleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -187,7 +186,6 @@ class GridSampleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Grid"), this->InputGrad("Grid")); op->SetOutput(framework::GradVarName("Grid"), this->InputGrad("Grid"));
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -189,8 +189,7 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -189,8 +189,7 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto *op = new T();
op->SetType("group_norm_grad"); op->SetType("group_norm_grad");
op->SetInput("Scale", this->Input("Scale")); op->SetInput("Scale", this->Input("Scale"));
op->SetInput("Bias", this->Input("Bias")); op->SetInput("Bias", this->Input("Bias"));
...@@ -203,8 +202,6 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -203,8 +202,6 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -390,8 +390,7 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -390,8 +390,7 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto* grad_op = new T();
grad_op->SetType("gru_grad"); grad_op->SetType("gru_grad");
grad_op->SetInput("Input", this->Input("Input")); grad_op->SetInput("Input", this->Input("Input"));
grad_op->SetInput("H0", this->Input("H0")); grad_op->SetInput("H0", this->Input("H0"));
...@@ -415,7 +414,6 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -415,7 +414,6 @@ class GRUGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); grad_op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -212,8 +212,7 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -212,8 +212,7 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("gru_unit_grad"); op->SetType("gru_unit_grad");
op->SetInput("Input", this->Input("Input")); op->SetInput("Input", this->Input("Input"));
...@@ -232,7 +231,6 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -232,7 +231,6 @@ class GRUUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("HiddenPrev")); this->InputGrad("HiddenPrev"));
op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight")); op->SetOutput(framework::GradVarName("Weight"), this->InputGrad("Weight"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -189,8 +189,7 @@ class HierarchicalSigmoidGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -189,8 +189,7 @@ class HierarchicalSigmoidGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
// Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD // Inputs: X, W, Label, PathTable, PathCode, PreOut, Out@GRAD
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
...@@ -207,8 +206,6 @@ class HierarchicalSigmoidGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -207,8 +206,6 @@ class HierarchicalSigmoidGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -106,15 +106,13 @@ class HingeLossGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -106,15 +106,13 @@ class HingeLossGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("hinge_loss_grad"); op->SetType("hinge_loss_grad");
op->SetInput("Logits", this->Input("Logits")); op->SetInput("Logits", this->Input("Logits"));
op->SetInput("Labels", this->Input("Labels")); op->SetInput("Labels", this->Input("Labels"));
op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss")); op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss"));
op->SetOutput(framework::GradVarName("Logits"), this->InputGrad("Logits")); op->SetOutput(framework::GradVarName("Logits"), this->InputGrad("Logits"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -121,15 +121,13 @@ class HuberLossGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -121,15 +121,13 @@ class HuberLossGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("huber_loss_grad"); op->SetType("huber_loss_grad");
op->SetInput("Residual", this->Output("Residual")); op->SetInput("Residual", this->Output("Residual"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -159,14 +159,12 @@ class Im2SequenceGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -159,14 +159,12 @@ class Im2SequenceGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("im2sequence_grad"); op->SetType("im2sequence_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -71,13 +71,11 @@ class IncrementGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -71,13 +71,11 @@ class IncrementGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("increment"); grad_op->SetType("increment");
grad_op->SetInput("X", this->Output("Out")); grad_op->SetInput("X", this->Output("Out"));
grad_op->SetOutput("Out", this->Input("X")); grad_op->SetOutput("Out", this->Input("X"));
grad_op->SetAttr("step", -boost::get<float>(this->GetAttr("step"))); grad_op->SetAttr("step", -boost::get<float>(this->GetAttr("step")));
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -80,8 +80,7 @@ class InstanceNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -80,8 +80,7 @@ class InstanceNormGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto *op = new T();
op->SetType("instance_norm_grad"); op->SetType("instance_norm_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
...@@ -94,8 +93,6 @@ class InstanceNormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -94,8 +93,6 @@ class InstanceNormGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
return std::unique_ptr<T>(op);
} }
}; };
...@@ -105,8 +102,7 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -105,8 +102,7 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto *op = new T();
op->SetType("instance_norm_grad_grad"); op->SetType("instance_norm_grad_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Scale", this->Input("Scale")); op->SetInput("Scale", this->Input("Scale"));
...@@ -121,7 +117,6 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -121,7 +117,6 @@ class InstanceNormDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput("DX", this->InputGrad("X")); op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DScale", this->InputGrad("Scale")); op->SetOutput("DScale", this->InputGrad("Scale"));
op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y"))); op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y")));
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -429,8 +429,7 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -429,8 +429,7 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType(this->ForwardOpType() + "_grad"); op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
if (this->HasInput("SizeTensor") > 0) { if (this->HasInput("SizeTensor") > 0) {
...@@ -445,7 +444,6 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -445,7 +444,6 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -148,8 +148,7 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -148,8 +148,7 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("kldiv_loss_grad"); op->SetType("kldiv_loss_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Target", this->Input("Target")); op->SetInput("Target", this->Input("Target"));
...@@ -158,7 +157,6 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -158,7 +157,6 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -69,14 +69,12 @@ class L1NormGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -69,14 +69,12 @@ class L1NormGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("l1_norm_grad"); op->SetType("l1_norm_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -117,13 +117,11 @@ class LabelSmoothGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -117,13 +117,11 @@ class LabelSmoothGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("label_smooth_grad"); op->SetType("label_smooth_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -171,8 +171,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -171,8 +171,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("layer_norm_grad"); op->SetType("layer_norm_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Mean", this->Output("Mean")); op->SetInput("Mean", this->Output("Mean"));
...@@ -189,7 +188,6 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -189,7 +188,6 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -280,8 +280,7 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -280,8 +280,7 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("linear_chain_crf_grad"); op->SetType("linear_chain_crf_grad");
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
op->SetInput("Emission", this->Input("Emission")); op->SetInput("Emission", this->Input("Emission"));
...@@ -300,8 +299,6 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -300,8 +299,6 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> {
this->InputGrad("Emission")); this->InputGrad("Emission"));
op->SetOutput(framework::GradVarName("Transition"), op->SetOutput(framework::GradVarName("Transition"),
this->InputGrad("Transition")); this->InputGrad("Transition"));
return op;
} }
}; };
......
...@@ -208,14 +208,12 @@ class LoDResetGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -208,14 +208,12 @@ class LoDResetGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("lod_reset_grad"); op->SetType("lod_reset_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
...@@ -232,14 +232,12 @@ class LoDTensorToArrayGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -232,14 +232,12 @@ class LoDTensorToArrayGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad_op) const override {
auto *grad_op = new T();
grad_op->SetType("array_to_lod_tensor"); grad_op->SetType("array_to_lod_tensor");
grad_op->SetInput("X", this->OutputGrad("Out")); grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetInput("RankTable", this->Input("RankTable")); grad_op->SetInput("RankTable", this->Input("RankTable"));
grad_op->SetOutput("Out", this->InputGrad("X")); grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs()); grad_op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(grad_op);
} }
}; };
......
...@@ -111,8 +111,7 @@ class LogLossGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -111,8 +111,7 @@ class LogLossGradMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
std::unique_ptr<T> op(new T());
op->SetType("log_loss_grad"); op->SetType("log_loss_grad");
op->SetInput("Predicted", this->Input("Predicted")); op->SetInput("Predicted", this->Input("Predicted"));
op->SetInput("Labels", this->Input("Labels")); op->SetInput("Labels", this->Input("Labels"));
...@@ -120,7 +119,6 @@ class LogLossGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -120,7 +119,6 @@ class LogLossGradMaker : public framework::SingleGradOpMaker<T> {
op->SetOutput(framework::GradVarName("Predicted"), op->SetOutput(framework::GradVarName("Predicted"),
this->InputGrad("Predicted")); this->InputGrad("Predicted"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return op;
} }
}; };
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册