distributed_fused_lamb_init_op.cc 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h"

namespace paddle {
namespace operators {

class DistributedFusedLambInitOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext *ctx) const override {}

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    auto dtype = framework::proto::VarType::FP32;  // dtype is not important
    return framework::OpKernelType(dtype, ctx.GetPlace());
  }
};

class DistributedFusedLambInitOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("Param", "The initial parameter list.").AsDuplicable();
    AddInput("Grad", "The initial gradient list.").AsDuplicable();

    AddOutput("FP32FusedParam",
              "The fp32 fused param and fp16 fused master weight tensor. Its "
              "shape is [M1+M2], where M1 is the fp32 fused parameter size and "
              "M2 is the fp16 fused master weight parameter size. Note that M1 "
              "and M2 should be exactly divided by N (guaranteed by extra "
              "padding 0), where N is the world size.")
        .AsDispensable();
    AddOutput("FP32FusedGrad", "The fp32 fused grad tensor. Its shape is [M1].")
        .AsDispensable();
    AddOutput("FP16FusedParam",
              "The fp16 fused param tensor. Its shape is [M2].")
        .AsDispensable();
    AddOutput("FP16FusedGrad", "The fp16 fused grad tensor. Its shape is [M2].")
        .AsDispensable();

    AddOutput("Moment1",
              "The sharded fp32 moment1 tensor. Its shape is [(M1+M2)/N].");
    AddOutput("Moment2",
              "The sharded fp32 moment2 tensor. Its shape is [(M1+M2)/N].");
    AddOutput("Beta1Pow",
              "The fp32 beta1 power accumulator tensor. Its shape is [1].");
    AddOutput("Beta2Pow",
              "The fp32 beta2 power accumulator tensor. Its shape is [1].");
    AddOutput(
        "FusedParamOffsets",
        "The numel offset of each parameter inside the FP32FusedParam. Its "
        "shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
68
        "+ n_2, ...]. It should be in CPUPlace.");
69
    AddOutput(
70 71 72 73 74 75 76
        "FP32ShardFusedParamOffsets",
        "The sharded numel offset of each parameter in the local rank. "
        "Its shape is [fp32_local_param_num + 1]. It should be in CPUPlace.");
    AddOutput(
        "FP16ShardFusedParamOffsets",
        "The sharded numel offset of each parameter in the local rank. "
        "Its shape is [fp16_local_param_num + 1]. It should be in CPUPlace.");
77 78
    AddOutput("ParamInfo",
              "The param info. It should be in CPUPlace, and its shape is [6]"
79
              "CPUPlace, and its shape is [8]. It is "
80
              "[fp32_shard_param_start_idx, fp32_local_param_num, "
81 82 83 84 85 86 87 88
              "fp32_global_param_num, fp32_weight_decay_end_idx, "
              "fp16_shard_param_start_idx, "
              "fp16_local_param_num, fp16_global_param_num, "
              "fp16_weight_decay_end_idx].");
    AddOutput("ParamOrder",
              "The reordered parameter order. Inside this op, "
              "the parameter would be reordered by data type and weight decay "
              "value.");
89 90 91 92 93 94 95 96
    AddOutput("ParamOut", "The output parameter list.").AsDuplicable();
    AddOutput("MasterParamOut",
              "The output master parameter list. It would share the memory of "
              "each fp32 parameter and fp16 master parameter.")
        .AsDuplicable();
    AddOutput("GradOut", "The output gradient list.").AsDuplicable();
    AddOutput("GlobalScale",
              "The global scale. It is usually the scale factor for AMP.");
97
    AddOutput("Step", "The global step which excludes the NaN/Inf step.");
98 99 100

    AddAttr<float>("beta1", "The initial value of Beta1Pow.");
    AddAttr<float>("beta2", "The initial value of Beta2Pow.");
101 102
    AddAttr<std::vector<int>>("apply_weight_decay",
                              "Whether to apply weight decay.");
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    AddAttr<int>("alignment", "The alignment in bytes for the fused tensors.");
    AddAttr<int>("rank", "The global rank of the current process.");
    AddAttr<int>("nranks", "The global world size.");
    AddComment(
        R"DOC(The init operator for the DistributedFusedLamb optimizer.)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb_init,
                             ops::DistributedFusedLambInitOp,
                             ops::DistributedFusedLambInitOpMaker);

REGISTER_OP_CPU_KERNEL(
    distributed_fused_lamb_init,
    ops::DistributedFusedLambInitOpKernel<plat::CPUDeviceContext, float>);