提交 bf4da3d9 编写于 作者: F fengjiayi

Refactor Rigistry::CreateGradOp()

We put forward Op's inputs, outputs and output gradients into Grad
Op's inputs, and put forward Op's input gradients into Grad Op's output.
So Grad Op's `in_out_idx`, `input_format` and 'output format' need to be
rebuilt during Op creating.
上级 e786746f
...@@ -228,6 +228,11 @@ class OpRegistry { ...@@ -228,6 +228,11 @@ class OpRegistry {
} }
} }
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}
static OperatorPtr CreateOp(const std::string& type, static OperatorPtr CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameList& inputs,
const VarNameList& outputs, const VarNameList& outputs,
...@@ -240,6 +245,7 @@ class OpRegistry { ...@@ -240,6 +245,7 @@ class OpRegistry {
op->type_ = type; op->type_ = type;
op->inputs_ = inputs; op->inputs_ = inputs;
op->outputs_ = outputs; op->outputs_ = outputs;
op->attrs_ = attrs; op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_); op_checkers().at(type).Check(op->attrs_);
...@@ -256,11 +262,6 @@ class OpRegistry { ...@@ -256,11 +262,6 @@ class OpRegistry {
return OperatorPtr(op); return OperatorPtr(op);
} }
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}
static OperatorPtr CreateOp(const OpDesc& op_desc) { static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs; std::vector<std::string> inputs;
inputs.reserve((size_t)op_desc.inputs_size()); inputs.reserve((size_t)op_desc.inputs_size());
...@@ -280,19 +281,16 @@ class OpRegistry { ...@@ -280,19 +281,16 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static OperatorPtr CreateGradOp(std::shared_ptr<OperatorBase> op) { static OperatorPtr CreateGradOp(OperatorPtr op) {
OperatorPtr op_grad(grad_creators().at(op->type_)()); OperatorPtr grad_op(grad_creators().at(op->type_)());
op_grad->type_ = op->type_; grad_op->type_ = op->type_;
op_grad->inputs_.reserve(op->inputs_.size());
for (auto& input : op->inputs_) { AssembleGradInOut(op, grad_op);
op_grad->inputs_.emplace_back(input); GenerateGradArgOffset(op, grad_op);
op_grad->outputs_.emplace_back(input + "@grad"); GenerateGradAttr(op, grad_op);
}
for (auto& output : op->outputs_) { grad_op->Init();
op_grad->inputs_.emplace_back(output); return grad_op;
op_grad->inputs_.emplace_back(output + "@grad");
}
return op_grad;
} }
static std::unordered_map<std::string, OpProto>& protos() { static std::unordered_map<std::string, OpProto>& protos() {
...@@ -307,6 +305,21 @@ class OpRegistry { ...@@ -307,6 +305,21 @@ class OpRegistry {
return maps_; return maps_;
} }
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}
static void GenerateTempVariableName(OperatorBase* op) { static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL); static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) { for (auto& outname : op->outputs_) {
...@@ -318,19 +331,98 @@ class OpRegistry { ...@@ -318,19 +331,98 @@ class OpRegistry {
} }
} }
static std::unordered_map<std::string, OpCreator>& creators() { static void AssembleGradInOut(OperatorPtr op, OperatorPtr grad_op) {
static std::unordered_map<std::string, OpCreator> creators_; size_t in_sz = op->inputs_.size() + op->outputs_.size() * 2;
return creators_; grad_op->inputs_.reserve(in_sz);
size_t out_sz = op->inputs_.size();
grad_op->outputs_.reserve(out_sz);
// copy op->inputs_ to grad_op->inputs_
std::copy(op->inputs_.begin(), op->inputs_.end(),
std::back_inserter(grad_op->inputs_));
// copy op->outputs_ to grad_op->inputs_
std::copy(op->outputs_.begin(), op->outputs_.end(),
std::back_inserter(grad_op->inputs_));
// add gradients of op->outputs_ to grad_op->inputs_
for (const std::string& name : op->outputs_) {
grad_op->inputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX());
}
// add gradients of op->inputs_ to grad_op->outputs_
for (const std::string& name : op->inputs_) {
grad_op->outputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX());
}
} }
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() { static void GenerateGradArgOffset(OperatorPtr op, OperatorPtr grad_op) {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; VarIndexMap* grad_varmap = new VarIndexMap();
return op_checkers_; const OpProto& op_proto = protos()[op->type_];
}; int idx = 0;
// offset of op's inputs
for (const auto& var : op_proto.inputs()) {
(*grad_varmap)[var.name()] = idx++;
}
// offset of op's outputs
for (const auto& var : op_proto.outputs()) {
(*grad_varmap)[var.name()] = idx++;
}
// offset of gradients of op's output
for (const auto& var : op_proto.outputs()) {
(*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++;
}
idx = 0;
// offset of gradients of op's input
for (const auto& var : op_proto.inputs()) {
(*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++;
}
grad_op->in_out_idxs_.reset(grad_varmap);
}
static std::unordered_map<std::string, OpCreator>& grad_creators() { static void GenerateGradAttr(OperatorPtr op, OperatorPtr grad_op) {
static std::unordered_map<std::string, OpCreator> grad_creators_; const OpProto& op_proto = protos()[op->type_];
return grad_creators_; grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
bool has_in_format = op->attrs_.count("input_format");
bool has_out_format = op->attrs_.count("output_format");
// grad_op's inputs_ contains op's inputs_, outputs_ and gradients of
// outpus_. So grad_op's input_format is necessary when op has
// either input_format or output_format.
if (has_in_format || has_out_format) {
std::vector<int> old_in_format;
std::vector<int> old_out_format;
has_in_format
? old_in_format = op->GetAttr<std::vector<int>>("input_format")
: old_in_format = std::vector<int>(op_proto.inputs_size()),
std::iota(old_in_format.begin(), old_in_format.end(), 0);
has_out_format
? old_out_format = op->GetAttr<std::vector<int>>("output_format")
: old_out_format = std::vector<int>(op_proto.outputs_size()),
std::iota(old_out_format.begin(), old_out_format.end(), 0);
std::vector<int> in_format;
in_format.reserve(old_in_format.size() + old_out_format.size() * 2);
int base = 0;
for (const int& idx : old_in_format) {
in_format.emplace_back(idx + base);
}
base += op->inputs_.size();
for (const int& idx : old_out_format) {
in_format.emplace_back(idx + base);
}
base += op->outputs_.size();
for (const int& idx : old_in_format) {
in_format.emplace_back(idx + base);
}
grad_op->attrs_["input_format"] = in_format;
// grad_op's outputs_ contains gradients of op's inputs_. So grad_op's
// output_format is necessary only when op has input_format.
if (has_in_format) {
std::vector<int> out_format;
out_format.reserve(op_proto.inputs_size());
std::copy(old_in_format.begin(), old_in_format.end(),
std::back_inserter(out_format));
grad_op->attrs_["output_format"] = out_format;
}
}
} }
}; };
...@@ -370,7 +462,7 @@ class GradOpRegisterHelper { ...@@ -370,7 +462,7 @@ class GradOpRegisterHelper {
int __op_register_##__op_type##_handle__() { return 0; } int __op_register_##__op_type##_handle__() { return 0; }
/** /**
* Macro to Register Operator. * Macro to Register Gradient Operator.
*/ */
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \ #define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
......
...@@ -63,6 +63,11 @@ class OperatorBase { ...@@ -63,6 +63,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator. /// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; } static std::string TMP_VAR_NAME() { return "@TEMP@"; }
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册