提交 e6fca658 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #3433 from wangkuiyi/refactorize_grad_op_builder.cc

Refactorize grad op builder.cc
...@@ -19,45 +19,44 @@ permissions and limitations under the License. */ ...@@ -19,45 +19,44 @@ permissions and limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpRegistry; typedef std::vector<int> Ints;
using VarIndexMap = std::unordered_map<std::string, int>;
enum class OpArgType { IN, OUT }; enum class OpArgType { IN, OUT };
static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) { const Ints* AttrFormat(const AttributeMap& attrs, const std::string& key) {
std::string key = type == OpArgType::IN ? "input_format" : "output_format"; return (attrs.count(key) > 0) ? &boost::get<Ints>(attrs.at(key)) : nullptr;
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
} }
static const std::vector<int>* GetOpFormat(const OperatorBase* op, Ints* AttrFormat(AttributeMap& attrs, const std::string& key) {
const OpArgType& type) { return (attrs.count(key) > 0) ? &boost::get<Ints>(attrs.at(key)) : nullptr;
std::string key = type == OpArgType::IN ? "input_format" : "output_format";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
} }
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, static void TransOpArg(const OperatorBase* src_op,
const OpArgType& src_type, const OpArgType& dst_type, std::vector<std::string>& grad_inputs,
std::vector<std::string>& grad_outputs,
AttributeMap& grad_attrs,
std::unordered_map<std::string, int>& grad_idxs,
const std::string& src_type, const std::string& dst_type,
int& idx, bool is_grad) { int& idx, bool is_grad) {
const std::vector<std::string>& src_inout = const std::vector<std::string>& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; (src_type == "input_format") ? src_op->inputs_ : src_op->outputs_;
const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
const std::vector<int>* src_format = AttrFormat(src_op->Attrs(), src_type);
std::vector<std::string>& dst_inout = std::vector<std::string>& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; (dst_type == "input_format") ? grad_inputs : grad_outputs;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
std::vector<int>* dst_format = AttrFormat(grad_attrs, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_); const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); (src_type == "input_format") ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
std::string src_name = arg.name(); std::string src_name = arg.name();
std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name; std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name;
(*dst_op->in_out_idxs_)[dst_name] = idx++; grad_idxs[dst_name] = idx++;
int src_arg_idx = src_op->in_out_idxs_->at(src_name); int src_arg_idx = src_op->in_out_idxs_->at(src_name);
int src_begin = int src_begin =
src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx);
...@@ -76,26 +75,42 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, ...@@ -76,26 +75,42 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_); const std::string& grad_op_type = OpRegistry::grad_ops().at(op->Type());
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type; AttributeMap grad_attrs(op->Attrs());
grad_op->attrs_ = op->attrs_; grad_attrs.erase("input_format");
grad_op->attrs_.erase("input_format"); grad_attrs.erase("output_format");
grad_op->attrs_.erase("output_format"); if (op->Attrs().count("input_format") > 0) {
if (GetOpFormat(op, OpArgType::IN) != nullptr) { grad_attrs["output_format"] = std::vector<int>({0});
grad_op->attrs_["output_format"] = std::vector<int>({0});
} }
if (GetOpFormat(op, OpArgType::IN) != nullptr || if (op->Attrs().count("input_format") > 0 ||
GetOpFormat(op, OpArgType::OUT) != nullptr) { op->Attrs().count("output_format") > 0) {
grad_op->attrs_["input_format"] = std::vector<int>({0}); grad_attrs["input_format"] = std::vector<int>({0});
} }
grad_op->in_out_idxs_.reset(new VarIndexMap());
std::vector<std::string> grad_inputs, grad_outputs;
using VarIndexMap = std::unordered_map<std::string, int>;
VarIndexMap* grad_idxs = new VarIndexMap;
int in_idx = 0; int in_idx = 0;
int out_idx = 0; int out_idx = 0;
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G "input_format", "input_format", in_idx, false); // I
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG "output_format", "input_format", in_idx, false); // G
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"output_format", "input_format", in_idx, true); // OG
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"input_format", "output_format", out_idx, true); // IG
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
grad_op->inputs_ = grad_inputs;
grad_op->outputs_ = grad_outputs;
grad_op->attrs_ = grad_attrs;
grad_op->in_out_idxs_.reset(grad_idxs);
return grad_op; return grad_op;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册