提交 f15e0830 编写于 作者: Y Yu Yang

Remove std::shared_ptr in Python & C++

* Also simplify pybind implementation by using OperatorBase as holder
  type.
上级 7f5338a7
...@@ -89,7 +89,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -89,7 +89,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
} }
// Returned gradient network // Returned gradient network
auto net = std::unique_ptr<operators::NetOp>(); auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast. // Because forwardOp is a net op, it can static_cast.
...@@ -204,7 +204,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -204,7 +204,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
} }
// See header for comments // See header for comments
std::shared_ptr<OperatorBase> Backward( std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names; std::unordered_set<std::string> no_grad_names;
......
...@@ -20,7 +20,7 @@ namespace framework { ...@@ -20,7 +20,7 @@ namespace framework {
// Create the backward operator from a forward operator. // Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment. // TODO(yuyang18): Add more API reference comment.
extern std::shared_ptr<OperatorBase> Backward( extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework } // namespace framework
......
...@@ -180,8 +180,7 @@ TEST(Backward, simple_op_not_need_grad) { ...@@ -180,8 +180,7 @@ TEST(Backward, simple_op_not_need_grad) {
auto no_input_gop = f::Backward(*fwd, {"x", "b"}); auto no_input_gop = f::Backward(*fwd, {"x", "b"});
ASSERT_NE(no_input_gop, nullptr); ASSERT_NE(no_input_gop, nullptr);
ASSERT_TRUE(no_input_gop->IsNetOp()); ASSERT_TRUE(no_input_gop->IsNetOp());
ASSERT_EQ(0UL, ASSERT_EQ(0UL, static_cast<ops::NetOp *>(no_input_gop.get())->ops_.size());
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
} }
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
......
...@@ -48,29 +48,6 @@ namespace framework { ...@@ -48,29 +48,6 @@ namespace framework {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename ClassType>
void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("type",
[](const typename ClassType::type &op) -> std::string {
return op.Type();
})
.def("outputs",
[](const typename ClassType::type &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs",
[](const typename ClassType::type &op) { return op.Inputs(); })
.def("__str__", &ClassType::type::DebugString)
.def("no_intermediate_outputs",
[](const typename ClassType::type &op) {
return op.OutputVars(false);
})
.def("support_gpu", &ClassType::type::SupportGPU);
}
static size_t UniqueIntegerGenerator() { static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator; static std::atomic<size_t> generator;
return generator.fetch_add(1); return generator.fetch_add(1);
...@@ -207,70 +184,69 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -207,70 +184,69 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>()) .def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>); .def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<OperatorBase> operator_base(m, "Operator"); py::class_<OperatorBase>(m, "Operator")
.def_static("create",
operator_base.def_static("create", [](py::bytes protobin) { [](py::bytes protobin) {
OpDesc desc; OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc"); "Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc);
}); })
.def("backward",
operator_base.def("backward", [](const OperatorBase &forwardOp,
[](const OperatorBase &forwardOp, const std::unordered_set<std::string> &no_grad_vars) {
const std::unordered_set<std::string> &no_grad_vars) { return Backward(forwardOp, no_grad_vars).release();
return Backward(forwardOp, no_grad_vars); })
}); .def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run)
ExposeOperator(operator_base); .def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
py::class_<operators::NetOp> net(m, "Net"); .def("outputs",
[](const OperatorBase &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
.def("__str__", &OperatorBase::DebugString)
.def("no_intermediate_outputs",
[](const OperatorBase &op) { return op.OutputVars(false); })
.def("support_gpu", &OperatorBase::SupportGPU);
net.def_static("create", py::class_<operators::NetOp, OperatorBase>(m, "Net")
[]() -> operators::NetOp * { .def_static("create",
auto *retv = new operators::NetOp; []() -> operators::NetOp * {
retv->SetType("plain_net"); auto *retv = new operators::NetOp;
return retv; retv->SetType("plain_net");
}) return retv;
})
.def("add_op", [](operators::NetOp &self, .def("add_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); }) const OperatorBase &op) { self.AddOp(op); })
.def("add_op",
[](operators::NetOp &self, const operators::NetOp &net) -> void {
self.AddOp(net);
})
.def("add_op",
[](operators::NetOp &self,
const operators::RecurrentOp &rnn) -> void { self.AddOp(rnn); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) { .def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp(); self->CompleteAddOp();
}); });
ExposeOperator(net);
// recurrent_op // recurrent_op
py::class_<operators::RecurrentOp> rnn(m, "RecurrentOp"); py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
.def_static(
rnn.def_static( "create",
"create", [](py::bytes protobin) -> operators::RecurrentOp * {
[](py::bytes protobin) -> operators::RecurrentOp * { OpDesc desc;
OpDesc desc; PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), "Cannot parse user input to OpDesc");
"Cannot parse user input to OpDesc"); PADDLE_ENFORCE(desc.IsInitialized(),
PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s",
"User OpDesc is not initialized, reason %s", desc.InitializationErrorString());
desc.InitializationErrorString()); auto rnn_op = OpRegistry::CreateOp(desc);
auto rnn_op = OpRegistry::CreateOp(desc); return static_cast<operators::RecurrentOp *>(rnn_op.release());
return static_cast<operators::RecurrentOp *>(rnn_op.release()); })
})
.def("set_stepnet", [](operators::RecurrentOp &self, .def("set_stepnet", [](operators::RecurrentOp &self,
const operators::NetOp &net) -> void { const operators::NetOp &net) -> void {
self.set_stepnet(net.Clone()); self.set_stepnet(net.Clone());
}); });
ExposeOperator(rnn);
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
......
...@@ -41,9 +41,7 @@ class NetOp : public framework::OperatorBase { ...@@ -41,9 +41,7 @@ class NetOp : public framework::OperatorBase {
NetOp(const std::string& type, const VarNameMap& inputs, NetOp(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const framework::AttributeMap& attrs); const VarNameMap& outputs, const framework::AttributeMap& attrs);
NetOp(const NetOp& o) NetOp(const NetOp& o) : framework::OperatorBase(o.type_, {}, {}, o.attrs_) {
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
this->ops_.reserve(o.ops_.size()); this->ops_.reserve(o.ops_.size());
std::transform( std::transform(
o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_),
......
...@@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { ...@@ -42,7 +42,7 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1, rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
stepnet_->InferShape(*step_scopes[i]); (*stepnet_)->InferShape(*step_scopes[i]);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
...@@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, ...@@ -61,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
stepnet_->Run(*step_scopes[step_id], dev_ctx); (*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
...@@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -76,15 +76,15 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// Now all variables in scope must be created outside of op. // Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_); PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "stepnet_ op has no outputs"); PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!stepnet_->Outputs().empty(), "net_op has no outputs"); PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs");
if (seq_len_ > step_scopes->size()) { if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope(); auto& step_scope = scope.NewScope();
// create step net's temp inputs // create step net's temp inputs
for (auto& input : stepnet_->Inputs()) { for (auto& input : (*stepnet_)->Inputs()) {
// the weight are located in parent scope // the weight are located in parent scope
for (auto& var_name : input.second) { for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) { if (!step_scope.FindVar(var_name)) {
...@@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -93,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
} }
} }
// create stepnet's outputs // create stepnet's outputs
for (const auto& output : stepnet_->Outputs()) { for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) { for (auto& var_name : output.second) {
step_scope.NewVar(var_name); step_scope.NewVar(var_name);
} }
...@@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type, ...@@ -136,7 +136,7 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this); rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, stepnet_.get()); alg_.Init(&arg_, &stepnet_);
} }
class RecurrentAlgorithmProtoAndCheckerMaker class RecurrentAlgorithmProtoAndCheckerMaker
...@@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run( ...@@ -178,7 +178,7 @@ void RecurrentGradientAlgorithm::Run(
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/); false /*infer_shape_mode*/);
} }
stepnet_->Run(*step_scopes[step_id], dev_ctx); (*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
} }
LinkBootMemoryGradients(step_scopes[0], false); LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
...@@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { ...@@ -215,7 +215,7 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
} }
stepnet_->InferShape(*step_scopes[step_id]); (*stepnet_)->InferShape(*step_scopes[step_id]);
} }
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/); true /*infer_shape_mode*/);
...@@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp( ...@@ -228,7 +228,7 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {
rnn::InitArgument(kArgName, &arg_, *this); rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, stepnet_.get()); alg_.Init(&arg_, &stepnet_);
} }
} // namespace operators } // namespace operators
......
...@@ -34,7 +34,8 @@ class RecurrentAlgorithm { ...@@ -34,7 +34,8 @@ class RecurrentAlgorithm {
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const; const platform::DeviceContext& dev_ctx) const;
void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) { void Init(rnn::Argument* arg,
std::unique_ptr<framework::OperatorBase>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg; arg_ = arg;
stepnet_ = stepnet; stepnet_ = stepnet;
...@@ -63,7 +64,7 @@ class RecurrentAlgorithm { ...@@ -63,7 +64,7 @@ class RecurrentAlgorithm {
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
private: private:
framework::OperatorBase* stepnet_; std::unique_ptr<framework::OperatorBase>* stepnet_;
rnn::Argument* arg_; rnn::Argument* arg_;
mutable size_t seq_len_; mutable size_t seq_len_;
}; };
...@@ -80,7 +81,8 @@ class RecurrentGradientAlgorithm { ...@@ -80,7 +81,8 @@ class RecurrentGradientAlgorithm {
* operator. * operator.
*/ */
public: public:
void Init(rnn::Argument* arg, framework::OperatorBase* stepnet) { void Init(rnn::Argument* arg,
std::unique_ptr<framework::OperatorBase>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg); arg_ = std::move(arg);
stepnet_ = stepnet; stepnet_ = stepnet;
...@@ -107,7 +109,7 @@ class RecurrentGradientAlgorithm { ...@@ -107,7 +109,7 @@ class RecurrentGradientAlgorithm {
private: private:
rnn::Argument* arg_; rnn::Argument* arg_;
mutable size_t seq_len_; mutable size_t seq_len_;
framework::OperatorBase* stepnet_; std::unique_ptr<framework::OperatorBase>* stepnet_;
}; };
class RecurrentOp : public framework::OperatorBase { class RecurrentOp : public framework::OperatorBase {
......
...@@ -165,7 +165,6 @@ class GradientChecker(unittest.TestCase): ...@@ -165,7 +165,6 @@ class GradientChecker(unittest.TestCase):
for no_grad in no_grad_set: for no_grad in no_grad_set:
if no_grad not in in_names: if no_grad not in in_names:
raise ValueError("no_grad should be in in_names") raise ValueError("no_grad should be in in_names")
backward_op = core.Operator.backward(forward_op, no_grad_set) backward_op = core.Operator.backward(forward_op, no_grad_set)
bwd_outputs = backward_op.outputs() bwd_outputs = backward_op.outputs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册