optimizer_extract_pass.cc 13.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h"

#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
namespace framework {
namespace ir {

23 24 25 26 27 28 29 30 31 32
std::set<std::string> ignored_ops = {
    "sign",
    "sum",
    "clip",
    "clip_by_norm",
    "reduce_sum",
    "sqrt",
    "elementwise_max",
    "elementwise_div",
    "elementwise_mul",
33 34 35 36
    "scale",            // adamax
    "assign",           // adamw
    "squared_l2_norm",  // gradient_clip_norm
    "cast",             // mix-precision support
37 38 39 40 41 42 43 44 45 46 47 48 49
};

const bool startswith(const std::string& str, const std::string& pre) {
  if (str.rfind(pre, 0) == 0) {
    return true;
  } else {
    return false;
  }
}

const bool is_grad_clip_op(const std::string& op_namescope) {
  return startswith(op_namescope, "/gradient_clip");
}
50

51 52 53 54 55 56 57
const bool is_optimizer_op(const std::string& op_namescope) {
  return startswith(op_namescope, "/optimizer");
}

const bool is_regularization_op(const std::string& op_namescope) {
  return startswith(op_namescope, "/regularization");
}
58

59
void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
A
Allen Guo 已提交
60
  // optimizer values will be extracted when lowering optimizer in ipu_backend
61 62 63 64 65
  OpDesc new_op("popart_optimizer", {}, {}, {});
  new_op.SetAttr("op_role", 0);
  new_op.SetAttr("with_lr_sched", false);

