提交 4130e5fa 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into add_selected_rows_functor

...@@ -28,15 +28,15 @@ namespace paddle { ...@@ -28,15 +28,15 @@ namespace paddle {
namespace framework { namespace framework {
static inline std::unique_ptr<OperatorBase> CreateGradOp( static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op, const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
const std::unordered_set<std::string>& no_grad_set) { std::unordered_map<std::string, std::string>* grad_to_var) {
OpDescBind op_desc; OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs()); op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs()); op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type()); op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs()); op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type()); auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set); auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var);
std::vector<std::unique_ptr<OperatorBase>> grad_ops; std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size()); grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(), std::transform(grad_descs.begin(), grad_descs.end(),
...@@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() { ...@@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() {
// See Backward.h for details // See Backward.h for details
static std::unique_ptr<OperatorBase> BackwardRecursive( static std::unique_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names,
std::unordered_map<std::string, std::string>* grad_to_var,
size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take // just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic. // too much time for calculation, but it is useful for simplifying logic.
...@@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) { ++it, ++local_op_id) {
auto& fwd = *it; auto& fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); auto bwd = BackwardRecursive(*fwd, no_grad_names, grad_to_var, uniq_id);
ForEachVarName(bwd->Outputs(), ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) { [&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
...@@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
} }
} else { } else {
std::unique_ptr<OperatorBase> grad_op( std::unique_ptr<OperatorBase> grad_op(
CreateGradOp(forwardOp, no_grad_names)); CreateGradOp(forwardOp, no_grad_names, grad_to_var));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) { const std::string& grad_input) {
...@@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
*static_cast<const OperatorBase*>(&rnnop.stepnet()); *static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op // create stepnet's gradient op
rnn_grad_op->set_stepnet( rnn_grad_op->set_stepnet(
BackwardRecursive(stepnet_op, no_grad_names, uniq_id)); BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
} }
if (net->ops_.empty()) { // Current no aux op is added to network if (net->ops_.empty()) { // Current no aux op is added to network
...@@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward( ...@@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward(
no_grad_names.insert(name + kGradVarSuffix); no_grad_names.insert(name + kGradVarSuffix);
} }
size_t uid = 0; size_t uid = 0;
return BackwardRecursive(forwardOp, no_grad_names, uid); std::unordered_map<std::string, std::string> grad_to_var;
return BackwardRecursive(forwardOp, no_grad_names, &grad_to_var, uid);
} }
// ==================================== // // ==================================== //
...@@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names,
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const std::unique_ptr<OpDescBind>& op_desc, const std::unique_ptr<OpDescBind>& op_desc,
std::unordered_set<std::string>& no_grad_vars) { std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculate. // All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames(); const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
if (AllGradInSet(inputs, no_grad_vars)) { if (AllGradInSet(inputs, *no_grad_vars)) {
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& outputs = op_desc->OutputArgumentNames(); const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
if (AllGradInSet(outputs, no_grad_vars)) { if (AllGradInSet(outputs, *no_grad_vars)) {
for (const std::string& name : inputs) { for (const std::string& name : inputs) {
no_grad_vars.insert(GradVarName(name)); no_grad_vars->insert(GradVarName(name));
} }
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
grad_op_descs = OpInfoMap::Instance() grad_op_descs = OpInfoMap::Instance()
.Get(op_desc->Type()) .Get(op_desc->Type())
.GradOpMaker()(*op_desc, no_grad_vars); .GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
for (const std::string& in_name : desc->InputArgumentNames()) { for (const std::string& in_name : desc->InputArgumentNames()) {
if (no_grad_vars.count(in_name)) { if (no_grad_vars->count(in_name)) {
std::string prefix = in_name.substr( std::string prefix = in_name.substr(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix; std::string new_name = prefix + kZeroVarSuffix;
...@@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>& no_grad_vars) { std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
BlockDescBind* cur_block = program_desc.Block(block_idx); BlockDescBind* cur_block = program_desc.Block(block_idx);
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_; std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
...@@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::vector<std::unique_ptr<OpDescBind>> backward_descs; std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads = std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeOpGrad(*it, no_grad_vars); MakeOpGrad(*it, no_grad_vars, grad_to_var);
if ((*it)->Type() == "recurrent") { if ((*it)->Type() == "recurrent") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_grads.size(), size_t(1), op_grads.size(), size_t(1),
"rnn_op's gradient process should contain only one op."); "rnn_op's gradient process should contain only one op.");
int step_block_idx = (*it)->GetBlockAttr("stop_block"); int step_block_idx = (*it)->GetBlockAttr("stop_block");
auto backward_block_op_descs = auto backward_block_op_descs = MakeBlockBackward(
MakeBlockBackward(program_desc, step_block_idx, no_grad_vars); program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block); BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
for (auto& ptr : backward_block_op_descs) { for (auto& ptr : backward_block_op_descs) {
backward_block->ops_.push_back(std::move(ptr)); backward_block->ops_.push_back(std::move(ptr));
...@@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc, ...@@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc,
no_grad_var_names.insert(GradVarName(name)); no_grad_var_names.insert(GradVarName(name));
} }
const int root_block_idx = 0; const int root_block_idx = 0;
auto backward_op_descs = std::unordered_map<std::string, std::string> grad_to_var;
MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names); auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
&no_grad_var_names, &grad_to_var);
auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_; auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
for (auto& ptr : backward_op_descs) { for (auto& ptr : backward_op_descs) {
forw_op_descs.push_back(std::move(ptr)); forw_op_descs.push_back(std::move(ptr));
......
...@@ -66,7 +66,7 @@ std::vector<OpDescBind *> BlockDescBind::AllOps() const { ...@@ -66,7 +66,7 @@ std::vector<OpDescBind *> BlockDescBind::AllOps() const {
return res; return res;
} }
void BlockDescBind::Sync() { void BlockDescBind::Flush() {
if (need_update_) { if (need_update_) {
auto &op_field = *this->desc_->mutable_ops(); auto &op_field = *this->desc_->mutable_ops();
op_field.Clear(); op_field.Clear();
...@@ -91,5 +91,10 @@ BlockDescBind *BlockDescBind::ParentBlock() const { ...@@ -91,5 +91,10 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx())); return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
} }
BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,8 @@ class BlockDescBind { ...@@ -35,7 +35,8 @@ class BlockDescBind {
public: public:
friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind &program_desc, int block_idx, ProgramDescBind &program_desc, int block_idx,
std::unordered_set<std::string> &no_grad_vars); std::unordered_set<std::string> *no_grad_vars,
std::unordered_map<std::string, std::string> *grad_to_var);
friend void AppendBackward( friend void AppendBackward(
ProgramDescBind &program_desc, ProgramDescBind &program_desc,
...@@ -64,9 +65,9 @@ class BlockDescBind { ...@@ -64,9 +65,9 @@ class BlockDescBind {
std::vector<OpDescBind *> AllOps() const; std::vector<OpDescBind *> AllOps() const;
void Sync(); void Flush();
BlockDesc *RawPtr() { return desc_; } BlockDesc *Proto();
private: private:
ProgramDescBind *prog_; // not_own ProgramDescBind *prog_; // not_own
......
...@@ -99,8 +99,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -99,8 +99,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = []( info->grad_op_maker_ = [](
const OpDescBind& fwd_op, const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set) { const std::unordered_set<std::string>& no_grad_set,
T maker(fwd_op, no_grad_set); std::unordered_map<std::string, std::string>* grad_to_var) {
T maker(fwd_op, no_grad_set, grad_to_var);
return maker(); return maker();
}; };
} }
......
...@@ -25,8 +25,9 @@ class GradOpDescMakerBase { ...@@ -25,8 +25,9 @@ class GradOpDescMakerBase {
public: public:
explicit GradOpDescMakerBase( explicit GradOpDescMakerBase(
const OpDescBind& fwd_op, const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set) const std::unordered_set<std::string>& no_grad_set,
: fwd_op_(fwd_op), no_grad_set_(no_grad_set) {} std::unordered_map<std::string, std::string>* grad_to_var)
: fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var) {}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
...@@ -37,11 +38,16 @@ class GradOpDescMakerBase { ...@@ -37,11 +38,16 @@ class GradOpDescMakerBase {
std::vector<std::string> ret_val; std::vector<std::string> ret_val;
auto var_names = this->Input(name); auto var_names = this->Input(name);
ret_val.reserve(var_names.size()); ret_val.reserve(var_names.size());
std::transform( std::transform(var_names.begin(), var_names.end(),
var_names.begin(), var_names.end(), std::back_inserter(ret_val), std::back_inserter(ret_val),
[this](const std::string& fwd_var_name) -> std::string { [this](const std::string& fwd_var_name) -> std::string {
auto g_name = GradVarName(fwd_var_name); auto g_name = GradVarName(fwd_var_name);
return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName; if (no_grad_set_.count(g_name)) {
return kEmptyVarName;
} else {
(*this->grad_to_var_)[g_name] = fwd_var_name;
return g_name;
}
}); });
if (!drop_empty_grad) { if (!drop_empty_grad) {
return ret_val; return ret_val;
...@@ -95,6 +101,7 @@ class GradOpDescMakerBase { ...@@ -95,6 +101,7 @@ class GradOpDescMakerBase {
private: private:
const OpDescBind& fwd_op_; const OpDescBind& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_; const std::unordered_set<std::string>& no_grad_set_;
std::unordered_map<std::string, std::string>* grad_to_var_;
}; };
class SingleGradOpDescMaker : public GradOpDescMakerBase { class SingleGradOpDescMaker : public GradOpDescMakerBase {
......
...@@ -32,7 +32,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, ...@@ -32,7 +32,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
} }
OpDesc *OpDescBind::Proto() { OpDesc *OpDescBind::Proto() {
Sync(); Flush();
return &op_desc_; return &op_desc_;
} }
...@@ -101,7 +101,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { ...@@ -101,7 +101,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
} }
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr(); BlockDesc *desc = block.Proto();
this->attrs_[name] = desc; this->attrs_[name] = desc;
need_update_ = true; need_update_ = true;
} }
...@@ -165,7 +165,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -165,7 +165,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
void OpDescBind::Sync() { void OpDescBind::Flush() {
if (need_update_) { if (need_update_) {
this->op_desc_.mutable_inputs()->Clear(); this->op_desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) { for (auto &ipt : inputs_) {
......
...@@ -89,8 +89,6 @@ class OpDescBind { ...@@ -89,8 +89,6 @@ class OpDescBind {
this->need_update_ = true; this->need_update_ = true;
} }
void Sync();
const VariableNameMap &Inputs() const { return inputs_; } const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; } const VariableNameMap &Outputs() const { return outputs_; }
...@@ -104,6 +102,8 @@ class OpDescBind { ...@@ -104,6 +102,8 @@ class OpDescBind {
void InferShape(const BlockDescBind &block) const; void InferShape(const BlockDescBind &block) const;
void Flush();
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -45,7 +45,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { ...@@ -45,7 +45,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
ProgramDesc *ProgramDescBind::Proto() { ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Sync(); block->Flush();
} }
return prog_; return prog_;
} }
......
...@@ -123,7 +123,18 @@ void BindProgramDesc(py::module &m) { ...@@ -123,7 +123,18 @@ void BindProgramDesc(py::module &m) {
AppendBackward(program_desc, no_grad_vars); AppendBackward(program_desc, no_grad_vars);
}) })
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size); .def("num_blocks", &ProgramDescBind::Size)
.def("serialize_to_string",
[](ProgramDescBind &program_desc) -> py::bytes {
const ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"ProgramDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize ProgramDesc Error. This could be a bug of Paddle.");
return res;
});
} }
void BindBlockDesc(py::module &m) { void BindBlockDesc(py::module &m) {
...@@ -149,7 +160,17 @@ void BindBlockDesc(py::module &m) { ...@@ -149,7 +160,17 @@ void BindBlockDesc(py::module &m) {
.def("all_vars", &BlockDescBind::AllVars, .def("all_vars", &BlockDescBind::AllVars,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("all_ops", &BlockDescBind::AllOps, .def("all_ops", &BlockDescBind::AllOps,
py::return_value_policy::reference); py::return_value_policy::reference)
.def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes {
const BlockDesc *desc = block_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"BlockDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize BlockDesc Error. This could be a bug of Paddle.");
return res;
});
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
...@@ -177,7 +198,17 @@ void BindVarDsec(py::module &m) { ...@@ -177,7 +198,17 @@ void BindVarDsec(py::module &m) {
.def("lod_level", &VarDescBind::GetLodLevel) .def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel) .def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType) .def("type", &VarDescBind::GetType)
.def("set_type", &VarDescBind::SetType); .def("set_type", &VarDescBind::SetType)
.def("serialize_to_string", [](VarDescBind &var_desc) -> py::bytes {
const VarDesc *desc = var_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"VarDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize VarDesc Error. This could be a bug of Paddle.");
return res;
});
py::enum_<VarDesc::VarType>(var_desc, "VarType", "") py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR) .value("LOD_TENSOR", VarDesc::LOD_TENSOR)
...@@ -213,7 +244,17 @@ void BindOpDesc(py::module &m) { ...@@ -213,7 +244,17 @@ void BindOpDesc(py::module &m) {
.def("set_block_attr", &OpDescBind::SetBlockAttr) .def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("block_attr", &OpDescBind::GetBlockAttr) .def("block_attr", &OpDescBind::GetBlockAttr)
.def("check_attrs", &OpDescBind::CheckAttrs) .def("check_attrs", &OpDescBind::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape); .def("infer_shape", &OpDescBind::InferShape)
.def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes {
const OpDesc *desc = op_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"OpDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize OpDesc Error. This could be a bug of Paddle.");
return res;
});
} }
} // namespace pybind } // namespace pybind
......
...@@ -73,6 +73,13 @@ class Variable(object): ...@@ -73,6 +73,13 @@ class Variable(object):
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.VarDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property @property
def name(self): def name(self):
return self.desc.name() return self.desc.name()
...@@ -169,6 +176,18 @@ class Operator(object): ...@@ -169,6 +176,18 @@ class Operator(object):
proto = OpProtoHolder.instance().get_op_proto(type) proto = OpProtoHolder.instance().get_op_proto(type)
if inputs is not None: if inputs is not None:
given = set()
need = set()
for n in inputs:
given.add(n)
for m in proto.inputs:
need.add(m.name)
if not given == need:
raise ValueError(
"Incorrect setting for input(s) of operator \"%s\". Need: [%s] Given: [%s]"
% (type, ", ".join(str(e) for e in need), ", ".join(
str(e) for e in given)))
for in_proto in proto.inputs: for in_proto in proto.inputs:
in_argus = inputs[in_proto.name] in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list): if not isinstance(in_argus, list):
...@@ -183,6 +202,18 @@ class Operator(object): ...@@ -183,6 +202,18 @@ class Operator(object):
self.desc.set_input(in_proto.name, in_argu_names) self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None: if outputs is not None:
given = set()
need = set()
for n in outputs:
given.add(n)
for m in proto.outputs:
need.add(m.name)
if not given == need:
raise ValueError(
"Incorrect setting for output(s) of operator \"%s\". Need: [%s] Given: [%s]"
% (type, ", ".join(str(e) for e in need), ", ".join(
str(e) for e in given)))
for out_proto in proto.outputs: for out_proto in proto.outputs:
out_argus = outputs[out_proto.name] out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list): if not isinstance(out_argus, list):
...@@ -210,6 +241,13 @@ class Operator(object): ...@@ -210,6 +241,13 @@ class Operator(object):
self.desc.check_attrs() self.desc.check_attrs()
self.desc.infer_shape(self.block.desc) self.desc.infer_shape(self.block.desc)
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property @property
def type(self): def type(self):
return self.desc.type() return self.desc.type()
...@@ -252,6 +290,13 @@ class Block(object): ...@@ -252,6 +290,13 @@ class Block(object):
self.ops = collections.deque() # operator list self.ops = collections.deque() # operator list
self.program = program self.program = program
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.BlockDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property @property
def parent_idx(self): def parent_idx(self):
return self.desc.parent return self.desc.parent
...@@ -296,6 +341,13 @@ class Program(object): ...@@ -296,6 +341,13 @@ class Program(object):
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
def global_block(self): def global_block(self):
return self.blocks[0] return self.blocks[0]
......
...@@ -34,6 +34,8 @@ class TestOperator(unittest.TestCase): ...@@ -34,6 +34,8 @@ class TestOperator(unittest.TestCase):
"Y": mul_y}, "Y": mul_y},
outputs={"Out": [mul_out]}, outputs={"Out": [mul_out]},
attrs={"x_num_col_dims": 1}) attrs={"x_num_col_dims": 1})
self.assertNotEqual(str(mul_op), "")
self.assertEqual(mul_op.type, "mul") self.assertEqual(mul_op.type, "mul")
self.assertEqual(mul_op.input_names, ["X", "Y"]) self.assertEqual(mul_op.input_names, ["X", "Y"])
self.assertEqual(mul_op.input("X"), ["mul.x"]) self.assertEqual(mul_op.input("X"), ["mul.x"])
......
...@@ -21,6 +21,7 @@ class TestVariable(unittest.TestCase): ...@@ -21,6 +21,7 @@ class TestVariable(unittest.TestCase):
b = g_program.current_block() b = g_program.current_block()
w = b.create_var( w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w") dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertNotEqual(str(w), "")
self.assertEqual(core.DataType.FP64, w.data_type) self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape) self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name) self.assertEqual("fc.w", w.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册