diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 98ef426b108848feab43e4163e08100496fc2f77..6ba0784f1b83fe47a05ac305d025ed82fbe563b8 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -228,6 +228,11 @@ class OpRegistry { } } + template + static void RegisterGradOp(const std::string& op_type) { + grad_creators()[op_type] = [] { return new OpType; }; + } + static OperatorPtr CreateOp(const std::string& type, const VarNameList& inputs, const VarNameList& outputs, @@ -240,6 +245,7 @@ class OpRegistry { op->type_ = type; op->inputs_ = inputs; op->outputs_ = outputs; + op->attrs_ = attrs; op_checkers().at(type).Check(op->attrs_); @@ -256,11 +262,6 @@ class OpRegistry { return OperatorPtr(op); } - template - static void RegisterGradOp(const std::string& op_type) { - grad_creators()[op_type] = [] { return new OpType; }; - } - static OperatorPtr CreateOp(const OpDesc& op_desc) { std::vector inputs; inputs.reserve((size_t)op_desc.inputs_size()); @@ -280,19 +281,16 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static OperatorPtr CreateGradOp(std::shared_ptr op) { - OperatorPtr op_grad(grad_creators().at(op->type_)()); - op_grad->type_ = op->type_; - op_grad->inputs_.reserve(op->inputs_.size()); - for (auto& input : op->inputs_) { - op_grad->inputs_.emplace_back(input); - op_grad->outputs_.emplace_back(input + "@grad"); - } - for (auto& output : op->outputs_) { - op_grad->inputs_.emplace_back(output); - op_grad->inputs_.emplace_back(output + "@grad"); - } - return op_grad; + static OperatorPtr CreateGradOp(OperatorPtr op) { + OperatorPtr grad_op(grad_creators().at(op->type_)()); + grad_op->type_ = op->type_; + + AssembleGradInOut(op, grad_op); + GenerateGradArgOffset(op, grad_op); + GenerateGradAttr(op, grad_op); + + grad_op->Init(); + return grad_op; } static std::unordered_map& protos() { @@ -307,6 +305,21 @@ class OpRegistry { return maps_; } + static std::unordered_map& creators() { + static std::unordered_map creators_; + return creators_; + } + + static std::unordered_map& op_checkers() { + static std::unordered_map op_checkers_; + return op_checkers_; + }; + + static std::unordered_map& grad_creators() { + static std::unordered_map grad_creators_; + return grad_creators_; + } + static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); for (auto& outname : op->outputs_) { @@ -318,19 +331,98 @@ class OpRegistry { } } - static std::unordered_map& creators() { - static std::unordered_map creators_; - return creators_; + static void AssembleGradInOut(OperatorPtr op, OperatorPtr grad_op) { + size_t in_sz = op->inputs_.size() + op->outputs_.size() * 2; + grad_op->inputs_.reserve(in_sz); + size_t out_sz = op->inputs_.size(); + grad_op->outputs_.reserve(out_sz); + // copy op->inputs_ to grad_op->inputs_ + std::copy(op->inputs_.begin(), op->inputs_.end(), + std::back_inserter(grad_op->inputs_)); + // copy op->outputs_ to grad_op->inputs_ + std::copy(op->outputs_.begin(), op->outputs_.end(), + std::back_inserter(grad_op->inputs_)); + // add gradients of op->outputs_ to grad_op->inputs_ + for (const std::string& name : op->outputs_) { + grad_op->inputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX()); + } + // add gradients of op->inputs_ to grad_op->outputs_ + for (const std::string& name : op->inputs_) { + grad_op->outputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX()); + } } - static std::unordered_map& op_checkers() { - static std::unordered_map op_checkers_; - return op_checkers_; - }; + static void GenerateGradArgOffset(OperatorPtr op, OperatorPtr grad_op) { + VarIndexMap* grad_varmap = new VarIndexMap(); + const OpProto& op_proto = protos()[op->type_]; + int idx = 0; + // offset of op's inputs + for (const auto& var : op_proto.inputs()) { + (*grad_varmap)[var.name()] = idx++; + } + // offset of op's outputs + for (const auto& var : op_proto.outputs()) { + (*grad_varmap)[var.name()] = idx++; + } + // offset of gradients of op's output + for (const auto& var : op_proto.outputs()) { + (*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++; + } + idx = 0; + // offset of gradients of op's input + for (const auto& var : op_proto.inputs()) { + (*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++; + } + grad_op->in_out_idxs_.reset(grad_varmap); + } - static std::unordered_map& grad_creators() { - static std::unordered_map grad_creators_; - return grad_creators_; + static void GenerateGradAttr(OperatorPtr op, OperatorPtr grad_op) { + const OpProto& op_proto = protos()[op->type_]; + grad_op->attrs_ = op->attrs_; + grad_op->attrs_.erase("input_format"); + grad_op->attrs_.erase("output_format"); + bool has_in_format = op->attrs_.count("input_format"); + bool has_out_format = op->attrs_.count("output_format"); + // grad_op's inputs_ contains op's inputs_, outputs_ and gradients of + // outpus_. So grad_op's input_format is necessary when op has + // either input_format or output_format. + if (has_in_format || has_out_format) { + std::vector old_in_format; + std::vector old_out_format; + has_in_format + ? old_in_format = op->GetAttr>("input_format") + : old_in_format = std::vector(op_proto.inputs_size()), + std::iota(old_in_format.begin(), old_in_format.end(), 0); + has_out_format + ? old_out_format = op->GetAttr>("output_format") + : old_out_format = std::vector(op_proto.outputs_size()), + std::iota(old_out_format.begin(), old_out_format.end(), 0); + + std::vector in_format; + in_format.reserve(old_in_format.size() + old_out_format.size() * 2); + int base = 0; + for (const int& idx : old_in_format) { + in_format.emplace_back(idx + base); + } + base += op->inputs_.size(); + for (const int& idx : old_out_format) { + in_format.emplace_back(idx + base); + } + base += op->outputs_.size(); + for (const int& idx : old_in_format) { + in_format.emplace_back(idx + base); + } + grad_op->attrs_["input_format"] = in_format; + // grad_op's outputs_ contains gradients of op's inputs_. So grad_op's + // output_format is necessary only when op has input_format. + if (has_in_format) { + std::vector out_format; + out_format.reserve(op_proto.inputs_size()); + std::copy(old_in_format.begin(), old_in_format.end(), + std::back_inserter(out_format)); + grad_op->attrs_["output_format"] = out_format; + } + } } }; @@ -370,7 +462,7 @@ class GradOpRegisterHelper { int __op_register_##__op_type##_handle__() { return 0; } /** - * Macro to Register Operator. + * Macro to Register Gradient Operator. */ #define REGISTER_GRADIENT_OP(__op_type, __op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 5f046d6293d5dbb9fd594b0c13aa8d62012cf915..31d7b2575cc122eb4e64bc3c95d5838f8e0a2bab 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -63,6 +63,11 @@ class OperatorBase { /// but it will be convert to a unique name in scope after OpCreator. static std::string TMP_VAR_NAME() { return "@TEMP@"; } + /// If a variable's name has a certain suffix, it means that the + /// variable is the gradient of another varibale. + /// e.g. Variable "x@GRAD" is the gradient of varibale "x". + static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + virtual ~OperatorBase() {} template