// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/program_desc.h"

namespace paddle {
namespace operators {

// OpVariant is a wrapper class of OpDesc and OperatorBase
// So that API would be the same.
class OpVariant {
  struct InputsVisitor
      : public boost::static_visitor<const framework::VariableNameMap *> {
    template <typename OpType>
    const framework::VariableNameMap *operator()(const OpType *op) const {
      return &(op->Inputs());
    }
  };

  struct OutputsVisitor
      : public boost::static_visitor<const framework::VariableNameMap *> {
    template <typename OpType>
    const framework::VariableNameMap *operator()(const OpType *op) const {
      return &(op->Outputs());
    }
  };

  struct AttributeMapVisitor
      : public boost::static_visitor<const framework::AttributeMap *> {
    const framework::AttributeMap *operator()(
        const framework::OpDesc *op) const {
      return &(op->GetAttrMap());
    }

    const framework::AttributeMap *operator()(
        const framework::OperatorBase *op) const {
      return &(op->Attrs());
    }
  };

  struct RawPointerVisitor : public boost::static_visitor<const void *> {
    template <typename OpType>
    const void *operator()(const OpType *op) const {
      return op;
    }
  };

 public:
  OpVariant(const framework::OperatorBase *op) : op_(op) {}  // NOLINT

  OpVariant(const framework::OpDesc *op) : op_(op) {}  // NOLINT

  const framework::VariableNameMap &Inputs() const {
    return *boost::apply_visitor(InputsVisitor(), op_);
  }

  const framework::VariableNameMap &Outputs() const {
    return *boost::apply_visitor(OutputsVisitor(), op_);
  }

  const framework::AttributeMap &Attrs() const {
    return *boost::apply_visitor(AttributeMapVisitor(), op_);
  }

  template <typename AttrType>
  const AttrType &Attr(const std::string &name) const {
    auto &attrs = Attrs();
    auto it = attrs.find(name);
    PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
    return boost::get<AttrType>(it->second);
  }

  bool operator==(const OpVariant &other) const {
    return RawPointer() == other.RawPointer();
  }

  const void *RawPointer() const {
    return boost::apply_visitor(RawPointerVisitor(), op_);
  }

  int which() const { return static_cast<int>(op_.which()); }

  struct Hasher {
    size_t operator()(const OpVariant &op) const {
      return reinterpret_cast<size_t>(op.RawPointer());
    }
  };

 private:
  const boost::variant<const framework::OperatorBase *,
                       const framework::OpDesc *>
      op_;
};

static std::string GetDebugString(const std::vector<std::string> &names) {
  if (names.empty()) return "";
  std::string ret = names[0];
  for (size_t i = 1; i < names.size(); ++i) {
    ret += (" " + names[i]);
  }
  return ret;
}

// Set skip variables of while_op and while_grad_op
// These variables should be skipped when eager deletion enables.
// It is because:
//  1. while_grad_op needs some variables defined in while_op.
//  2. while_grad_op needs variables from the previous time step.
static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
  auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
  VLOG(2) << "Prepare to skip " << attr.size()
          << " var(s): " << GetDebugString(attr);
  attrs[kSkipEagerDeletionVars] = std::move(attr);
}

// Check whether the forward while_op and while_grad_op match
// The program may have many while_ops.
static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
                                           const OpVariant &grad_op) {
  return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
         fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
}

// Test whether the variable is skippable in forward while_op
// The variable is skippable in while_op when the variable used in while_grad
// is not from grad_block.
static bool IsSkippableVar(const std::string &name,
                           framework::BlockDesc *grad_block) {
  return name != framework::kEmptyVarName && !grad_block->HasVar(name);
}

static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
                                            const OpVariant &bwd_op) {
  auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);

  // Find all skippable variables in forward while_op
  std::unordered_set<std::string> forward_skip_vars;
  for (auto *op_desc : grad_block->AllOps()) {
    for (auto &in_arg_name : op_desc->InputArgumentNames()) {
      if (IsSkippableVar(in_arg_name, grad_block)) {
        forward_skip_vars.insert(in_arg_name);
      }
    }

    for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
      if (IsSkippableVar(out_arg_name, grad_block)) {
        forward_skip_vars.insert(out_arg_name);
      }
    }
  }

  SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
                                               forward_skip_vars.end()));

  // Find all skippable variables in while_grad_op
  // The skipped variables are those which would be used across time steps.
  auto &fwd_input = fwd_op.Inputs().at(kX);
  auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
  PADDLE_ENFORCE_EQ(
      fwd_input.size(), in_grads.size(),
      "Backward input gradient number does not match forward input number.");

  std::unordered_set<std::string> backward_skip_vars;
  for (size_t i = 0; i < in_grads.size(); ++i) {
    if (in_grads[i] == framework::kEmptyVarName) {
      continue;
    }
    backward_skip_vars.insert(in_grads[i]);
    backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
  }

  SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
                                               backward_skip_vars.end()));
}

// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
                                       std::vector<OpVariant> *while_grad_ops) {
  PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());

  if (while_ops->empty()) return;

  const auto *program =
      while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
  for (size_t i = 1; i < program->Size(); ++i) {
    auto &block = program->Block(i);
    for (size_t j = 0; j < block.OpSize(); ++j) {
      auto *op = block.Op(j);
      if (op->Type() == "while") {
        while_ops->emplace_back(op);
      } else if (op->Type() == "while_grad") {
        while_grad_ops->emplace_back(op);
      }
    }
  }

  PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
                    "There are extra while_grad ops in the graph or program");
}

static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
    std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
  FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);

  VLOG(2) << "Found while op num: " << while_ops->size()
          << ", while grad op num: " << while_grad_ops->size();

  if (while_grad_ops->empty()) {
    return;
  }

  std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
      while_ops->begin(), while_ops->end());

  for (auto &bwd_op : *while_grad_ops) {
    const OpVariant *matched_fwd_op = nullptr;
    for (auto &fwd_op : while_op_set) {
      if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
        PADDLE_ENFORCE(matched_fwd_op == nullptr,
                       "Found multiple matched while ops");
        matched_fwd_op = &fwd_op;
      }
    }
    PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
                            "Cannot find matched forward while op.");
    ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
    while_op_set.erase(*matched_fwd_op);
  }
}

void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
    int block_id,
    const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
  // If block_id is not 0, returns
  // This is because all while_ops and while_grad_ops in the whole program
  // would be processed when block_id is 0 (i.e. when Executor::Run() or
  // ParallelExecutor constructs).

  // What's more, all while_ops and while_grad_ops must be processed when
  // block_id is zero. If not, while_op may run first and erase variables
  // used in while_grad_op, and in this moment, while_grad_ops may be not
  // constructed yet.
  if (block_id != 0) return;

  std::vector<OpVariant> fwd_ops, bwd_ops;
  for (auto &op : all_ops) {
    if (op->Type() == "while") {
      fwd_ops.emplace_back(op.get());
    } else if (op->Type() == "while_grad") {
      bwd_ops.emplace_back(op.get());
    }
  }
  PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}

void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
    const std::vector<framework::OperatorBase *> &while_ops,
    const std::vector<framework::OperatorBase *> &while_grad_ops) {
  std::vector<OpVariant> fwd_ops, bwd_ops;
  fwd_ops.reserve(while_ops.size());
  for (auto *op : while_ops) {
    fwd_ops.emplace_back(op);
  }

  bwd_ops.reserve(while_grad_ops.size());
  for (auto *op : while_grad_ops) {
    bwd_ops.emplace_back(op);
  }

  PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
}

}  // namespace operators
}  // namespace paddle