diff --git a/paddle/infrt/dialect/basic_kernels.td b/paddle/infrt/dialect/basic_kernels.td index 7d8de79fbae2b0cb36ca354b8f6f39fc94851ebe..32845a09351f70fe1acd7659b8c5e3a579ff83e0 100644 --- a/paddle/infrt/dialect/basic_kernels.td +++ b/paddle/infrt/dialect/basic_kernels.td @@ -106,10 +106,10 @@ class PrintOp : INFRT_Op<"print." # suffix> { let verifier = ?; } -//def PrintI32Op : PrintOp<"i32", I32>; -//def PrintI64Op : PrintOp<"i64", I64>; +def PrintI32Op : PrintOp<"i32", I32>; +def PrintI64Op : PrintOp<"i64", I64>; def PrintF32Op : PrintOp<"f32", F32>; -//def PrintF64Op : PrintOp<"f64", F64>; +def PrintF64Op : PrintOp<"f64", F64>; def GetStringOp : INFRT_Op<"get_string"> { let summary = "infrt.get_string"; diff --git a/paddle/infrt/dialect/dense_tensor.td b/paddle/infrt/dialect/dense_tensor.td index 07e70cb2ca1eeaa3ec6ddcc8f0057f1efc55fabc..7156e229512251c14ad300d05ccf1d7e7cd1b68c 100644 --- a/paddle/infrt/dialect/dense_tensor.td +++ b/paddle/infrt/dialect/dense_tensor.td @@ -112,23 +112,35 @@ def LoadParamsOp : DT_Op<"load_params", [NoSideEffect]> { let verifier = ?; } -def GetParamOp : DT_Op<"get_param", [NoSideEffect]> { - let summary = "dt.get_param operation"; +def TensorMapGetTensorOp : DT_Op<"tensor_map_get_tensor", [NoSideEffect]> { + let summary = "dt.tensor_map_get_tensor operation"; let description = [{ - An operation that can get a tensor from TensorMap. + An operation that can get a tensor from a TensorMap. }]; // input path of model params. let arguments = (ins TensorMapType:$map, - StrAttr:$name + StringType:$name ); let results = (outs TensorType:$output); - let assemblyFormat = "`(` $map `,` $name `)` attr-dict `->` type($output)"; + let assemblyFormat = "`(` operands `)` attr-dict `->` type($output)"; let verifier = ?; } +def TensorMapGetSizeOp : DT_Op<"tensor_map_get_size", [NoSideEffect]> { + let summary = "ddt.tensor_map_get_size operation"; + + let description = [{ + An operation that get the size of a TensorMap. + }]; + + let arguments = (ins TensorMapType:$map); + let results = (outs I32:$size); + let assemblyFormat = "`(` $map `)` attr-dict `->` type($size)"; +} + def GetTensorShapeOp : DT_Op<"get_tensor_shape", [NoSideEffect]> { let summary = "dt.get_tensor_shape operation"; @@ -141,10 +153,38 @@ def GetTensorShapeOp : DT_Op<"get_tensor_shape", [NoSideEffect]> { let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; } +class NaiveElementwiseAddOp : + DT_Op<"naive_elementwise_add." # dtype, [NoSideEffect]> { + let summary = "dt.naive_elementwise_add operation"; + + let description = [{ + Naive elementwise_add operation. + Just for testing. + }]; + let arguments = (ins TensorType:$a, TensorType:$b); + let results = (outs TensorType:$output); + let assemblyFormat = "`(` $a `,` $b `)` attr-dict `:` `(` type($a) `,` type($b) `)` `->` type($output)"; +} + +class NaiveMatmulOp : + DT_Op<"naive_matmul." # dtype, [NoSideEffect]> { + let summary = "dt.naive_matmul operation"; + + let description = [{ + Naive matmul operation. + Just for testing. + }]; + let arguments = (ins TensorType:$x, TensorType:$w); + let results = (outs TensorType:$output); + let assemblyFormat = "`(` $x `,` $w `)` attr-dict `:` `(` type($x) `,` type($w) `)` `->` type($output)"; +} + foreach dtype = ["ui8", "ui16", "ui32", "ui64", "i32", "f32", "f64", "i64"] in { def DT_CreateUninitTensorOp_#dtype : CreateUninitTensorOp; def DT_FillTensorOp_#dtype : FillTensorWithConstantOp; def DT_SetTensorOp_#dtype : SetTensorOp; + def DT_NaiveElementwiseAddOp_#dtype : NaiveElementwiseAddOp; + def DT_NaiveMatmulOp_#dtype : NaiveMatmulOp; } #endif // DT_OPS diff --git a/paddle/infrt/host_context/kernel_frame.h b/paddle/infrt/host_context/kernel_frame.h index 5186b88fe2c41a8b4939dd70fde9123549764856..298c40322b7ddc6f8f2cbad7a94e464501c090a6 100644 --- a/paddle/infrt/host_context/kernel_frame.h +++ b/paddle/infrt/host_context/kernel_frame.h @@ -37,6 +37,14 @@ class KernelFrame { (num_results_ == -1 ? 0 : num_results_); } + //! Get something at a specific position \p index. The element might be an + //! argument, an attribute or a result. + template + T& GetElementAt(int index) { + CHECK_LT(index, GetNumArgs() + GetNumAttributes() + GetNumResults()); + return value_or_attrs_[index]->template get_or_default(); + } + template T& GetArgAt(int index) { CHECK_LT(index, GetNumArgs()); diff --git a/paddle/infrt/host_context/kernel_utils.h b/paddle/infrt/host_context/kernel_utils.h index 33812912ba029c27cf60ff795bd61fd433ec8378..31d411006d2378eb77d254c76baf25809c79bb42 100644 --- a/paddle/infrt/host_context/kernel_utils.h +++ b/paddle/infrt/host_context/kernel_utils.h @@ -244,7 +244,7 @@ struct KernelImpl { static_assert(out_idx == 0, "Arguments should appear before results"); static_assert(const_idx == 0, "Arguments and results should appear before attributes."); - auto* arg = &frame->GetArgAt(in_idx); + auto* arg = &frame->template GetElementAt(in_idx); KernelCallHelper< Tail...>::template Invoke(frame, pargs..., diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index 5ed89e78f1152d7a20ece2ae1ffaffdd3f8dcc3b..000ce95b827f84039eea499452479c469d690bd7 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -28,6 +28,9 @@ #include "paddle/infrt/tensor/dense_tensor_view.h" #include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_shape.h" +// Disabled temporarily for failed compile, will enable latter. +// #include "paddle/pten/backends/cpu/cpu_context.h" +// #include "paddle/pten/core/dense_tensor.h" namespace infrt { namespace host_context { @@ -82,13 +85,25 @@ class Value : public common::Object { template const T& get() const { + CHECK(data.template is()); return data.get(); } + template T& get() { + CHECK(data.template is()); return data.get(); } + //! Get the value if assigned before or return a default value instead. + template + T& get_or_default() { + if (!data.template is()) { + this->set(T{}); + } + return get(); + } + template void set(T&& v) { data = std::move(v); @@ -124,6 +139,7 @@ class ValueRef : common::Shared { using common::Shared::Reset; using common::Shared::operator->; using common::Shared::operator*; + //! Get a readonly data. template const T& get() const { diff --git a/paddle/infrt/host_context/value_test.cc b/paddle/infrt/host_context/value_test.cc index 48d49478ce0efbbff172e2ca661a00d017f141b0..5ac9b60a22be2bbba3ef08a889729c62ddbbc660 100644 --- a/paddle/infrt/host_context/value_test.cc +++ b/paddle/infrt/host_context/value_test.cc @@ -30,5 +30,15 @@ TEST(ValueRef, test) { ASSERT_EQ(z.get(), true); } +// If the value is not assign, the get_or_default should return a default value. +TEST(Value, init) { + Value x; + ASSERT_EQ(x.get_or_default(), 0); + + Value tensor; + auto& t = tensor.get_or_default(); + ASSERT_EQ(t.shape().GetRank(), 0); +} + } // namespace host_context } // namespace infrt diff --git a/paddle/infrt/kernel/tensor_kernels.cc b/paddle/infrt/kernel/tensor_kernels.cc index 51e000492237435de555bc53bb63d23ce7ecbeb2..c6e28c4c79d29aac9cb1f536866eefddc55bf891 100644 --- a/paddle/infrt/kernel/tensor_kernels.cc +++ b/paddle/infrt/kernel/tensor_kernels.cc @@ -53,13 +53,62 @@ TensorMap LoadParams(const std::string &path) { return *(infrt::tensor::LoadParams(path)); } -DenseHostTensor GetParam(TensorMap map, Attribute nameAttr) { - auto &name = nameAttr.get(); - return *(map[name]); +void TensorMapGetTensor(TensorMap map, + const std::string &name, + DenseHostTensor *out) { + auto it = map.find(name); + CHECK(it != map.end()) << "No tensor called " << name << " in the TensorMap"; + *out = *it->second; } +int32_t TensorMapGetSize(TensorMap map) { return map.size(); } + DenseHostTensor ShallowCopyTensor(DenseHostTensor v) { return v; } +template +void NaiveElementwiseAdd(const DenseHostTensor &x, + const DenseHostTensor &y, + DenseHostTensor *out) { + CHECK_EQ(x.shape().GetNumElements(), y.shape().GetNumElements()); + + // Infer shape + *out = DenseHostTensor(x.shape(), GetDType()); + + const T *x_data = static_cast(x.raw_data()); + const T *y_data = static_cast(y.raw_data()); + T *out_data = static_cast(out->raw_data()); + for (size_t i = 0, n = x.shape().GetNumElements(); i < n; i++) { + out_data[i] = x_data[i] + y_data[i]; + } +} + +//! A naive implementation for x matmul w +template +void NaiveMatmul(const DenseHostTensor &x, + const DenseHostTensor &w, + DenseHostTensor *out) { + CHECK_EQ(x.shape().GetRank(), 2); + CHECK_EQ(w.shape().GetRank(), 2); + CHECK_EQ(x.shape().GetDim(x.shape().GetRank() - 1), w.shape().GetDim(0)); + std::vector out_dims({x.shape().GetDim(0), w.shape().GetDim(1)}); + *out = DenseHostTensor(TensorShape(out_dims), GetDType()); + + auto *out_data = static_cast(out->raw_data()); + auto *x_data = static_cast(x.raw_data()); + auto *w_data = static_cast(w.raw_data()); + + const int M = x.shape().GetDim(0); + const int K = x.shape().GetDim(1); + const int N = w.shape().GetDim(1); + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + out_data[i * N + j] += x_data[i * K + k] * w_data[k * N + j]; + } + } + } +} + /// ===== Kernel end ==== void RegisterTensorKernels(host_context::KernelRegistry *registry) { @@ -71,10 +120,20 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { INFRT_KERNEL(FillTensorWithConstant)); registry->AddKernel("dt.fill_tensor_with_constant.f64", INFRT_KERNEL(FillTensorWithConstant)); + + // TensorMap related methods. registry->AddKernel("dt.load_params", INFRT_KERNEL(LoadParams)); - registry->AddKernel("dt.get_param", INFRT_KERNEL(GetParam)); + registry->AddKernel("dt.tensor_map_get_tensor", + INFRT_KERNEL(TensorMapGetTensor)); + registry->AddKernel("dt.tensor_map_get_size", INFRT_KERNEL(TensorMapGetSize)); + registry->AddKernel("dt.shallow_copy_tensor", INFRT_KERNEL(ShallowCopyTensor)); + + // Naive kernels. + registry->AddKernel("dt.naive_elementwise_add.f32", + INFRT_KERNEL(NaiveElementwiseAdd)); + registry->AddKernel("dt.naive_matmul.f32", INFRT_KERNEL(NaiveMatmul)); } } // namespace kernel diff --git a/paddle/infrt/tests/.gitignore b/paddle/infrt/tests/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d641a9bc74b16fc7affdf72125cafb07d7196939 --- /dev/null +++ b/paddle/infrt/tests/.gitignore @@ -0,0 +1,7 @@ +.DS_Store +.idea +*.log +tmp/ + + +Output diff --git a/paddle/infrt/tests/CMakeLists.txt b/paddle/infrt/tests/CMakeLists.txt index a27fb3d8f18b001181b2bda0bf26b205a197d166..e5cc1ec1121fb7bbff2fad7856151916d8ea0924 100644 --- a/paddle/infrt/tests/CMakeLists.txt +++ b/paddle/infrt/tests/CMakeLists.txt @@ -2,3 +2,5 @@ configure_file(lit.cfg.py.in "${CMAKE_SOURCE_DIR}/paddle/infrt/tests/lit.cfg.py" add_test(NAME test_infrt_by_lit COMMAND sh -c "lit -v ${CMAKE_SOURCE_DIR}/paddle/infrt/tests --filter-out \"disabled_*\"" DEPENDS infrtopt infrtexec) + +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir.in ${CMAKE_CURRENT_SOURCE_DIR}/dialect/tensor/tensor_map.mlir) diff --git a/paddle/infrt/tests/dialect/tensor/.gitignore b/paddle/infrt/tests/dialect/tensor/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..488396e1e896d97a8259159ff14a428b1bbbe1d5 --- /dev/null +++ b/paddle/infrt/tests/dialect/tensor/.gitignore @@ -0,0 +1,5 @@ +.DS_Store +.idea +*.log +tmp/ +tensor_map.mlir diff --git a/paddle/infrt/tests/dialect/tensor/dense_tensor.mlir b/paddle/infrt/tests/dialect/tensor/dense_tensor.mlir new file mode 100644 index 0000000000000000000000000000000000000000..f1def17aa87961d70322ec20b4a86a018250e58d --- /dev/null +++ b/paddle/infrt/tests/dialect/tensor/dense_tensor.mlir @@ -0,0 +1,24 @@ +// RUN: infrtexec -i %s | FileCheck %s +// CHECK-LABEL: dense_shape0 +func @dense_shape0() { + %shape = ts.build_shape [1:i64, 57:i64] + %a = dt.create_uninit_tensor.f32 [12:i64, 23:i64] -> !infrt.tensor + + infrt.return +} + +func @predict(%a: !infrt.tensor, %b: !infrt.tensor) -> (!infrt.tensor, !infrt.tensor) { + %a0 = dt.shallow_copy_tensor %a : !infrt.tensor -> !infrt.tensor + %b0 = dt.shallow_copy_tensor %b : !infrt.tensor -> !infrt.tensor + + infrt.return %a0, %b0: !infrt.tensor, !infrt.tensor +} + + +func @main() { + %shape = ts.build_shape [1:i64, 57:i64] + %a = dt.create_uninit_tensor.f32 [12:i64, 23:i64] -> !infrt.tensor + + %b, %c = infrt.call @predict(%a, %a) : (!infrt.tensor, !infrt.tensor) -> (!infrt.tensor, !infrt.tensor) + infrt.return +} diff --git a/paddle/infrt/tests/dialect/tensor/naive_kernels.mlir b/paddle/infrt/tests/dialect/tensor/naive_kernels.mlir new file mode 100644 index 0000000000000000000000000000000000000000..914e863db49cca3320c74b11b624e3d7dfe3b6f8 --- /dev/null +++ b/paddle/infrt/tests/dialect/tensor/naive_kernels.mlir @@ -0,0 +1,35 @@ +// RUN: infrtexec -i %s | FileCheck %s +// CHECK-LABEL: naive_elementwise_add +func @naive_elementwise_add() { + // create a + %a = dt.create_uninit_tensor.f32 [2:i64, 8:i64] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%a : !infrt.tensor) {value=1.0:f32} + // create b + %b = dt.create_uninit_tensor.f32 [2:i64, 8:i64] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%b : !infrt.tensor) {value=2.0:f32} + // get c + %c = dt.naive_elementwise_add.f32(%a, %b) {} : (!infrt.tensor, !infrt.tensor) -> !infrt.tensor + + // CHECK: tensor: shape=shape[2,8], values=[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] + dt.print_tensor (%c : !infrt.tensor) + + infrt.return +} + +// RUN: infrtexec -i %s | FileCheck %s +// CHECK-LABEL: naive_matmul +func @naive_matmul() { + // create a + %a = dt.create_uninit_tensor.f32 [2:i64, 8:i64] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%a : !infrt.tensor) {value=1.0:f32} + // create b + %b = dt.create_uninit_tensor.f32 [8:i64, 4:i64] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%b : !infrt.tensor) {value=2.0:f32} + // get c + %c = dt.naive_matmul.f32(%a, %b) {} : (!infrt.tensor, !infrt.tensor) -> !infrt.tensor + + // CHECK: tensor: shape=shape[2,4], values=[16, 16, 16, 16, 16, 16, 16, 16] + dt.print_tensor (%c : !infrt.tensor) + + infrt.return +} diff --git a/paddle/infrt/tests/dialect/tensor/tensor_map.mlir.in b/paddle/infrt/tests/dialect/tensor/tensor_map.mlir.in new file mode 100644 index 0000000000000000000000000000000000000000..3baa6bcd42050237aa47d99fcbb5ab642f58d5a9 --- /dev/null +++ b/paddle/infrt/tests/dialect/tensor/tensor_map.mlir.in @@ -0,0 +1,16 @@ +// RUN: infrtexec -i %s | FileCheck %s + +func @load_tensor_map() { + %path = infrt.get_string("@CMAKE_BINARY_DIR@/multi_fc_model") + %map = dt.load_params(%path) + %size = dt.tensor_map_get_size(%map) -> i32 + infrt.print.i32 %size + + %tensor_name = infrt.get_string("fc_bias") + %a = dt.tensor_map_get_tensor(%map, %tensor_name) -> !infrt.tensor + + // CHECK: tensor: shape=shape[2], values=[0, 0] + dt.print_tensor (%a : !infrt.tensor) + + infrt.return +} diff --git a/paddle/infrt/tests/dialect/tensor/tensor_shape.mlir b/paddle/infrt/tests/dialect/tensor/tensor_shape.mlir new file mode 100644 index 0000000000000000000000000000000000000000..09210078b9d7d139f2bc2534acf07e83aa1146bb --- /dev/null +++ b/paddle/infrt/tests/dialect/tensor/tensor_shape.mlir @@ -0,0 +1,8 @@ +// RUN: infrtexec -i %s | FileCheck %s +// CHECK-LABEL: @build_tensor1 +func @build_tensor1() { + %a = ts.build_shape [1:i64, 57:i64, 92:i64] + // CHECK: shape[1,57,92] + ts.print_shape %a + infrt.return +} diff --git a/paddle/infrt/tests/dialect/tensor/tensor_type.mlir b/paddle/infrt/tests/dialect/tensor/tensor_type.mlir new file mode 100644 index 0000000000000000000000000000000000000000..01a2f7df32608ad64d2929b4b24f96cf4e5062c4 --- /dev/null +++ b/paddle/infrt/tests/dialect/tensor/tensor_type.mlir @@ -0,0 +1,10 @@ +// RUN: infrtexec -i %s | FileCheck %s +// CHECK-LABEL: test_tensor_type +func @test_tensor_type() { + %a = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%a : !infrt.tensor) {value=1.0:f32} + // CHECK: tensor: shape=shape[3,4], values=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + dt.print_tensor (%a : !infrt.tensor) + + infrt.return +} diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index e6e9759db8efb791ae7860d3c0f0027e6301c9c8..f76fa497d6a03e02a4c6c222be1973b7a853edab 100755 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -100,7 +100,17 @@ function infrt_gen_and_build() { echo "ipipe_log_param_Infrt_Build_Time: $[ $endTime_s - $startTime_s ]s" >> ${PADDLE_ROOT}/build/infrt_summary.txt } +function create_fake_models() { + cd ${PADDLE_ROOT}/build + # create multi_fc model, this will generate "multi_fc_model" + python3 -m pip uninstall -y paddlepaddle + python3 -m pip install paddlepaddle + python3 ${PADDLE_ROOT}/tools/infrt/fake_models/multi_fc.py +} + function test_infrt() { + create_fake_models + # install llvm-lit toolkit python3 -m pip install lit diff --git a/tools/infrt/fake_models/multi_fc.py b/tools/infrt/fake_models/multi_fc.py new file mode 100644 index 0000000000000000000000000000000000000000..03cf6828cc7e19eedd7b2cd1375b93859e9f3cfa --- /dev/null +++ b/tools/infrt/fake_models/multi_fc.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A fake model with multiple FC layers to test CINN on a more complex model. +""" +import numpy +import sys, os +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid.backward import append_backward + +size = 2 +num_layers = 4 +paddle.enable_static() + +a = fluid.layers.data(name="A", shape=[-1, size], dtype='float32') +label = fluid.layers.data(name="label", shape=[size], dtype='float32') + +fc_out = fluid.layers.fc(input=a, + size=size, + act="relu", + bias_attr=fluid.ParamAttr(name="fc_bias"), + num_flatten_dims=1) + +for i in range(num_layers - 1): + fc_out = fluid.layers.fc(input=fc_out, + size=size, + act="relu", + bias_attr=fluid.ParamAttr(name="fc_bias"), + num_flatten_dims=1) + +cost = fluid.layers.square_error_cost(fc_out, label) +avg_cost = fluid.layers.mean(cost) + +optimizer = fluid.optimizer.SGD(learning_rate=0.001) +optimizer.minimize(avg_cost) + +cpu = fluid.core.CPUPlace() +loss = exe = fluid.Executor(cpu) + +exe.run(fluid.default_startup_program()) + +fluid.io.save_inference_model("./multi_fc_model", [a.name], [fc_out], exe) +print('output name', fc_out.name)