提交 4130e5fa 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into add_selected_rows_functor

......@@ -28,15 +28,15 @@ namespace paddle {
namespace framework {
static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op,
const std::unordered_set<std::string>& no_grad_set) {
const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var) {
OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set);
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var);
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(),
......@@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() {
// See Backward.h for details
static std::unique_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
std::unordered_set<std::string>& no_grad_names,
std::unordered_map<std::string, std::string>* grad_to_var,
size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic.
......@@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) {
auto& fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
auto bwd = BackwardRecursive(*fwd, no_grad_names, grad_to_var, uniq_id);
ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id);
......@@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
}
} else {
std::unique_ptr<OperatorBase> grad_op(
CreateGradOp(forwardOp, no_grad_names));
CreateGradOp(forwardOp, no_grad_names, grad_to_var));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) {
......@@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
rnn_grad_op->set_stepnet(
BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
}
if (net->ops_.empty()) { // Current no aux op is added to network
......@@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward(
no_grad_names.insert(name + kGradVarSuffix);
}
size_t uid = 0;
return BackwardRecursive(forwardOp, no_grad_names, uid);
std::unordered_map<std::string, std::string> grad_to_var;
return BackwardRecursive(forwardOp, no_grad_names, &grad_to_var, uid);
}
// ==================================== //
......@@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names,
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const std::unique_ptr<OpDescBind>& op_desc,
std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
if (AllGradInSet(inputs, no_grad_vars)) {
if (AllGradInSet(inputs, *no_grad_vars)) {
return grad_op_descs; // empty vector
}
// All output gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
if (AllGradInSet(outputs, no_grad_vars)) {
if (AllGradInSet(outputs, *no_grad_vars)) {
for (const std::string& name : inputs) {
no_grad_vars.insert(GradVarName(name));
no_grad_vars->insert(GradVarName(name));
}
return grad_op_descs; // empty vector
}
grad_op_descs = OpInfoMap::Instance()
.Get(op_desc->Type())
.GradOpMaker()(*op_desc, no_grad_vars);
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) {
for (const std::string& in_name : desc->InputArgumentNames()) {
if (no_grad_vars.count(in_name)) {
if (no_grad_vars->count(in_name)) {
std::string prefix = in_name.substr(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix;
......@@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
BlockDescBind* cur_block = program_desc.Block(block_idx);
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
......@@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
std::vector<std::unique_ptr<OpDescBind>> op_grads =
MakeOpGrad(*it, no_grad_vars);
MakeOpGrad(*it, no_grad_vars, grad_to_var);
if ((*it)->Type() == "recurrent") {
PADDLE_ENFORCE_EQ(
op_grads.size(), size_t(1),
"rnn_op's gradient process should contain only one op.");
int step_block_idx = (*it)->GetBlockAttr("stop_block");
auto backward_block_op_descs =
MakeBlockBackward(program_desc, step_block_idx, no_grad_vars);
auto backward_block_op_descs = MakeBlockBackward(
program_desc, step_block_idx, no_grad_vars, grad_to_var);
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
for (auto& ptr : backward_block_op_descs) {
backward_block->ops_.push_back(std::move(ptr));
......@@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc,
no_grad_var_names.insert(GradVarName(name));
}
const int root_block_idx = 0;
auto backward_op_descs =
MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names);
std::unordered_map<std::string, std::string> grad_to_var;
auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
&no_grad_var_names, &grad_to_var);
auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
for (auto& ptr : backward_op_descs) {
forw_op_descs.push_back(std::move(ptr));
......
......@@ -66,7 +66,7 @@ std::vector<OpDescBind *> BlockDescBind::AllOps() const {
return res;
}
void BlockDescBind::Sync() {
void BlockDescBind::Flush() {
if (need_update_) {
auto &op_field = *this->desc_->mutable_ops();
op_field.Clear();
......@@ -91,5 +91,10 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
}
BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}
} // namespace framework
} // namespace paddle
......@@ -35,7 +35,8 @@ class BlockDescBind {
public:
friend std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind &program_desc, int block_idx,
std::unordered_set<std::string> &no_grad_vars);
std::unordered_set<std::string> *no_grad_vars,
std::unordered_map<std::string, std::string> *grad_to_var);
friend void AppendBackward(
ProgramDescBind &program_desc,
......@@ -64,9 +65,9 @@ class BlockDescBind {
std::vector<OpDescBind *> AllOps() const;
void Sync();
void Flush();
BlockDesc *RawPtr() { return desc_; }
BlockDesc *Proto();
private:
ProgramDescBind *prog_; // not_own
......
......@@ -99,8 +99,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = [](
const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set) {
T maker(fwd_op, no_grad_set);
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var) {
T maker(fwd_op, no_grad_set, grad_to_var);
return maker();
};
}
......
......@@ -25,8 +25,9 @@ class GradOpDescMakerBase {
public:
explicit GradOpDescMakerBase(
const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set)
: fwd_op_(fwd_op), no_grad_set_(no_grad_set) {}
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var)
: fwd_op_(fwd_op), no_grad_set_(no_grad_set), grad_to_var_(grad_to_var) {}
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
......@@ -37,11 +38,16 @@ class GradOpDescMakerBase {
std::vector<std::string> ret_val;
auto var_names = this->Input(name);
ret_val.reserve(var_names.size());
std::transform(
var_names.begin(), var_names.end(), std::back_inserter(ret_val),
std::transform(var_names.begin(), var_names.end(),
std::back_inserter(ret_val),
[this](const std::string& fwd_var_name) -> std::string {
auto g_name = GradVarName(fwd_var_name);
return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName;
if (no_grad_set_.count(g_name)) {
return kEmptyVarName;
} else {
(*this->grad_to_var_)[g_name] = fwd_var_name;
return g_name;
}
});
if (!drop_empty_grad) {
return ret_val;
......@@ -95,6 +101,7 @@ class GradOpDescMakerBase {
private:
const OpDescBind& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_;
std::unordered_map<std::string, std::string>* grad_to_var_;
};
class SingleGradOpDescMaker : public GradOpDescMakerBase {
......
......@@ -32,7 +32,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
}
OpDesc *OpDescBind::Proto() {
Sync();
Flush();
return &op_desc_;
}
......@@ -101,7 +101,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
}
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr();
BlockDesc *desc = block.Proto();
this->attrs_[name] = desc;
need_update_ = true;
}
......@@ -165,7 +165,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void OpDescBind::Sync() {
void OpDescBind::Flush() {
if (need_update_) {
this->op_desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) {
......
......@@ -89,8 +89,6 @@ class OpDescBind {
this->need_update_ = true;
}
void Sync();
const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; }
......@@ -104,6 +102,8 @@ class OpDescBind {
void InferShape(const BlockDescBind &block) const;
void Flush();
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
......@@ -45,7 +45,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Sync();
block->Flush();
}
return prog_;
}
......
......@@ -123,7 +123,18 @@ void BindProgramDesc(py::module &m) {
AppendBackward(program_desc, no_grad_vars);
})
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size);
.def("num_blocks", &ProgramDescBind::Size)
.def("serialize_to_string",
[](ProgramDescBind &program_desc) -> py::bytes {
const ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"ProgramDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize ProgramDesc Error. This could be a bug of Paddle.");
return res;
});
}
void BindBlockDesc(py::module &m) {
......@@ -149,7 +160,17 @@ void BindBlockDesc(py::module &m) {
.def("all_vars", &BlockDescBind::AllVars,
py::return_value_policy::reference)
.def("all_ops", &BlockDescBind::AllOps,
py::return_value_policy::reference);
py::return_value_policy::reference)
.def("serialize_to_string", [](BlockDescBind &block_desc) -> py::bytes {
const BlockDesc *desc = block_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"BlockDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize BlockDesc Error. This could be a bug of Paddle.");
return res;
});
}
void BindVarDsec(py::module &m) {
......@@ -177,7 +198,17 @@ void BindVarDsec(py::module &m) {
.def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel)
.def("type", &VarDescBind::GetType)
.def("set_type", &VarDescBind::SetType);
.def("set_type", &VarDescBind::SetType)
.def("serialize_to_string", [](VarDescBind &var_desc) -> py::bytes {
const VarDesc *desc = var_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"VarDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize VarDesc Error. This could be a bug of Paddle.");
return res;
});
py::enum_<VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", VarDesc::LOD_TENSOR)
......@@ -213,7 +244,17 @@ void BindOpDesc(py::module &m) {
.def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("block_attr", &OpDescBind::GetBlockAttr)
.def("check_attrs", &OpDescBind::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape);
.def("infer_shape", &OpDescBind::InferShape)
.def("serialize_to_string", [](OpDescBind &op_desc) -> py::bytes {
const OpDesc *desc = op_desc.Proto();
PADDLE_ENFORCE(desc->IsInitialized(),
"OpDesc has not been initialized.");
std::string res;
PADDLE_ENFORCE(
desc->SerializeToString(&res),
"Serialize OpDesc Error. This could be a bug of Paddle.");
return res;
});
}
} // namespace pybind
......
......@@ -73,6 +73,13 @@ class Variable(object):
self.block.vars[name] = self
self.op = None
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.VarDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property
def name(self):
return self.desc.name()
......@@ -169,6 +176,18 @@ class Operator(object):
proto = OpProtoHolder.instance().get_op_proto(type)
if inputs is not None:
given = set()
need = set()
for n in inputs:
given.add(n)
for m in proto.inputs:
need.add(m.name)
if not given == need:
raise ValueError(
"Incorrect setting for input(s) of operator \"%s\". Need: [%s] Given: [%s]"
% (type, ", ".join(str(e) for e in need), ", ".join(
str(e) for e in given)))
for in_proto in proto.inputs:
in_argus = inputs[in_proto.name]
if not isinstance(in_argus, list):
......@@ -183,6 +202,18 @@ class Operator(object):
self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None:
given = set()
need = set()
for n in outputs:
given.add(n)
for m in proto.outputs:
need.add(m.name)
if not given == need:
raise ValueError(
"Incorrect setting for output(s) of operator \"%s\". Need: [%s] Given: [%s]"
% (type, ", ".join(str(e) for e in need), ", ".join(
str(e) for e in given)))
for out_proto in proto.outputs:
out_argus = outputs[out_proto.name]
if not isinstance(out_argus, list):
......@@ -210,6 +241,13 @@ class Operator(object):
self.desc.check_attrs()
self.desc.infer_shape(self.block.desc)
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property
def type(self):
return self.desc.type()
......@@ -252,6 +290,13 @@ class Block(object):
self.ops = collections.deque() # operator list
self.program = program
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.BlockDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
@property
def parent_idx(self):
return self.desc.parent
......@@ -296,6 +341,13 @@ class Program(object):
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
def __str__(self):
protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return proto.__str__()
__repr__ = __str__
def global_block(self):
return self.blocks[0]
......
......@@ -34,6 +34,8 @@ class TestOperator(unittest.TestCase):
"Y": mul_y},
outputs={"Out": [mul_out]},
attrs={"x_num_col_dims": 1})
self.assertNotEqual(str(mul_op), "")
self.assertEqual(mul_op.type, "mul")
self.assertEqual(mul_op.input_names, ["X", "Y"])
self.assertEqual(mul_op.input("X"), ["mul.x"])
......
......@@ -21,6 +21,7 @@ class TestVariable(unittest.TestCase):
b = g_program.current_block()
w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertNotEqual(str(w), "")
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册