如何创建一个 fusion op
Created by: codeWorm2015
这里以一个 fusion_fc_op 为例:
//FushionFcOp.h
using std::string;
using std::vector;
//创建一个 FusionFcMatcher 继承与 FusionOpMatcher
class FusionFcMatcher : public framework::FusionOpMatcher {
public:
// 用于描述 Fc 融合了哪些op, 这里融合了 mul 和 elementwise_add
FusionFcMatcher() {
node_ = framework::Node(G_OP_TYPE_MUL);
node_ > std::make_shared<framework::Node>(G_OP_TYPE_ELEMENTWISE_ADD);
}
//融合过程中, 优化器会来调用这个方法
void FolderNodes(
framework::Node *node,
std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
/*告诉优化器,
要融合的op的深度(这里就是 op 的个数);
类型;
以及要保留中间op的参数, 这里保留了 elementwise_add op的 Y 参数, 使用 字段 Z替换;
以及要删除的原 node
*/
node->Folder(node_.Depth(), Type(),
{{G_OP_TYPE_ELEMENTWISE_ADD, {"Y", "Z"}}}, removed_nodes);
}
std::string Type() { return G_OP_TYPE_FC; }
};
// 融合 op 的定义和 普通 op 是一样的, 需要根据新的模型描述创建一个 param, 这里是 FushionFcParam
template <typename DeviceType, typename T>
class FushionFcOp : public framework::OperatorWithKernel<DeviceType> {
public:
FushionFcOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const {
operators::FushionFcKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
protected:
FushionFcParam param_;
};
//需要注册一下 FC 融合器
static framework::FusionOpRegistrar fc_registrar(new FusionFcMatcher());
其他使用 和 普通 op 是一致的