Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle-Lite
  • Issue
  • #300

P
Paddle-Lite
  • 项目概览

PaddlePaddle / Paddle-Lite

通知 338
Star 4
Fork 1
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 271
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 78
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle-Lite
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 271
    • Issue 271
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 78
    • 合并请求 78
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 5月 28, 2018 by saxon_zh@saxon_zhGuest

operator部分代码设计

Created by: codeWorm2015

Operator 设计

一个operator 可分为两层:

  • 一层为 op 层, 包括参数获取包装成param结构体(传递给kernel层的参数结构体)和InferShape的操作.
  • 另一层为 kernel 层, 包含了op的具体运算部分, kernel层为可特化到具体平台实现

每一个op需要对应着一段op注册代码, 用于上层在实例化op时使用

op层

template <typename Dtype>
class OperatorBase {
 public:
  /*
   *  @b op 基类的实例化方法, op 获取到了 输入、参数以及提前分配好的输出 tensor
   * */
  OperatorBase(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs,
               std::shared_ptr<Scope> scope);
  virtual ~OperatorBase() {}
  void Run() const;

  std::vector<string> GetOutKeys() const;

  virtual void RunImpl() const = 0;

  virtual void Init() = 0;
  /*
   * @b op 运算所需的输入, 如上一层的输出结果、卷积核
   * */
  const VariableNameMap &Inputs() const { return inputs_; }
  /*
   * @b op 的输出, 内存会提前被分配好, 运算结果会被存到分配好的内存内
   * */
  const VariableNameMap &Outputs() const { return outputs_; }
  /*
   * @b op 类型
   * */
  const std::string &Type() const { return type_; }

  /*
   * @b 根据输入形状和参数计算出输出形状
   * */
  virtual void InferShape() const = 0;

 protected:
  std::shared_ptr<Scope> scope_;
  std::string type_;
  VariableNameMap inputs_;
  VariableNameMap outputs_;
  AttributeMap attrs_;

 private:
  void CheckAllInputOutputSet() const;
};
/*
 * @b 这个类为所有带有运算的 op 的父类, 这个 op 继承与 OperatorBase
 * */
template <typename Dtype, typename ParamType, typename KernelType>
class OperatorWithKernel : public OperatorBase<Dtype> {
 public:
  OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
        param_(inputs, outputs, attrs, *scope) {}

  virtual void RunImpl() const { this->kernel_.Compute(this->param_); }

  virtual void InferShape() const = 0;

  void Init() {
   // op 实现者可以重写该方法, 对参数进行预处理
    PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), "  %s kernel init failed",
                          this->type_.c_str());
  }

 protected:
  KernelType kernel_;
  ParamType param_;
};

kernel 层

/*
 * @b 所有kernel的父类
 * */
template <typename Dtype, typename P>
class OpKernelBase {
 public:
  /*
   * @b 所有kernel 需实现 Compute 方法
   * @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
   *    所有结构体存在与: paddle-mobile/src/operators/op_param.h
   * */
  virtual void Compute(const P &para) const = 0;
  virtual bool Init(P *para) { return true; };
  virtual ~OpKernelBase() = default;
};

例子: 一个 relu 的实现

//relu_op.h

template <typename DeviceType, typename T>
class ReluOp
    : public framework::OperatorWithKernel<
          DeviceType, ReluParam, operators::ReluKernel<DeviceType, T>> {
 public:
  /*
   * @b op 的实例化方法, 需要调用父类的实例化方法, 以及实例化自己的参数结构体
   * */
  ReluOp(const std::string &type, const VariableNameMap &inputs,
         const VariableNameMap &outputs, const framework::AttributeMap &attrs,
         std::shared_ptr<framework::Scope> scope)
      : framework::OperatorWithKernel<DeviceType, ReluParam,
                                      operators::ReluKernel<DeviceType, T>>(
            type, inputs, outputs, attrs, scope) {}

  using framework::OperatorWithKernel<
      DeviceType, ReluParam,
      operators::ReluKernel<DeviceType, T>>::OperatorWithKernel;
  void InferShape() const override;

 protected:
};

.cpp 中给出了 InferShape 的实现, 和op注册部分

//relu_op.cpp
namespace paddle_mobile {
namespace operators {

template <typename Dtype, typename T>
void ReluOp<Dtype, T>::InferShape() const {
  auto input_dims = param_.InputX()->dims();
  param_.Out()->Resize(input_dims);
}
template class ReluOp<CPU, float>;
}  // namespace operators
}  // namespace paddle_mobile

/*
 * @b 每一个 op 都需要注册一下的,
 *    USE_OP的参数 和 REGISTER_OPERATOR的第一个参数 都是需要和model中类型对应起来的
 * */
namespace ops = paddle_mobile::operators;
USE_OP(relu);
REGISTER_OPERATOR(relu, ops::ReluOp);

ReluParam 用于包装参数的结构体

//op_param.h
/*
 * @b op 层实例化好这个 param 传递给 kernel 层使用
 * */
class ReluParam : public OpParam {
 public:
  ReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
            const AttributeMap &attrs, const Scope &scope) {
    input_x_ = InputXFrom<Tensor>(inputs, scope);
    out_ = OutFrom<Tensor>(outputs, scope);
  }
  const Tensor *InputX() const { return input_x_; }
  Tensor *Out() const { return out_; }
 private:
  Tensor *input_x_;
  Tensor *out_;
};

kernel 层声明

template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
 public:
  void Compute(const ReluParam& param) const;
  bool Init(ReluParam* param);
};

特化到 arm 平台

/*
 * @b 特化到具体平台的实现, param 从 op 层传入
 * */
template <>
bool ReluKernel<CPU, float>::Init(ReluParam *param) {
  // 进行一些预处理操作
  return true;
}

/*
 * @b 特化到具体平台的实现, param 从 op 层传入
 * */
template <>
void ReluKernel<CPU, float>::Compute(const ReluParam &param) const {
      // arm 汇编实现 ...
}

image

指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/Paddle-Lite#300
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7