limit_by_capacity_op.cc 3.2 KB
Newer Older
R
Roc 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/limit_by_capacity_op.h"

namespace paddle {
namespace operators {

class LimitByCapacityOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
25 26 27
    OP_INOUT_CHECK(ctx->HasInput("expert_count"),
                   "Input",
                   "expert_count",
R
Roc 已提交
28
                   "LimitByCapacity");
29 30
    OP_INOUT_CHECK(
        ctx->HasInput("capacity"), "Input", "capacity", "LimitByCapacity");
R
Roc 已提交
31 32 33 34 35 36 37
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "LimitByCapacity");

    ctx->ShareDim("expert_count", "Out");
    ctx->ShareLoD("expert_count", "Out");
  }

 protected:
38
  phi::KernelKey GetExpectedKernelType(
R
Roc 已提交
39 40 41 42 43 44 45 46
      const framework::ExecutionContext& ctx) const override {
    // the dtype of the expert_count and capacity should be same as int64
    auto expert_count_dtype =
        OperatorWithKernel::IndicateVarDataType(ctx, "expert_count");
    auto capacity_dtype =
        OperatorWithKernel::IndicateVarDataType(ctx, "capacity");

    PADDLE_ENFORCE_EQ(
47 48
        expert_count_dtype,
        capacity_dtype,
R
Roc 已提交
49 50 51 52
        platform::errors::InvalidArgument(
            "The dtype of the expert_count and capacity should be same"));

    PADDLE_ENFORCE_EQ(
53 54
        expert_count_dtype,
        framework::proto::VarType::INT64,
R
Roc 已提交
55 56
        platform::errors::InvalidArgument("The dtype of the expert_count and "
                                          "capacity should be same as int64"));
57
    return phi::KernelKey(expert_count_dtype, ctx.GetPlace());
R
Roc 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
  }
};

class LimitByCapacityOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("expert_count", "(Tensor) The input expert count tensor.");
    AddInput("capacity", "(Tensor) The input capacity.");
    AddOutput("Out",
              "(Tensor) The output tensor expert count limit by capacity.");
    AddAttr<int>("n_worker", "(int), The number of works.");
    AddComment(
        R"DOC(limit_by_capacity Operator.limit expert count by capacity.)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

80 81
REGISTER_OP_WITHOUT_GRADIENT(limit_by_capacity,
                             ops::LimitByCapacityOp,
R
Roc 已提交
82
                             ops::LimitByCapacityOpMaker);
H
huangjiyi 已提交
83 84 85 86 87 88 89

PD_REGISTER_STRUCT_KERNEL(limit_by_capacity,
                          CPU,
                          ALL_LAYOUT,
                          ops::LimitByCapacityOpCPUKernel,
                          int,
                          int64_t) {}