program_translator.h 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/program_desc.h"
23 24 25
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/value.h"
26 27 28 29

namespace paddle {
namespace translator {

30 31 32 33 34 35 36 37 38 39 40 41
struct VariableDefiningInfo {
  VariableDefiningInfo(ir::OpResult value,
                       bool generated_by_vector = false,
                       int idx_in_vector = -1)
      : value(value),
        generated_by_vector(generated_by_vector),
        idx_in_vector(idx_in_vector) {}
  VariableDefiningInfo() {}

  ir::OpResult value;

  bool generated_by_vector =
42
      false;  // true if target variabe is generated by Vector<Tensor>
43
  int idx_in_vector =
44
      -1;  // positive if target variabe is generated by Vector<Tensor>
45 46 47 48
};

using TranslationContext =
    std::unordered_map<std::string, VariableDefiningInfo>;
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65

class ProgramTranslator {
  using ProgramDesc = ::paddle::framework::ProgramDesc;
  using BlockDesc = ::paddle::framework::BlockDesc;

 public:
  explicit ProgramTranslator(const ProgramDesc* legacy_program,
                             ir::Program* program);

  void Translate();

 private:
  const ProgramDesc* legacy_program;
  ir::Program* program;
  TranslationContext param_map;
  ir::IrContext* ctx;

66 67 68 69 70 71 72 73
  /// In the legacy program desc, there are two special named varibales:
  /// 1. "feed", the input variable of feed op
  /// 2. "fetch", the output variable of fetch op
  /// However, new feed has no input and new fetch has no output
  /// So we don't handle these two vairables when
  /// `ExtractParameterFromSingleBlock`
  static const std::unordered_set<std::string> no_cast_var_names;

74 75 76 77 78 79
  void ExtractParameterFromSingleBlock(const BlockDesc& block);
  void InsertOperationToSingleBlock(const BlockDesc& block);
};

}  // namespace translator
}  // namespace paddle