提交 a39deed4 编写于 作者: Z zhaojiaying01

op register

上级 52c6945b
......@@ -19,7 +19,9 @@ SOFTWARE.
#include <map>
#include <string>
#include <vector>
#include "framework/attribute.h"
#include "framework/scope.h"
namespace paddle_mobile {
......@@ -37,7 +39,8 @@ template <typename Dtype>
using OpCreator = std::function<framework::OperatorBase<Dtype> *(
const std::string & /*type*/, const VariableNameMap & /*inputs*/,
const VariableNameMap & /*outputs*/,
const framework::AttributeMap & /*attrs*/)>;
const framework::AttributeMap & /*attrs*/,
std::shared_ptr<framework::Scope> /*scope*/)>;
using GradOpMakerFN =
std::function<std::vector<std::unique_ptr<framework::OpDesc>>(
......
......@@ -18,8 +18,10 @@ SOFTWARE.
#pragma once
#include <string>
#include "common/log.h"
#include "common/type_define.h"
#include "framework.pb.h"
#include "framework/framework.pb.h"
namespace paddle_mobile {
namespace framework {
......@@ -45,11 +47,12 @@ template <typename Dtype>
class OpInfoMap {
public:
static OpInfoMap &Instance() {
LOG(paddle_mobile::kLOG_DEBUG1) << " TODO: fix bug";
if (g_op_info_map<Dtype> == nullptr) {
g_op_info_map<Dtype> = new OpInfoMap();
}
return *g_op_info_map<Dtype>;
};
}
bool Has(const std::string &op_type) const {
return map_.find(op_type) != map_.end();
......
......@@ -17,3 +17,121 @@ SOFTWARE.
==============================================================================*/
#pragma once
#include <string>
#include <tuple>
#include "common/log.h"
#include "common/type_define.h"
#include "framework/op_info.h"
#include "framework/operator.h"
namespace paddle_mobile {
namespace framework {
class Registrar {
public:
void Touch() {}
};
template <typename Dtype, size_t I, bool at_end, typename... ARGS>
class OperatorRegistrarRecursive;
template <typename Dtype, typename... ARGS>
struct OperatorRegistrar : public Registrar {
explicit OperatorRegistrar(const std::string& op_type) {
if (OpInfoMap<Dtype>::Instance().Has(op_type)) {
LOG(paddle_mobile::kLOG_DEBUG1)
<< op_type << " is registered more than once.";
return;
}
if (sizeof...(ARGS) == 0) {
LOG(paddle_mobile::kLOG_DEBUG1)
<< "OperatorRegistrar should be invoked at least by OpClass";
return;
}
OpInfo<Dtype> info;
OperatorRegistrarRecursive<Dtype, 0, false, ARGS...>(op_type, &info);
OpInfoMap<Dtype>::Instance().Insert(op_type, info);
}
};
template <typename Dtype, typename T>
struct OpInfoFiller {
void operator()(const std::string& op_type, OpInfo<Dtype>* info) const {
info->creator_ = [](const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs,
std::shared_ptr<Scope> scope) {
return new T(type, inputs, outputs, attrs, scope);
};
}
};
template <typename Dtype, size_t I, typename... ARGS>
class OperatorRegistrarRecursive<Dtype, I, false, ARGS...> {
public:
using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
OperatorRegistrarRecursive(const std::string& op_type, OpInfo<Dtype>* info) {
OpInfoFiller<Dtype, T> fill;
fill(op_type, info);
constexpr auto size = sizeof...(ARGS);
OperatorRegistrarRecursive<Dtype, I + 1, I + 1 == size, ARGS...> reg(
op_type, info);
(void)(reg);
}
};
template <typename Dtype, size_t I, typename... ARGS>
class OperatorRegistrarRecursive<Dtype, I, true, ARGS...> {
public:
OperatorRegistrarRecursive(const std::string& op_type, OpInfo<Dtype>* info) {}
};
template <typename Dtype>
class OpRegistry {
public:
static std::shared_ptr<OperatorBase<Dtype>> CreateOp(
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap attrs,
std::shared_ptr<paddle_mobile::framework::Scope> scope) {
LOG(paddle_mobile::kLOG_DEBUG1) << " type: "
<< type;
LOG(paddle_mobile::kLOG_DEBUG1) << " input size: "
<< inputs.size();
LOG(paddle_mobile::kLOG_DEBUG1) << " output size: "
<< outputs.size();
LOG(paddle_mobile::kLOG_DEBUG1) << " attr size: "
<< attrs.size();
LOG(paddle_mobile::kLOG_DEBUG1) << " OpInfoMap size: "
<< OpInfoMap<Dtype>::Instance().map().size();
LOG(paddle_mobile::kLOG_DEBUG1) << " has type: "
<< type
<< " "
<< OpInfoMap<Dtype>::Instance().Has(type);
auto& info = OpInfoMap<Dtype>::Instance().Get(type);
auto op = info.Creator()(type, inputs, outputs, attrs, scope);
return std::shared_ptr<OperatorBase<Dtype>>(op);
}
};
#define REGISTER_OPERATOR(op_type, op_class) \
template <typename Dtype, typename T> \
class _OpClass_##op_type##_ : public op_class<Dtype, T> { \
public: \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
static paddle_mobile::framework::OperatorRegistrar< \
paddle_mobile::CPU, _OpClass_##op_type##_<paddle_mobile::CPU, float>> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
#define USE_OP(op_type) \
extern int TouchOpRegistrar_##op_type(); \
static int use_op_itself_##op_type##_ __attribute__((unused)) = \
TouchOpRegistrar_##op_type()
} // namespace framework
} // namespace paddle_mobile
......@@ -19,18 +19,20 @@ SOFTWARE.
#pragma once
#include <map>
#include "attribute.h"
#include "block_desc.h"
#include <string>
#include <vector>
#include <utility>
#include "framework/attribute.h"
#include "framework/block_desc.h"
#include "common/type_define.h"
#include "common/types.h"
#include "common/variant.h"
#include "op_info.h"
#include "op_kernel_type.h"
#include "paddle_mobile_object.h"
#include "scope.h"
#include "tensor.h"
#include "variable.h"
#include "framework/op_info.h"
#include "framework/op_kernel_type.h"
#include "framework/paddle_mobile_object.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"
namespace paddle_mobile {
namespace framework {
......@@ -97,5 +99,12 @@ class OpKernelBase : PaddleMobileObject {
virtual ~OpKernelBase() = default;
};
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \
cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \
const ::paddle_mobile::VariableNameMap &outputs, \
const ::paddle_mobile::framework::AttributeMap &attrs, \
std::shared_ptr<::paddle_mobile::framework::Scope> scope) \
: parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {}
} // namespace framework
} // namespace paddle_mobile
......@@ -16,9 +16,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
==============================================================================*/
#include "conv_op.h"
#include <vector>
#include "operators/conv_op.h"
#include "framework/data_type.h"
#include "framework/op_proto_maker.h"
#include "framework/op_registry.h"
namespace paddle_mobile {
namespace operators {
......@@ -73,3 +75,7 @@ template class ConvOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
USE_OP(conv2d);
REGISTER_OPERATOR(conv2d, ops::ConvOp);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册