提交 ff7fdb7d 编写于 作者: F fengjiayi

Add `CreateBackwardOp` function

上级 a80e0100
......@@ -222,5 +222,61 @@ std::unique_ptr<OperatorBase> Backward(
return BackwardRecursive(forwardOp, no_grad_names, uid);
}
// ==================================== //
static bool AllGradInSet(const std::vector<std::string>& names,
const std::unordered_set<std::string>& set) {
for (const std::string& name : names) {
if (!set.count(GradVarName(name))) {
return false;
}
}
return true;
}
std::vector<OpDescBind> CreatBackwardOps(
const OpDescBind& op_desc, unordered_map<std::string>& no_grad_vars) {
std::vector<OpDescBind> grad_op_descs;
// All input gradients of forwarding operator do not need to calculat.
if (AllGradInSet(op_desc_.InputNames(), kGradVarSuffix, no_grad_vars)) {
return grad_op_descs; // empty vector
}
// All output gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& outputs = op_desc_.OutputNames();
if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) {
for (const std::string& name : outputs) {
no_grad_vars.insert(GradVarName(name));
}
return grad_op_descs; // empty vector
}
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc);
std::vector<OpDescBind> fill_zeros_ops;
for (OpDescBind& desc : grad_op_descs) {
for (const std::string& in_name : desc.InputNames()) {
if (no_grad_vars.count(in_name)) {
std::string prefix = in_name.substr(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix;
desc.Rename(in_name, new_name);
OpDescBind op_desc_bind(
{"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}});
fill_zeros_ops.push_back(op_desc_bind);
}
}
for (const std::string& out_name : desc.OutputName()) {
if (no_grad_vars.count(out_name)) {
desc.Rename(out_name, kEmptyVarName);
}
}
}
grad_op_descs.insert(grad_op_descs.begin(), fill_zeros_ops.begin(),
fill_zeros_ops.end());
// TODO (fengjiayi): RNN op
return grad_op_descs;
}
} // namespace framework
} // namespace paddle
......@@ -23,5 +23,10 @@ namespace framework {
extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
extern void AppendBackwardOps(
BlockDescBind& block_desc,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
} // namespace paddle
......@@ -18,6 +18,15 @@ limitations under the License. */
namespace paddle {
namespace framework {
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs) {
op_desc_.set_type(type);
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
}
OpDesc *OpDescBind::Proto() {
Sync();
return &op_desc_;
......@@ -112,6 +121,17 @@ const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
return attrs_;
}
void Rename(const std::string &old_name, const std::string &new_name) {
for (std : string &input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name);
}
for (std::string &output : outputs_) {
std::repalce(output.second.begin(), output.second.end(), old_name,
new_name);
}
need_update_ = true;
}
void OpDescBind::Sync() {
if (need_update_) {
this->op_desc_.mutable_inputs()->Clear();
......
......@@ -26,6 +26,11 @@ class BlockDescBind;
class OpDescBind {
public:
OpDescBind() {}
OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);
OpDesc *Proto();
std::string Type() const { return op_desc_.type(); }
......@@ -67,6 +72,8 @@ class OpDescBind {
int GetBlockAttr(const std::string &name) const;
void Rename(const std::string &old_name, const std::string &new_name);
// Only be used in C++
const std::unordered_map<std::string, Attribute> &GetAttrMap() const;
......
......@@ -57,5 +57,10 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
return std::unique_ptr<OperatorBase>(BuildGradOp(&op));
}
static std::vector<OpDescBind> CreateGradOpDescs(const OpDescBind& op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc.Type());
return info.grad_op_maker_(op_desc);
}
} // namespace framework
} // namespace paddle
......@@ -68,6 +68,8 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
static std::vector<OpDescBind> CreateGradOpDescs(const OpDescBind& op_desc);
};
class Registrar {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册