operator部分代码设计
Created by: codeWorm2015
Operator 设计
一个operator 可分为两层:
- 一层为 op 层, 包括参数获取包装成param结构体(传递给kernel层的参数结构体)和InferShape的操作.
- 另一层为 kernel 层, 包含了op的具体运算部分, kernel层为可特化到具体平台实现
每一个op需要对应着一段op注册代码, 用于上层在实例化op时使用
op层
template <typename Dtype>
class OperatorBase {
public:
/*
* @b op 基类的实例化方法, op 获取到了 输入、参数以及提前分配好的输出 tensor
* */
OperatorBase(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope);
virtual ~OperatorBase() {}
void Run() const;
std::vector<string> GetOutKeys() const;
virtual void RunImpl() const = 0;
virtual void Init() = 0;
/*
* @b op 运算所需的输入, 如上一层的输出结果、卷积核
* */
const VariableNameMap &Inputs() const { return inputs_; }
/*
* @b op 的输出, 内存会提前被分配好, 运算结果会被存到分配好的内存内
* */
const VariableNameMap &Outputs() const { return outputs_; }
/*
* @b op 类型
* */
const std::string &Type() const { return type_; }
/*
* @b 根据输入形状和参数计算出输出形状
* */
virtual void InferShape() const = 0;
protected:
std::shared_ptr<Scope> scope_;
std::string type_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
private:
void CheckAllInputOutputSet() const;
};
/*
* @b 这个类为所有带有运算的 op 的父类, 这个 op 继承与 OperatorBase
* */
template <typename Dtype, typename ParamType, typename KernelType>
class OperatorWithKernel : public OperatorBase<Dtype> {
public:
OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
param_(inputs, outputs, attrs, *scope) {}
virtual void RunImpl() const { this->kernel_.Compute(this->param_); }
virtual void InferShape() const = 0;
void Init() {
// op 实现者可以重写该方法, 对参数进行预处理
PADDLE_MOBILE_ENFORCE(kernel_.Init(¶m_), " %s kernel init failed",
this->type_.c_str());
}
protected:
KernelType kernel_;
ParamType param_;
};
kernel 层
/*
* @b 所有kernel的父类
* */
template <typename Dtype, typename P>
class OpKernelBase {
public:
/*
* @b 所有kernel 需实现 Compute 方法
* @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
* 所有结构体存在与: paddle-mobile/src/operators/op_param.h
* */
virtual void Compute(const P ¶) const = 0;
virtual bool Init(P *para) { return true; };
virtual ~OpKernelBase() = default;
};
例子: 一个 relu 的实现
//relu_op.h
template <typename DeviceType, typename T>
class ReluOp
: public framework::OperatorWithKernel<
DeviceType, ReluParam, operators::ReluKernel<DeviceType, T>> {
public:
/*
* @b op 的实例化方法, 需要调用父类的实例化方法, 以及实例化自己的参数结构体
* */
ReluOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ReluParam,
operators::ReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ReluParam,
operators::ReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
protected:
};
.cpp 中给出了 InferShape 的实现, 和op注册部分
//relu_op.cpp
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ReluOp<Dtype, T>::InferShape() const {
auto input_dims = param_.InputX()->dims();
param_.Out()->Resize(input_dims);
}
template class ReluOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
/*
* @b 每一个 op 都需要注册一下的,
* USE_OP的参数 和 REGISTER_OPERATOR的第一个参数 都是需要和model中类型对应起来的
* */
namespace ops = paddle_mobile::operators;
USE_OP(relu);
REGISTER_OPERATOR(relu, ops::ReluOp);
ReluParam 用于包装参数的结构体
//op_param.h
/*
* @b op 层实例化好这个 param 传递给 kernel 层使用
* */
class ReluParam : public OpParam {
public:
ReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<Tensor>(inputs, scope);
out_ = OutFrom<Tensor>(outputs, scope);
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
private:
Tensor *input_x_;
Tensor *out_;
};
kernel 层声明
template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
public:
void Compute(const ReluParam& param) const;
bool Init(ReluParam* param);
};
特化到 arm 平台
/*
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <>
bool ReluKernel<CPU, float>::Init(ReluParam *param) {
// 进行一些预处理操作
return true;
}
/*
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <>
void ReluKernel<CPU, float>::Compute(const ReluParam ¶m) const {
// arm 汇编实现 ...
}