  std::set<std::string> set_ops{};
A
Allen Guo 已提交
66 67 68 69
  // save the weight decay tensor_name and weight_decay_value for Lamb
  std::vector<std::string> weight_decay_vars{};
  std::vector<float> weight_decay_values{};

70
  // use map store <op_type, op_ptr> ?
71
  for (auto* node : graph->Nodes()) {
72 73 74
    if (!node->IsOp()) {
      continue;
    }
75

76 77 78 79 80
    auto op = node->Op();
    auto op_type = op->Type();
    int op_role_ = BOOST_GET_CONST(
        int, op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
    auto op_role = static_cast<OpRole>(op_role_);
81

82
    if (op_role == OpRole::kOptimize) {
A
Allen Guo 已提交
83 84 85 86 87 88 89 90 91
      // save weight decay value from every lamb optimizer op
      if (op_type == "lamb" && op->HasAttr("weight_decay")) {
        auto weight_decay_value =
            BOOST_GET_CONST(float, op->GetAttr("weight_decay"));
        auto params = op->Output("ParamOut");
        weight_decay_vars.push_back(params[0]);
        weight_decay_values.push_back(weight_decay_value);
      }

92 93
      if (set_ops.count(op_type)) {
        continue;
94 95
      }

96 97 98 99 100
      auto op_namescope =
          BOOST_GET_CONST(std::string, op->GetAttr("op_namescope"));
      bool is_grad_clip = is_grad_clip_op(op_namescope);
      // bool is_optimizer = is_optimizer_op(op_namescope);
      bool is_regularization = is_regularization_op(op_namescope);
101

102
      VLOG(10) << "found optimizer releated op: " << op_type;
A
Allen Guo 已提交
103
      // initial larning_rate will be set in ipu_backend
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
      set_ops.insert(op_type);
      if (op_type == "sgd") {
        auto type = std::string{"sgd"};
        auto lr_var = op->Input("LearningRate").front();
        new_op.SetAttr("type", type);
        new_op.SetAttr("lr_var", lr_var);
        new_op.SetAttr("weight_decay", 0.0f);
        new_op.SetAttr("momentum", 0.0f);
        new_op.SetAttr("raw_type", op_type);
      } else if (op_type == "momentum") {
        auto type = std::string{"sgd"};
        // auto LearningRate = op->Input("LearningRate");
        auto use_nesterov = BOOST_GET_CONST(bool, op->GetAttr("use_nesterov"));
        PADDLE_ENFORCE_EQ(use_nesterov, false,
                          platform::errors::Unimplemented(
                              "ipu does not support nesterov mode."));
        auto regularization_method =
            BOOST_GET_CONST(std::string, op->GetAttr("regularization_method"));
        PADDLE_ENFORCE_NE(regularization_method, "l1_decay",
                          platform::errors::Unimplemented(
                              "ipu does not support l1_decay mode."));
        auto multi_precision =
            BOOST_GET_CONST(bool, op->GetAttr("multi_precision"));
        PADDLE_ENFORCE_EQ(multi_precision, false,
                          platform::errors::Unimplemented(
                              "ipu does not support multi_precision mode."));
        auto rescale_grad = BOOST_GET_CONST(float, op->GetAttr("rescale_grad"));
        PADDLE_ENFORCE_EQ(rescale_grad, 1.0,
                          platform::errors::Unimplemented(
                              "ipu does not support rescale_grad mode."));
        auto regularization_coeff =
            BOOST_GET_CONST(float, op->GetAttr("regularization_coeff"));
        auto lr_var = op->Input("LearningRate").front();
        auto momentum = BOOST_GET_CONST(float, op->GetAttr("mu"));
        new_op.SetAttr("type", type);
        new_op.SetAttr("lr_var", lr_var);
        new_op.SetAttr("momentum", momentum);
        new_op.SetAttr("weight_decay", regularization_coeff);
        new_op.SetAttr("raw_type", op_type);
      } else if (op_type == "adam" || op_type == "adamw") {
        auto type = std::string{"adam"};
        auto lr_var = op->Input("LearningRate").front();
        auto beta1 = BOOST_GET_CONST(float, op->GetAttr("beta1"));
        auto beta2 = BOOST_GET_CONST(float, op->GetAttr("beta2"));
        auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
        auto lazy_mode = BOOST_GET_CONST(bool, op->GetAttr("lazy_mode"));
        auto multi_precision =
            BOOST_GET_CONST(bool, op->GetAttr("multi_precision"));
        PADDLE_ENFORCE_EQ(lazy_mode, false,
                          platform::errors::Unimplemented(
                              "ipu does not support lazy_mode mode."));
        PADDLE_ENFORCE_EQ(multi_precision, false,
                          platform::errors::Unimplemented(
                              "ipu does not support multi_precision mode."));
        new_op.SetAttr("type", type);
        new_op.SetAttr("lr_var", lr_var);
        new_op.SetAttr("weight_decay", 0.0f);
        new_op.SetAttr("beta1", beta1);
        new_op.SetAttr("beta2", beta2);
        new_op.SetAttr("eps", epsilon);
        new_op.SetAttr("adam_mode", std::string{"adam"});
        // adam or adamw
        if (op_type == "adam") {
          new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
          new_op.SetAttr("raw_type", std::string{"adam"});
        } else {
          new_op.SetAttr("weight_decay_mode", std::string{"decay"});
          new_op.SetAttr("raw_type", std::string{"adamw"});
        }
      } else if (op_type == "adamax") {
        auto type = std::string{"adam"};
        auto lr_var = op->Input("LearningRate").front();
        auto beta1 = BOOST_GET_CONST(float, op->GetAttr("beta1"));
        auto beta2 = BOOST_GET_CONST(float, op->GetAttr("beta2"));
        auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
        new_op.SetAttr("type", type);
        new_op.SetAttr("lr_var", lr_var);
        new_op.SetAttr("weight_decay", 0.0f);
        new_op.SetAttr("beta1", beta1);
        new_op.SetAttr("beta2", beta2);
        new_op.SetAttr("eps", epsilon);
        new_op.SetAttr("adam_mode", std::string{"adamax"});
        new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
        new_op.SetAttr("raw_type", op_type);
      } else if (op_type == "lamb") {
        // use decay mode
        auto type = std::string{"adam"};
        auto lr_var = op->Input("LearningRate").front();
        auto weight_decay = BOOST_GET_CONST(float, op->GetAttr("weight_decay"));
        auto beta1 = BOOST_GET_CONST(float, op->GetAttr("beta1"));
        auto beta2 = BOOST_GET_CONST(float, op->GetAttr("beta2"));
        auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
        new_op.SetAttr("type", type);
        new_op.SetAttr("lr_var", lr_var);
        new_op.SetAttr("weight_decay", weight_decay);
        new_op.SetAttr("beta1", beta1);
        new_op.SetAttr("beta2", beta2);
        new_op.SetAttr("eps", epsilon);
        new_op.SetAttr("adam_mode", std::string{"lamb"});
        new_op.SetAttr("weight_decay_mode", std::string{"decay"});
        new_op.SetAttr("raw_type", op_type);
      } else if (op_type == "adadelta") {
        // NO LearningRate
        auto type = std::string{"adaptive"};
        auto rho = BOOST_GET_CONST(float, op->GetAttr("rho"));
        auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
        new_op.SetAttr("type", type);
        new_op.SetAttr("weight_decay", 0.0f);
        new_op.SetAttr("alpha", rho);
        new_op.SetAttr("eps", epsilon);
        new_op.SetAttr("momentum", 0.0f);
        new_op.SetAttr("adaptive_mode", std::string{"adadelta"});
        new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
        new_op.SetAttr("raw_type", op_type);
      } else if (op_type == "adagrad") {
        auto type = std::string{"adaptive"};
        auto lr_var = op->Input("LearningRate").front();
        auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
        new_op.SetAttr("type", type);
        new_op.SetAttr("lr_var", lr_var);
        new_op.SetAttr("weight_decay", 0.0f);
        // `alpha` use default
        new_op.SetAttr("alpha", 0.99f);
        new_op.SetAttr("eps", epsilon);
        new_op.SetAttr("momentum", 0.0f);
        new_op.SetAttr("adaptive_mode", std::string{"adagrad"});
        new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
        new_op.SetAttr("raw_type", op_type);
      } else if (op_type == "rmsprop") {
        auto type = std::string{"adaptive"};
        auto lr_var = op->Input("LearningRate").front();
        auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
        auto decay = BOOST_GET_CONST(float, op->GetAttr("decay"));
        auto momentum = BOOST_GET_CONST(float, op->GetAttr("momentum"));
        auto centered = BOOST_GET_CONST(bool, op->GetAttr("centered"));
        new_op.SetAttr("type", type);
        new_op.SetAttr("weight_decay", 0.0f);
        new_op.SetAttr("alpha", decay);
        new_op.SetAttr("eps", epsilon);
        new_op.SetAttr("momentum", momentum);
        new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
        if (centered) {
          new_op.SetAttr("adaptive_mode", std::string{"centered_rmsprop"});
          new_op.SetAttr("raw_type", op_type);
        } else {
          new_op.SetAttr("adaptive_mode", std::string{"rmsprop"});
          new_op.SetAttr("raw_type", op_type);
        }
      } else if (is_regularization && op_type == "scale") {
        // set weight_decay for L2Decay
        auto scale = BOOST_GET_CONST(float, op->GetAttr("scale"));
        new_op.SetAttr("weight_decay", scale);
      } else if (is_grad_clip && op_type == "fill_constant") {
        // set clip_norm for ClipGradByGlobalNorm
        auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
        new_op.SetAttr("clip_norm", value);
      } else if (ignored_ops.count(op_type)) {
        VLOG(10) << "Ignore optimizer releated op: " << op_type;
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Unknown optimizer releated op_type: %s", op_type));
265
      }
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    } else if (op_role == OpRole::kLoss) {
      VLOG(10) << "found loss op type: " << op->Type();
      auto outputs = op->Outputs();
      PADDLE_ENFORCE_EQ(
          outputs.size(), 1,
          platform::errors::InvalidArgument("Can only support one loss key"));
      auto losses = outputs.begin()->second;
      PADDLE_ENFORCE_EQ(
          losses.size(), 1,
          platform::errors::InvalidArgument("Can only support one loss name"));
      auto loss_var = losses.front();
      new_op.SetAttr("loss_var", loss_var);
    } else if (op_role == OpRole::kLRSched) {
      // op_role == OpRole::kLRSched | OpRole::kOptimize
      new_op.SetAttr("with_lr_sched", true);
    }
  }

  // seems with_lr_sched is always true
  new_op.SetAttr("with_lr_sched", true);

A
Allen Guo 已提交
287 288 289 290
  // setup weight decay for Lamb
  new_op.SetAttr("weight_decay_vars", weight_decay_vars);
  new_op.SetAttr("weight_decay_values", weight_decay_values);

291 292 293 294 295 296 297 298 299 300 301
  // weight_decay/coeff is "scale" attr of scale_op
  if (set_ops.count("scale") && set_ops.count("sum")) {
    if (set_ops.count("sign")) {
      // L1Decay
      // sign + scale + sum
      PADDLE_THROW(
          platform::errors::Unimplemented("Unsupported L1Decay regularizer"));
    } else {
      // L2Decay
      // scale + sum
      new_op.SetAttr("weight_decay_mode", std::string{"l2_regularization"});
302
    }
303 304 305 306 307 308 309 310 311 312 313 314
  } else {
    VLOG(10) << "No weight deacy setting found";
  }

  // setup grad clip
  if (set_ops.count("clip")) {
    // ClipGradByValue
    PADDLE_THROW(
        platform::errors::Unimplemented("Unsupported ClipGradByValue"));
  } else if (set_ops.count("clip_by_norm")) {
    // ClipGradByNorm
    PADDLE_THROW(platform::errors::Unimplemented("Unsupported ClipGradByNorm"));
315 316
  }

317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
  // ClipGradByGlobalNorm
  // use graph pattern match ClipGradByGlobalNorm
  // square + reduce_sum + sum + sqrt + fill_constant
  // + elementwise_max + elementwise_div + elementwise_mul
  // clip_norm from fill_constant`s attr `value` dtype float

  if (new_op.HasAttr("type")) {
    auto new_node = graph->CreateOpNode(&new_op);
    VLOG(10) << "New Optimizer Node:";
    VLOG(10) << DebugString(new_node);
  } else {
    PADDLE_THROW(platform::errors::NotFound(
        "No optimizer found, optimizer must be one of these types: sgd, "
        "momentum, adam, adamw, adamax, lamb, adadelta, adagrad or rmsprop"));
  }
332 333 334 335 336 337 338 339
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(optimizer_extract_pass,
              paddle::framework::ir::IpuOptimizerExtractPass);