提交 240ed5e6 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #4543 from reyoung/feature/change_grad_reg_mechanism

Add GradOpDescMaker to OpInfo and complete OperatorRegistrar method
...@@ -22,7 +22,7 @@ cc_library(attribute SRCS attribute.cc DEPS framework_proto) ...@@ -22,7 +22,7 @@ cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
namespace details {
enum OpInfoFillType {
kOperator = 0,
kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2
};
template <typename T>
struct OpInfoFillTypeID {
static constexpr OpInfoFillType ID() {
return std::is_base_of<OperatorBase, T>::value
? kOperator
: (std::is_base_of<OpProtoAndCheckerMaker, T>::value
? kOpProtoAndCheckerMaker
: (std::is_base_of<GradOpDescMakerBase, T>::value
? kGradOpDescMaker
: static_cast<OpInfoFillType>(-1)));
}
};
template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()>
struct OpInfoFiller;
template <size_t I, bool at_end, typename... ARGS>
class OperatorRegistrarRecursive;
template <size_t I, typename... ARGS>
class OperatorRegistrarRecursive<I, false, ARGS...> {
public:
using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {
OpInfoFiller<T> fill;
fill(op_type, info);
constexpr auto size = sizeof...(ARGS);
OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type,
info);
(void)(reg);
}
};
template <size_t I, typename... ARGS>
class OperatorRegistrarRecursive<I, true, ARGS...> {
public:
OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {}
};
template <typename T>
struct OpInfoFiller<T, kOperator> {
void operator()(const char* op_type, OpInfo* info) const {
info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs) {
return new T(type, inputs, outputs, attrs);
};
}
};
template <typename T>
struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->proto_ = new OpProto;
info->checker_ = new OpAttrChecker();
auto maker = T(info->proto_, info->checker_);
maker.Validate();
info->proto_->set_type(op_type);
PADDLE_ENFORCE(
info->proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, info->proto_->InitializationErrorString());
}
};
template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = new T();
}
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -29,11 +29,18 @@ using OpCreator = std::function<OperatorBase*( ...@@ -29,11 +29,18 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
class GradOpDescMakerBase {
public:
virtual ~GradOpDescMakerBase() = default;
virtual std::vector<OpDescBind> operator()(const OpDescBind&) const = 0;
};
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
std::string grad_op_type_; std::string grad_op_type_;
OpProto* proto_; GradOpDescMakerBase* grad_op_maker_{nullptr};
OpAttrChecker* checker_; OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
bool HasOpProtoAndChecker() const { bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;
......
...@@ -21,49 +21,42 @@ limitations under the License. */ ...@@ -21,49 +21,42 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/details/op_registry.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename... ARGS>
struct OperatorRegistrar {
explicit OperatorRegistrar(const char* op_type) : op_type(op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type);
static_assert(sizeof...(ARGS) != 0,
"OperatorRegistrar should be invoked at least by OpClass");
details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info);
}
~OperatorRegistrar() { OpInfoMap::Instance().Insert(op_type, info); }
const char* op_type;
OpInfo info;
};
class OpRegistry { class OpRegistry {
public: public:
template <typename OpType, typename ProtoMakerType, typename GradOpType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type, static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) { const std::string& grad_op_type) {
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), OperatorRegistrar<OpType, ProtoMakerType> reg(op_type.c_str());
"'%s' is registered more than once.", op_type); reg.info.grad_op_type_ = grad_op_type;
OpInfo op_info;
op_info.creator_ = [](
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate();
op_info.proto_->set_type(op_type);
PADDLE_ENFORCE(
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
} else {
op_info.proto_ = nullptr;
op_info.checker_ = nullptr;
}
OpInfoMap::Instance().Insert(op_type, op_info);
// register gradient op // register gradient op
if (!grad_op_type.empty()) { if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, ""); OperatorRegistrar<GradOpType> grad_reg(grad_op_type.c_str());
} }
} }
......
...@@ -173,3 +173,14 @@ TEST(OpRegistry, CustomChecker) { ...@@ -173,3 +173,14 @@ TEST(OpRegistry, CustomChecker) {
int test_attr = op->Attr<int>("test_attr"); int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
class CosineOpComplete : public paddle::framework::CosineOp {
public:
DEFINE_OP_CONSTRUCTOR(CosineOpComplete, paddle::framework::CosineOp);
DEFINE_OP_CLONE_METHOD(CosineOpComplete);
};
TEST(OperatorRegistrar, Test) {
using namespace paddle::framework;
OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos");
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册