提交 bf4da3d9 编写于 作者: F fengjiayi

Refactor Rigistry::CreateGradOp()

We put forward Op's inputs, outputs and output gradients into Grad
Op's inputs, and put forward Op's input gradients into Grad Op's output.
So Grad Op's `in_out_idx`, `input_format` and 'output format' need to be
rebuilt during Op creating.
上级 e786746f
......@@ -228,6 +228,11 @@ class OpRegistry {
}
}
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}
static OperatorPtr CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
......@@ -240,6 +245,7 @@ class OpRegistry {
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;
op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
......@@ -256,11 +262,6 @@ class OpRegistry {
return OperatorPtr(op);
}
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}
static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs;
inputs.reserve((size_t)op_desc.inputs_size());
......@@ -280,19 +281,16 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
static OperatorPtr CreateGradOp(std::shared_ptr<OperatorBase> op) {
OperatorPtr op_grad(grad_creators().at(op->type_)());
op_grad->type_ = op->type_;
op_grad->inputs_.reserve(op->inputs_.size());
for (auto& input : op->inputs_) {
op_grad->inputs_.emplace_back(input);
op_grad->outputs_.emplace_back(input + "@grad");
}
for (auto& output : op->outputs_) {
op_grad->inputs_.emplace_back(output);
op_grad->inputs_.emplace_back(output + "@grad");
}
return op_grad;
static OperatorPtr CreateGradOp(OperatorPtr op) {
OperatorPtr grad_op(grad_creators().at(op->type_)());
grad_op->type_ = op->type_;
AssembleGradInOut(op, grad_op);
GenerateGradArgOffset(op, grad_op);
GenerateGradAttr(op, grad_op);
grad_op->Init();
return grad_op;
}
static std::unordered_map<std::string, OpProto>& protos() {
......@@ -307,6 +305,21 @@ class OpRegistry {
return maps_;
}
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
......@@ -318,19 +331,98 @@ class OpRegistry {
}
}
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
static void AssembleGradInOut(OperatorPtr op, OperatorPtr grad_op) {
size_t in_sz = op->inputs_.size() + op->outputs_.size() * 2;
grad_op->inputs_.reserve(in_sz);
size_t out_sz = op->inputs_.size();
grad_op->outputs_.reserve(out_sz);
// copy op->inputs_ to grad_op->inputs_
std::copy(op->inputs_.begin(), op->inputs_.end(),
std::back_inserter(grad_op->inputs_));
// copy op->outputs_ to grad_op->inputs_
std::copy(op->outputs_.begin(), op->outputs_.end(),
std::back_inserter(grad_op->inputs_));
// add gradients of op->outputs_ to grad_op->inputs_
for (const std::string& name : op->outputs_) {
grad_op->inputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX());
}
// add gradients of op->inputs_ to grad_op->outputs_
for (const std::string& name : op->inputs_) {
grad_op->outputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX());
}
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
static void GenerateGradArgOffset(OperatorPtr op, OperatorPtr grad_op) {
VarIndexMap* grad_varmap = new VarIndexMap();
const OpProto& op_proto = protos()[op->type_];
int idx = 0;
// offset of op's inputs
for (const auto& var : op_proto.inputs()) {
(*grad_varmap)[var.name()] = idx++;
}
// offset of op's outputs
for (const auto& var : op_proto.outputs()) {
(*grad_varmap)[var.name()] = idx++;
}
// offset of gradients of op's output
for (const auto& var : op_proto.outputs()) {
(*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++;
}
idx = 0;
// offset of gradients of op's input
for (const auto& var : op_proto.inputs()) {
(*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++;
}
grad_op->in_out_idxs_.reset(grad_varmap);
}
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
static void GenerateGradAttr(OperatorPtr op, OperatorPtr grad_op) {
const OpProto& op_proto = protos()[op->type_];
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
bool has_in_format = op->attrs_.count("input_format");
bool has_out_format = op->attrs_.count("output_format");
// grad_op's inputs_ contains op's inputs_, outputs_ and gradients of
// outpus_. So grad_op's input_format is necessary when op has
// either input_format or output_format.
if (has_in_format || has_out_format) {
std::vector<int> old_in_format;
std::vector<int> old_out_format;
has_in_format
? old_in_format = op->GetAttr<std::vector<int>>("input_format")
: old_in_format = std::vector<int>(op_proto.inputs_size()),
std::iota(old_in_format.begin(), old_in_format.end(), 0);
has_out_format
? old_out_format = op->GetAttr<std::vector<int>>("output_format")
: old_out_format = std::vector<int>(op_proto.outputs_size()),
std::iota(old_out_format.begin(), old_out_format.end(), 0);
std::vector<int> in_format;
in_format.reserve(old_in_format.size() + old_out_format.size() * 2);
int base = 0;
for (const int& idx : old_in_format) {
in_format.emplace_back(idx + base);
}
base += op->inputs_.size();
for (const int& idx : old_out_format) {
in_format.emplace_back(idx + base);
}
base += op->outputs_.size();
for (const int& idx : old_in_format) {
in_format.emplace_back(idx + base);
}
grad_op->attrs_["input_format"] = in_format;
// grad_op's outputs_ contains gradients of op's inputs_. So grad_op's
// output_format is necessary only when op has input_format.
if (has_in_format) {
std::vector<int> out_format;
out_format.reserve(op_proto.inputs_size());
std::copy(old_in_format.begin(), old_in_format.end(),
std::back_inserter(out_format));
grad_op->attrs_["output_format"] = out_format;
}
}
}
};
......@@ -370,7 +462,7 @@ class GradOpRegisterHelper {
int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Operator.
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
......
......@@ -63,6 +63,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
virtual ~OperatorBase() {}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册