diff --git a/paddle/infrt/dialect/phi/CMakeLists.txt b/paddle/infrt/dialect/phi/CMakeLists.txt index 67f6bb8a2d7bbfa604614e4909169c08ea18e1b3..436ff0a40480cfe2d0907309a8eecf8cdc32adb2 100644 --- a/paddle/infrt/dialect/phi/CMakeLists.txt +++ b/paddle/infrt/dialect/phi/CMakeLists.txt @@ -5,6 +5,10 @@ endif() add_subdirectory(ir) add_subdirectory(pass) +add_executable(phi-ir-exec phi_ir_exec.cc) +target_link_libraries(phi-ir-exec infrt) + + add_executable(phi-exec phi_exec.cc) target_link_libraries(phi-exec infrt) diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_base.td b/paddle/infrt/dialect/phi/ir/infrt_phi_base.td index 8e21283183d036ac26c117a0a209ba92d1f9febc..376d62deecee7cfb958f7dbb180b5936091f8acf 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_base.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_base.td @@ -18,8 +18,8 @@ def PHI_Dialect : Dialect { def PhiOpTrait : NativeOpTrait<"PhiOpTrait">; -class PHI_Type traits = []> - : TypeDef {} +class PHI_Type traits = [], string baseCppClass = "::mlir::Type"> + : TypeDef {} def Allocator : PHI_Type<"Allocator"> { let mnemonic = "allocator"; diff --git a/paddle/infrt/host_context/paddle_mlir.cc b/paddle/infrt/host_context/paddle_mlir.cc index e161dc47075bb3e87399477b3112a4c4c57cec1c..ec12815e3ce94f52b987e845363880cfc3896387 100644 --- a/paddle/infrt/host_context/paddle_mlir.cc +++ b/paddle/infrt/host_context/paddle_mlir.cc @@ -16,6 +16,7 @@ #include "paddle/infrt/dialect/infrt/ir/basic_kernels.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/pd/common/pd_ops_info.h" +#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" MLIRModelGenImpl::MLIRModelGenImpl() : context_(infrt::Global::getMLIRContext()), builder_(context_) { @@ -24,6 +25,8 @@ MLIRModelGenImpl::MLIRModelGenImpl() context_->getOrLoadDialect(); context_->getOrLoadDialect(); context_->getOrLoadDialect<::infrt::InfrtDialect>(); + context_->getOrLoadDialect<::infrt::phi::PHIDialect>(); + context_->getOrLoadDialect<::infrt::phi::PHIDenseTensorDialect>(); module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(context_)); } @@ -79,7 +82,7 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule( llvm::SmallVector MLIRModelGenImpl::GetModelInputsType( const infrt::paddle::framework_proto::ProgramDesc &program) { llvm::SmallVector operandTypes; - operandTypes.push_back(infrt::DenseHostTensorMapType::get(context_)); + operandTypes.push_back(infrt::phi::DenseTensorMapType::get(context_)); for (auto &op_desc : main_block_.ops()) { if (op_desc.type() != "feed") continue; for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) { @@ -180,7 +183,7 @@ void MLIRModelGenImpl::UpdateModelParams( &precision_); mlir::Type type_ = infrt::DenseTensorType::get( context_, infrt::TargetType::CPU, precision_, infrt::LayoutType::ANY); - auto op = builder_.create( + auto op = builder_.create<::infrt::phi::TensorMapGetTensorOp>( mlir::UnknownLoc::get(context_), type_, map, name); params_map_.insert(std::pair( var_desc.name(), op.getOperation()->getResult(0))); diff --git a/paddle/infrt/tests/CMakeLists.txt b/paddle/infrt/tests/CMakeLists.txt index 58543a6864258bd6c0153150bb535262d9a8f00d..6f839cdc3954939e8c8d4792facac5a284d25f3f 100644 --- a/paddle/infrt/tests/CMakeLists.txt +++ b/paddle/infrt/tests/CMakeLists.txt @@ -6,3 +6,4 @@ add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle DEPENDS infrtopt infrtexec) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/phi/linear_cpu.mlir) diff --git a/paddle/infrt/tests/dialect/phi/linear_cpu.mlir.in b/paddle/infrt/tests/dialect/phi/linear_cpu.mlir.in new file mode 100644 index 0000000000000000000000000000000000000000..7ca33fa10a90d4dffef02526b2c19744e388a6aa --- /dev/null +++ b/paddle/infrt/tests/dialect/phi/linear_cpu.mlir.in @@ -0,0 +1,19 @@ +// RUN: infrtexec -i %s +module { + func @main_graph(%arg0: !phi.dense_tensor_map, %arg1: !infrt.dense_tensor) -> !infrt.dense_tensor { + %0 = phi_dt.tensor_map_get_tensor(%arg0) {name = "linear_0.w_0"} -> !infrt.dense_tensor + %1 = phi_dt.tensor_map_get_tensor(%arg0) {name = "linear_0.b_0"} -> !infrt.dense_tensor + %2 = "phi_dt.create_context.cpu"() : () -> !phi.context + %5 = "phi_cpu.matmul.float32.any"(%2, %arg1, %0) {trans_x = false, trans_y = false} : (!phi.context, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + %7 = "phi_cpu.add.float32.any"(%2, %5, %1): (!phi.context, !infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor + infrt.return %7 : !infrt.dense_tensor + } + func @main() { + %ctx = "phi_dt.create_context.cpu" (): () -> !phi.context + %1 = "phi_dt.create_dense_tensor.cpu" (%ctx) {precision=#infrt.precision, layout=#infrt.layout, lod=[1:i64], dims=[16:i64, 784:i64]}: (!phi.context) -> (!infrt.dense_tensor) + %map = phi_dt.load_combined_params(){model_path="@CMAKE_BINARY_DIR@/linear/linear.pdmodel",params_path="@CMAKE_BINARY_DIR@/linear/linear.pdiparams"} + %2 = infrt.call@main_graph(%map, %1) : (!phi.dense_tensor_map, !infrt.dense_tensor) -> !infrt.dense_tensor + phi_dt.print_tensor (%2 : !infrt.dense_tensor) + infrt.return + } +} diff --git a/paddle/infrt/tests/model/linear.py b/paddle/infrt/tests/model/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..602e067365b87318ecb847d13832810aa1db4593 --- /dev/null +++ b/paddle/infrt/tests/model/linear.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# example 1: save layer +import numpy as np +import paddle +import paddle.nn as nn +import paddle.optimizer as opt + +BATCH_SIZE = 16 +BATCH_NUM = 4 +EPOCH_NUM = 4 + +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +# define a random dataset +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + +class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + + @paddle.jit.to_static + def forward(self, x): + return self._linear(x) + + +def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + + +# 1. train & save model. + +# create network +layer = LinearNet() +loss_fn = nn.CrossEntropyLoss() +adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters()) + +# create data loader +dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) +loader = paddle.io.DataLoader( + dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2) + +# train +train(layer, loader, loss_fn, adam) + +# save +path = "linear/linear" +paddle.jit.save(layer, path) diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index 1b259023f94df7279066533bb6c182a644b4e9c2..37e19b49f1cd03dc08dadd977358118a3190289c 100755 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -114,6 +114,7 @@ function create_fake_models() { python3 -m pip install *whl cd ${PADDLE_ROOT}/build python3 ${PADDLE_ROOT}/tools/infrt/fake_models/multi_fc.py + python3 ${PADDLE_ROOT}/paddle/infrt/tests/model/linear.py } function test_infrt() {