pd_dialect.cc 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/dialect/pd_dialect.h"
16
#include "paddle/fluid/dialect/pd_attribute.h"
17 18
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/dialect/CMakeLists.txt.
19
#include "paddle/fluid/dialect/pd_op.h"
20
#include "paddle/fluid/dialect/pd_type.h"
21
#include "paddle/fluid/dialect/pd_type_storage.h"
22 23 24
#include "paddle/fluid/dialect/utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
25
#include "paddle/ir/core/dialect_interface.h"
26 27 28 29 30
#include "paddle/phi/core/dense_tensor.h"

namespace paddle {
namespace dialect {
std::shared_ptr<paddle::framework::Variable>
Z
zhangbo9674 已提交
31
ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) {
32 33 34 35
  if (parameter->type().isa<DenseTensorType>()) {
    VLOG(4) << "Convert a DenseTensor Parameter to a variable.";
    std::shared_ptr<paddle::framework::Variable> var =
        std::make_shared<paddle::framework::Variable>();
Z
zhangbo9674 已提交
36
    phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
37
    // Init DenseTensor
38
    auto dim = parameter->type().dyn_cast<DenseTensorType>().dims();
39 40 41
    phi::DenseTensorMeta meta(
        TransToPhiDataType(
            parameter->type().dyn_cast<DenseTensorType>().dtype()),
42 43 44
        dim,

        parameter->type().dyn_cast<DenseTensorType>().data_layout(),
45 46 47
        parameter->type().dyn_cast<DenseTensorType>().lod(),
        parameter->type().dyn_cast<DenseTensorType>().offset());
    tensor->set_meta(meta);
Z
zhangbo9674 已提交
48
    paddle::platform::DeviceContext *dev_ctx =
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        paddle::platform::DeviceContextPool::Instance().Get(
            paddle::platform::CPUPlace());
    dev_ctx->Alloc(tensor,
                   TransToPhiDataType(
                       parameter->type().dyn_cast<DenseTensorType>().dtype()));
    memcpy(tensor->data(),
           parameter->data(),
           tensor->numel() * phi::SizeOf(tensor->dtype()));
    return var;
  } else {
    return nullptr;
  }
}

std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
Z
zhangbo9674 已提交
64
    paddle::framework::Variable *var) {
65
  if (var->IsType<phi::DenseTensor>()) {
Z
zhangbo9674 已提交
66
    phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
67
    // Get Meta
Z
zhangbo9674 已提交
68
    ir::IrContext *ctx = ir::IrContext::Instance();
69
    ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx);
Z
zhangbo9674 已提交
70
    void *data = tensor->data();
71 72 73 74 75 76
    ir::Type dense_tensor_type = DenseTensorType::get(ctx,
                                                      data_type,
                                                      tensor->dims(),
                                                      tensor->layout(),
                                                      tensor->lod(),
                                                      tensor->meta().offset);
77 78 79 80 81 82 83 84 85
    return std::make_unique<ir::Parameter>(
        data,
        tensor->numel() * phi::SizeOf(tensor->dtype()),
        dense_tensor_type);
  } else {
    return nullptr;
  }
}

Z
zhangbo9674 已提交
86
PaddleDialect::PaddleDialect(ir::IrContext *context)
87 88 89 90 91
    : ir::Dialect(name(), context, ir::TypeId::get<PaddleDialect>()) {
  initialize();
}

void PaddleDialect::initialize() {
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
  RegisterTypes<paddle::dialect::DenseTensorType>();

  RegisterAttributes<paddle::dialect::IntArrayAttribute,
                     paddle::dialect::ScalarAttribute,
                     paddle::dialect::DataTypeAttribute,
                     paddle::dialect::PlaceAttribute,
                     paddle::dialect::DataLayoutAttribute>();

  // NOTE(zhangbo9674): GET_OP_LIST is defined in pd_op.h which is
  // generated by op_gen.py, see details in
  // paddle/fluid/dialect/CMakeLists.txt.
  RegisterOps<
#define GET_OP_LIST
#include "paddle/fluid/dialect/pd_op.h"  // NOLINT
      >();

108
  RegisterInterfaces<ParameterConvertInterface>();
109 110
}

Z
zhangbo9674 已提交
111
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
112 113 114
  DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();

  os << "tensor<";
115
  for (auto d : phi::vectorize(tensor_type.dims())) {
116 117 118 119 120
    os << d;
    os << "x";
  }
  tensor_type.dtype().print(os);
  os << ">";
121 122 123 124
}

}  // namespace dialect
}  // namespace paddle