moe_op.cc 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"

namespace paddle {
namespace operators {

class MoeOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
    return framework::OpKernelType(data_type, ctx.device_context());
  }
};

class MoeOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor), The source input tensor of Moe op.");
    AddInput("Gate", "(Tensor), The gating input tensor of Moe op.");
    AddInput("Bmm0", "(Tensor), The bmm0 input tensor of Moe op.");
    AddInput("Bias0", "(Tensor), The eltwise0 input tensor of Moe op.");
    AddInput("Bmm1", "(Tensor), The bmm1 input tensor of Moe op.");
    AddInput("Bias1", "(Tensor), The eltwise1 input tensor of Moe op.");
    AddOutput("Out", "(Tensor), The output tensor of Moe op.");
    AddAttr<std::string>(
        "act_type",
        R"DOC(activation type, currently only support `gelu`, `relu`. Default value is: `gelu`. )DOC")
        .SetDefault("gelu");
    AddComment(
        R"DOC(FusedEcMoe kernel. For more details you can refer to `FusedEcMoE` python documents. )DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(moe,
                            MoeInferShapeFunctor,
                            PD_INFER_META(phi::MoeInferMeta));
REGISTER_OPERATOR(moe, ops::MoeOp, ops::MoeOpMaker, MoeInferShapeFunctor);