提交 77af58f8 编写于 作者: F fengjiayi

Change gradient Op registry mechanism

OLD: op_type -> grad_op_creator

NEW: grad_op_type -> grad_op_creator
     op_type -> grad_op_type
上级 02cde244
......@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
OperatorBase* GradOpCreator::Create() {
OperatorBase* GradOpBuilder::Build() {
BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)();
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op);
return grad_op;
}
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
const VarIndexMap& var_map,
const std::vector<int>& format,
InOutType type) {
......@@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
end_idx);
}
void GradOpCreator::BuildOpInOutArgList() {
void GradOpBuilder::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format =
......@@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() {
}
}
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out,
std::vector<int>& format,
VarIndexMap* varmap, int& idx,
......@@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
format.push_back(in_out.size());
}
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
......
......@@ -25,12 +25,12 @@ struct OpInOutArg {
size_t end_idx_;
};
class GradOpCreator {
class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>;
public:
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
GradOpBuilder(const OperatorBase* op) : op_(op) {}
OperatorBase* Build();
private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
......
......@@ -222,7 +222,7 @@ class OpRegistry {
public:
template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; };
op_creators()[op_type] = [] { return new OpType; };
OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
......@@ -245,17 +245,19 @@ class OpRegistry {
}
}
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
grad_ops()[op_type] = grad_op_type;
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(),
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
auto op = op_create_it->second();
......@@ -300,8 +302,8 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get());
std::shared_ptr<OperatorBase> grad_op(creator.Create());
GradOpBuilder builder(op.get());
std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init();
return grad_op;
}
......@@ -311,9 +313,9 @@ class OpRegistry {
return protos_;
};
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
static std::unordered_map<std::string, std::string>& grad_ops() {
static std::unordered_map<std::string, std::string> grad_ops_;
return grad_ops_;
}
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
......@@ -322,12 +324,12 @@ class OpRegistry {
return maps_;
}
private:
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
}
private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
......@@ -353,11 +355,11 @@ class OpRegisterHelper {
}
};
template <typename OpType>
template <typename GradOpType>
class GradOpRegisterHelper {
public:
GradOpRegisterHelper(const char* op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type);
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
}
};
......@@ -383,13 +385,16 @@ class GradOpRegisterHelper {
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
#__grad_op_type); \
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
}
/**
* Macro to Register OperatorKernel.
......
......@@ -65,6 +65,6 @@ protected:
} // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad);
REGISTER_GRADIENT_OP(add_two, add_two_grad, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册