提交 e192d0fd 编写于 作者: F fengjiayi

Refactor the implementation of gradient Op creating

上级 3dc70ff2
#include "paddle/framework/grad_op_creator.h"
namespace paddle {
namespace framework {
OperatorBase* GradOpCreator::Create() {
BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)();
CompleteGradOp(grad_op);
return grad_op;
}
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
const VarIndexMap& var_map,
const vector<int>& format, InOutType type) {
int idx = var_map.at(var.name());
int begin_idx = format.empty() ? idx : format.at(idx);
int end_idx = format.empty() ? idx + 1 : format.at(idx + 1);
return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx,
end_idx);
}
void GradOpCreator::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op->type_));
const vector<int>& in_format =
op_->attrs_.count("input_format")
? op->GetAttr<std::vector<int>>("input_format")
: std::vector<int>();
const vector<int>& out_format =
op_->attrs_.count("output_format")
? op->GetAttr<std::vector<int>>("output_format")
: std::vector<int>();
for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, in_format, IN)));
}
for (const auto& var : op_proto.outputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, out_format, OUT)));
}
}
void GradOpCreator::PushArgIntoGradOp(const OpInOutArg* arg,
vector<std::string>& in_out,
vector<int>& format, VarIndexMap* varmap,
int& idx, bool is_grad) {
std::string var_name = arg->proto_name_;
if (is_grad) {
var_name += OperatorBase::GRAD_VAR_SUFFIX();
}
*(varmap)[var_name] = idx++;
size_t pre_sz = in_out.size();
auto base_it = arg->type == IN ? op_->inputs_.begin() : op_->outputs_.begin();
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
std::back_inserter(in_out));
if (is_grad) {
for (size_t i = pre_sz; i < in_out.size(); ++i) {
in_out[i] += OperatorBase::GRAD_VAR_SUFFIX();
}
}
format.push_back(in_out.size());
}
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
VarIndexMap* grad_varmap = new VarIndexMap();
int in_idx = 0;
int out_idx = 0;
vector<int> in_format({0});
vector<int> out_format({0});
for (const auto& arg : arg_list_) {
// op_'s inputs_ and outputs_
if (arg->needed_in_grad_) {
PushArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, false);
}
if (arg->type_ == IN) {
// gradients of op_'s inputs_
PushArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap,
out_idx, true);
} else {
// gradients of op_'s outputs_
PushArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, true);
}
}
grad_op->attrs_["input_format"] = in_format;
grad_op->attrs_["output_format"] = out_format;
grad_op->in_out_idxs_.reset(grad_varmap);
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
#pragma once
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
class OpRegistry;
class GradOpCreator {
public:
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
private:
enum InOutType { IN, OUT };
struct OpInOutArg {
OpInOutArg(const std::string& proto_name, const InOutType& type,
bool needed_in_grad, size_t begin_idx, size_t end_idx)
: proto_name_(proto_name),
type_(type),
needed_in_grad_(needed_in_grad),
begin_idx_(begin_idx),
end_idx_(end_idx) {}
std::string proto_name_;
InOutType type_;
bool needed_in_grad_;
size_t begin_idx_;
size_t end_idx_;
};
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
const vector<int>& format, InOutType type);
void BuildOpInOutArgList();
void PushArgIntoGradOp(const OpInOutArg* arg, vector<std::string>& in_out,
vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad);
void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase* op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
}
} // namespace framework
} // namespace paddle
......@@ -6,9 +6,8 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creater.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace paddle {
......@@ -286,13 +285,8 @@ class OpRegistry {
}
static OperatorPtr CreateGradOp(OperatorPtr op) {
OperatorPtr grad_op(grad_creators().at(op->type_)());
grad_op->type_ = op->type_;
AssembleGradInOut(op, grad_op);
GenerateGradArgOffset(op, grad_op);
GenerateGradAttr(op, grad_op);
GradOpCreator creator(op.get());
OperatorPtr grad_op(creator.Create());
grad_op->Init();
return grad_op;
}
......@@ -302,13 +296,18 @@ class OpRegistry {
return protos_;
};
private:
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_;
}
private:
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
......@@ -319,11 +318,6 @@ class OpRegistry {
return op_checkers_;
};
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
......@@ -334,100 +328,6 @@ class OpRegistry {
}
}
}
static void AssembleGradInOut(OperatorPtr op, OperatorPtr grad_op) {
size_t in_sz = op->inputs_.size() + op->outputs_.size() * 2;
grad_op->inputs_.reserve(in_sz);
size_t out_sz = op->inputs_.size();
grad_op->outputs_.reserve(out_sz);
// copy op->inputs_ to grad_op->inputs_
std::copy(op->inputs_.begin(), op->inputs_.end(),
std::back_inserter(grad_op->inputs_));
// copy op->outputs_ to grad_op->inputs_
std::copy(op->outputs_.begin(), op->outputs_.end(),
std::back_inserter(grad_op->inputs_));
// add gradients of op->outputs_ to grad_op->inputs_
for (const std::string& name : op->outputs_) {
grad_op->inputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX());
}
// add gradients of op->inputs_ to grad_op->outputs_
for (const std::string& name : op->inputs_) {
grad_op->outputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX());
}
}
static void GenerateGradArgOffset(OperatorPtr op, OperatorPtr grad_op) {
VarIndexMap* grad_varmap = new VarIndexMap();
const OpProto& op_proto = protos()[op->type_];
int idx = 0;
// offset of op's inputs
for (const auto& var : op_proto.inputs()) {
(*grad_varmap)[var.name()] = idx++;
}
// offset of op's outputs
for (const auto& var : op_proto.outputs()) {
(*grad_varmap)[var.name()] = idx++;
}
// offset of gradients of op's output
for (const auto& var : op_proto.outputs()) {
(*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++;
}
idx = 0;
// offset of gradients of op's input
for (const auto& var : op_proto.inputs()) {
(*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++;
}
grad_op->in_out_idxs_.reset(grad_varmap);
}
static void GenerateGradAttr(OperatorPtr op, OperatorPtr grad_op) {
const OpProto& op_proto = protos()[op->type_];
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
bool has_in_format = op->attrs_.count("input_format");
bool has_out_format = op->attrs_.count("output_format");
// grad_op's inputs_ contains op's inputs_, outputs_ and gradients of
// outpus_. So grad_op's input_format is necessary when op has
// either input_format or output_format.
if (has_in_format || has_out_format) {
std::vector<int> old_in_format;
std::vector<int> old_out_format;
has_in_format
? old_in_format = op->GetAttr<std::vector<int>>("input_format")
: old_in_format = std::vector<int>(op_proto.inputs_size()),
std::iota(old_in_format.begin(), old_in_format.end(), 0);
has_out_format
? old_out_format = op->GetAttr<std::vector<int>>("output_format")
: old_out_format = std::vector<int>(op_proto.outputs_size()),
std::iota(old_out_format.begin(), old_out_format.end(), 0);
std::vector<int> in_format;
in_format.reserve(old_in_format.size() + old_out_format.size() * 2);
int base = 0;
for (const int& idx : old_in_format) {
in_format.emplace_back(idx + base);
}
base += op->inputs_.size();
for (const int& idx : old_out_format) {
in_format.emplace_back(idx + base);
}
base += op->outputs_.size();
for (const int& idx : old_in_format) {
in_format.emplace_back(idx + base);
}
grad_op->attrs_["input_format"] = in_format;
// grad_op's outputs_ contains gradients of op's inputs_. So grad_op's
// output_format is necessary only when op has input_format.
if (has_in_format) {
std::vector<int> out_format;
out_format.reserve(op_proto.inputs_size());
std::copy(old_in_format.begin(), old_in_format.end(),
std::back_inserter(out_format));
grad_op->attrs_["output_format"] = out_format;
}
}
}
};
template <typename OpType, typename ProtoMakerType>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册