提交 04b9bfbd 编写于 作者: J jim19930609

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

include(FetchContent) include(FetchContent)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz) set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/infrt/llvm_b5149f4e66a49a98b67e8e2de4e24a4af8e2781b.tar.gz)
set(LLVM_MD5 39d32b6be466781dddf5869318dcba53) set(LLVM_MD5 022819bb5760817013cf4b8a37e97d5e)
set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm) set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm)
set(FETCHCONTENT_QUIET OFF) set(FETCHCONTENT_QUIET OFF)
...@@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") ...@@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
# To build with MLIR, the LLVM is build from source code using the following flags: # To build with MLIR, the LLVM is build from source code using the following flags:
#[==[ #[==[
cmake -G Ninja ../llvm \ cmake ../llvm -G "Unix Makefiles" \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_BUILD_EXAMPLES=OFF \ -DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="X86" \ -DLLVM_TARGETS_TO_BUILD="X86" \
...@@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \ ...@@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_ZLIB=OFF \ -DLLVM_ENABLE_ZLIB=OFF \
-DLLVM_ENABLE_RTTI=ON \ -DLLVM_ENABLE_RTTI=ON \
-DLLVM_INSTALL_UTILS=ON \
-DCMAKE_INSTALL_PREFIX=./install
#]==] #]==]
# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit) # The matched llvm-project version is b5149f4e66a49a98b67e8e2de4e24a4af8e2781b (currently a temporary commit)
add_definitions(${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS})
...@@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS}) ...@@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS})
# The minimum needed libraries for MLIR IR parse and transform. # The minimum needed libraries for MLIR IR parse and transform.
set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib) set(MLIR_IR_LIBS MLIRAnalysis MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
# tb_base is the name of a xxx.td file (without the .td suffix) # tb_base is the name of a xxx.td file (without the .td suffix)
...@@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base) ...@@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base)
mlir_tablegen(${td_base}.cpp.inc -gen-op-defs) mlir_tablegen(${td_base}.cpp.inc -gen-op-defs)
if (mlir_tablegen_on_DIALECT) if (mlir_tablegen_on_DIALECT)
mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT}) mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT})
mlir_tablegen(${td_base}_dialect.cpp.inc --gen-dialect-defs -dialect=${mlir_tablegen_on_DIALECT})
endif() endif()
add_public_tablegen_target(${td_base}_IncGen) add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
......
...@@ -46,8 +46,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -46,8 +46,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
<< " is diabled by config in TensorRT"; << " is diabled by config in TensorRT";
return false; return false;
} }
return tensorrt::OpTeller::Global().Tell(node, no_calib_int8, bool is_ok = tensorrt::OpTeller::Global().Tell(node, no_calib_int8,
with_dynamic_shape); with_dynamic_shape);
if (!is_ok)
VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT";
return is_ok;
}; };
framework::ir::SubGraphFuser fuser( framework::ir::SubGraphFuser fuser(
......
...@@ -1416,6 +1416,7 @@ USE_TRT_CONVERTER(elementwise_min_tensor); ...@@ -1416,6 +1416,7 @@ USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor); USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(transpose); USE_TRT_CONVERTER(transpose);
USE_TRT_CONVERTER(flatten); USE_TRT_CONVERTER(flatten);
USE_TRT_CONVERTER(flatten_contiguous_range);
USE_TRT_CONVERTER(matmul); USE_TRT_CONVERTER(matmul);
USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu); USE_TRT_CONVERTER(relu);
......
...@@ -3,7 +3,7 @@ nv_library(tensorrt_converter ...@@ -3,7 +3,7 @@ nv_library(tensorrt_converter
SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc gather_op.cc
anchor_generator_op.cc anchor_generator_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* flatten_contiguous_range trt converter
*/
class FlattenContiguousRangeOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
int dims = input->getDimensions().nbDims;
int start_axis = BOOST_GET_CONST(int, op_desc.GetAttr("start_axis"));
int stop_axis = BOOST_GET_CONST(int, op_desc.GetAttr("stop_axis"));
nvinfer1::IShuffleLayer* layer = nullptr;
if (!engine_->with_dynamic_shape()) {
if (start_axis < 0) start_axis += dims + 1;
if (stop_axis < 0) stop_axis += dims + 1;
int dim_prod = 1;
nvinfer1::Dims flatten_dim;
flatten_dim.nbDims = dims - (stop_axis - start_axis);
for (int i = 0, j = 0; i < dims; ++i) {
if (start_axis <= i + 1 && i + 1 <= stop_axis) {
int dim_i = input->getDimensions().d[i];
PADDLE_ENFORCE_GT(dim_i, 0, platform::errors::InvalidArgument(
"flatten_contiguous_range input dim "
"should be > 0, but got %d.",
dim_i));
dim_prod *= dim_i;
if (i + 1 == stop_axis) {
flatten_dim.d[j++] = dim_prod;
}
} else {
flatten_dim.d[j++] = input->getDimensions().d[i];
}
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setReshapeDimensions(flatten_dim);
} else {
if (start_axis < 0) start_axis += dims;
if (stop_axis < 0) stop_axis += dims;
auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* shape_layer_itensor = shape_layer->getOutput(0);
nvinfer1::Dims start_dim, size_dim, stride_dim;
start_dim.nbDims = 1;
size_dim.nbDims = 1;
stride_dim.nbDims = 1;
start_dim.d[0] = start_axis;
size_dim.d[0] = stop_axis - start_axis + 1;
stride_dim.d[0] = 1;
auto* slice_layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor, start_dim,
size_dim, stride_dim);
uint32_t reduce_dim = 1;
auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *(slice_layer->getOutput(0)),
nvinfer1::ReduceOperation::kPROD, reduce_dim, true);
nvinfer1::ITensor* input_shape = nullptr;
if (start_axis == 0 && stop_axis == dims - 1) {
input_shape = reduce_prod_layer->getOutput(0);
} else {
std::vector<nvinfer1::ITensor*> itensors;
if (start_axis > 0) {
nvinfer1::Dims left_start_dim, left_size_dim, left_stride_dim;
left_start_dim.nbDims = 1;
left_size_dim.nbDims = 1;
left_stride_dim.nbDims = 1;
left_start_dim.d[0] = 0;
left_size_dim.d[0] = start_axis;
left_stride_dim.d[0] = 1;
auto* slice_layer_left = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *shape_layer_itensor, left_start_dim,
left_size_dim, left_stride_dim);
itensors.push_back(slice_layer_left->getOutput(0));
}
itensors.push_back(reduce_prod_layer->getOutput(0));
if (stop_axis < dims - 1) {
nvinfer1::Dims right_start_dim, right_size_dim, right_stride_dim;
right_start_dim.nbDims = 1;
right_size_dim.nbDims = 1;
right_stride_dim.nbDims = 1;
right_start_dim.d[0] = stop_axis + 1;
right_size_dim.d[0] = dims - stop_axis - 1;
right_stride_dim.d[0] = 1;
auto* slice_layer_right = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *shape_layer_itensor, right_start_dim,
right_size_dim, right_stride_dim);
itensors.push_back(slice_layer_right->getOutput(0));
}
auto* concat_layer = TRT_ENGINE_ADD_LAYER(
engine_, Concatenation, itensors.data(), itensors.size());
concat_layer->setAxis(0);
input_shape = concat_layer->getOutput(0);
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setInput(1, *input_shape);
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "flatten_contiguous_range", {output_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(flatten_contiguous_range,
FlattenContiguousRangeOpConverter);
...@@ -55,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -55,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller {
// #endif // #endif
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
teller_set.insert("tile"); teller_set.insert("tile");
teller_set.insert("flatten_contiguous_range");
#endif #endif
#if CUDA_VERSION >= 10020 #if CUDA_VERSION >= 10020
teller_set.insert("reshape"); teller_set.insert("reshape");
...@@ -531,6 +532,37 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -531,6 +532,37 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (axis != 1) return false; if (axis != 1) return false;
} }
} }
if (op_type == "flatten_contiguous_range") {
if (!with_dynamic_shape) {
int start_axis = BOOST_GET_CONST(int, desc.GetAttr("start_axis"));
int stop_axis = BOOST_GET_CONST(int, desc.GetAttr("stop_axis"));
auto x_var_name = desc.Input("X")[0];
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
int dims = x_shape.size();
if (start_axis < 0) start_axis += dims;
if (start_axis == 0) {
VLOG(3) << "TRT flatten_contiguous_range not support the "
"batch-dimension being changed";
return false;
}
if (stop_axis < 0) stop_axis += dims;
for (int i = start_axis; i <= stop_axis; ++i) {
if (x_shape[i] < 0) {
VLOG(3) << "On TRT static shape,flatten_contiguous_range input dim "
"should be > 0";
return false;
}
}
}
}
if (op_type == "gather") { if (op_type == "gather") {
auto gather_inputs = desc.Inputs(); auto gather_inputs = desc.Inputs();
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,9 +12,12 @@ ...@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/stack_op.h" #include "paddle/fluid/operators/stack_op.h"
#include <string> #include <string>
#ifdef PADDLE_WITH_XPU #include <vector>
#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -59,14 +62,44 @@ class StackXPUKernel : public framework::OpKernel<T> { ...@@ -59,14 +62,44 @@ class StackXPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class StackGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dy_dims = dy->dims();
if (axis < 0) axis += dy_dims.size() + 1;
auto dy_shape = framework::vectorize<int>(dy_dims);
std::vector<int> dx_dims_list(dx.size(), 1);
std::vector<T*> dx_lists;
for (auto out : dx) {
dx_lists.push_back(out->mutable_data<T>(ctx.GetPlace()));
}
int r = xpu::split<T>(dev_ctx.x_context(), dy->data<T>(), dx_lists,
dy_shape, dx_dims_list, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"The stack_grad XPU kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace plat = paddle::platform; namespace plat = paddle::platform;
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(stack, REGISTER_OP_XPU_KERNEL(stack,
ops::StackXPUKernel<plat::XPUDeviceContext, int64_t>, ops::StackXPUKernel<plat::XPUDeviceContext, float>,
ops::StackXPUKernel<plat::XPUDeviceContext, int>, ops::StackXPUKernel<plat::XPUDeviceContext, int>,
ops::StackXPUKernel<plat::XPUDeviceContext, float>); ops::StackXPUKernel<plat::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(stack_grad,
ops::StackGradXPUKernel<plat::XPUDeviceContext, float>,
ops::StackGradXPUKernel<plat::XPUDeviceContext, int>);
#endif #endif
...@@ -300,6 +300,7 @@ XPUOpMap& get_kl1_ops() { ...@@ -300,6 +300,7 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
...@@ -333,6 +333,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -333,6 +333,8 @@ XPUOpMap& get_kl2_ops() {
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})}, pOpKernelType(vartype::INT32, XPUPlace())})},
{"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
......
...@@ -77,7 +77,6 @@ add_subdirectory(paddle) ...@@ -77,7 +77,6 @@ add_subdirectory(paddle)
# MLIR td file generations # MLIR td file generations
set(infrt_mlir_incs set(infrt_mlir_incs
ops_inc
basic_kernels_inc basic_kernels_inc
test_kernels_inc test_kernels_inc
infrt_base_inc infrt_base_inc
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "mlir/IR/MLIRContext.h" #include <mlir/IR/MLIRContext.h>
#include "paddle/infrt/tensor/dense_host_tensor.h" #include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt { namespace infrt {
......
...@@ -2,7 +2,6 @@ core_gather_headers() ...@@ -2,7 +2,6 @@ core_gather_headers()
gather_srcs(infrt_src SRCS gather_srcs(infrt_src SRCS
dialect.cc dialect.cc
types.cc
basic_kernels.cc basic_kernels.cc
test_kernels.cc test_kernels.cc
infrt_base.cc infrt_base.cc
...@@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS ...@@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS
pd_types.cc pd_types.cc
pd_ops.cc pd_ops.cc
) )
mlir_tablegen_on(ops)
mlir_tablegen_on(basic_kernels) mlir_tablegen_on(basic_kernels)
mlir_tablegen_on(test_kernels) mlir_tablegen_on(test_kernels)
mlir_tablegen_on(infrt_base DIALECT infrt) mlir_tablegen_on(infrt_base DIALECT infrt)
...@@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite) ...@@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite)
# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code # TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code
add_executable(infrtopt opt.cc) add_executable(infrtopt opt.cc)
target_link_libraries(infrtopt infrt ${mlir_libs}) target_link_libraries(infrtopt infrt)
add_dependencies(infrtopt infrt)
add_executable(print-ir print_ir.cc) add_executable(print-ir print_ir.cc)
target_link_libraries(print-ir infrt ${mlir_libs}) target_link_libraries(print-ir infrt ${mlir_libs})
......
...@@ -17,17 +17,17 @@ ...@@ -17,17 +17,17 @@
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h> #include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h> #include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Module.h> #include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h> #include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h> #include <mlir/Support/LogicalResult.h>
#include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/dense_tensor.h"
namespace infrt::dialect { namespace infrt {
namespace dialect {
using namespace mlir; // NOLINT using namespace mlir; // NOLINT
static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT
...@@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT ...@@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT
static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT OperationState &result) { // NOLINT
return parseConstantOp( return parseConstantOp(
IntegerType::get(32, result.getContext()), parser, result); IntegerType::get(result.getContext(), 32), parser, result);
} }
static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT OperationState &result) { // NOLINT
return parseConstantOp( return parseConstantOp(
IntegerType::get(64, result.getContext()), parser, result); IntegerType::get(result.getContext(), 64), parser, result);
} }
static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
...@@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT ...@@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
} }
static void print(OpAsmPrinter &p, CallOp op) { // NOLINT static void print(OpAsmPrinter &p, CallOp op) { // NOLINT
p << "infrt.call " << op.getAttr("callee") << "("; p << "infrt.call " << op->getAttr("callee") << "(";
p.printOperands(op.getOperands()); p.printOperands(op.getOperands());
p << ")"; p << ")";
p.printOptionalAttrDict(op.getAttrs(), {"callee"}); p.printOptionalAttrDict(op->getAttrs(), {"callee"});
p << " : "; p << " : ";
} }
...@@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); } ...@@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); }
static LogicalResult verify(ConstantI64Op op) { return success(); } static LogicalResult verify(ConstantI64Op op) { return success(); }
static LogicalResult verify(ReturnOp op) { static LogicalResult verify(ReturnOp op) {
auto function = dyn_cast<FuncOp>(op.getParentOp()); auto function = dyn_cast<FuncOp>(op->getParentOp());
if (!function) return success(); if (!function) return success();
...@@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) { ...@@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) {
return success(); return success();
} }
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.cpp.inc" #include "paddle/infrt/dialect/basic_kernels.cpp.inc"
} // namespace infrt::dialect
...@@ -13,12 +13,9 @@ ...@@ -13,12 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
using namespace mlir; // NOLINT
namespace infrt::dialect {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.hpp.inc" #include "paddle/infrt/dialect/basic_kernels.hpp.inc"
} // namespace infrt::dialect
...@@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> { ...@@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> {
let results = (outs Variadic<AnyType>); let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
StringRef getCallee() { return callee(); } mlir::StringRef getCallee() { return callee(); }
mlir::FunctionType getCalleeType(); mlir::FunctionType getCalleeType();
}]; }];
} }
...@@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> { ...@@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> {
let arguments = (ins Variadic<AnyType>:$operands); let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder< let builders = [OpBuilder<(ins),
"OpBuilder &b, OperationState &result", [{ build($_builder, $_state, llvm::None); }]>];
[{ build(b, result, llvm::None); }]>];
} }
class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> { class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> {
......
...@@ -17,12 +17,11 @@ ...@@ -17,12 +17,11 @@
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h> #include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h> #include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h> #include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h> #include <mlir/Support/LogicalResult.h>
...@@ -31,68 +30,37 @@ ...@@ -31,68 +30,37 @@
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h" #include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt::dt { namespace infrt {
namespace dt {
void DTDialect::initialize() { void DTDialect::initialize() {
allowUnknownTypes();
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" #include "paddle/infrt/dialect/dense_tensor.cpp.inc"
>(); >();
} }
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) { llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) {
if (key.equals_lower("x86")) if (key.equals_insensitive("x86"))
return TargetType::X86; return TargetType::X86;
else if (key.equals_lower("cuda")) else if (key.equals_insensitive("cuda"))
return TargetType::CUDA; return TargetType::CUDA;
else else
return llvm::None; return llvm::None;
} }
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) { llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) {
if (key.equals_lower("nchw")) if (key.equals_insensitive("nchw"))
return LayoutType::NCHW; return LayoutType::NCHW;
else if (key.equals_lower("nhwc")) else if (key.equals_insensitive("nhwc"))
return LayoutType::NHWC; return LayoutType::NHWC;
else else
return llvm::None; return llvm::None;
} }
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) { llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) {
if (key.equals_lower("i32")) if (key.equals_insensitive("i32"))
return PrecisionType::I32; return PrecisionType::I32;
else if (key.equals_lower("f32")) else if (key.equals_insensitive("f32"))
return PrecisionType::F32; return PrecisionType::F32;
else else
return llvm::None; return llvm::None;
...@@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; } ...@@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; }
PrecisionType TensorType::precision() { return getImpl()->precision_; } PrecisionType TensorType::precision() { return getImpl()->precision_; }
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) { mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType) {
os << "TensorType<" << tensorType.target() << ", " << tensorType.layout() os << "TensorType<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">"; << ", " << tensorType.precision() << ">";
return os; return os;
...@@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) { ...@@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) {
return Base::get(context); return Base::get(context);
} }
raw_ostream &operator<<(raw_ostream &os, TargetType type) { mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type) {
switch (type) { switch (type) {
case (TargetType::X86): case (TargetType::X86):
os << "X86"; os << "X86";
...@@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) { ...@@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) {
return os; return os;
} }
raw_ostream &operator<<(raw_ostream &os, LayoutType type) { mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type) {
switch (type) { switch (type) {
case (LayoutType::NCHW): case (LayoutType::NCHW):
os << "NCHW"; os << "NCHW";
...@@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) { ...@@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
return os; return os;
} }
raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type) {
switch (type) { switch (type) {
case (PrecisionType::I32): case (PrecisionType::I32):
os << "I32"; os << "I32";
...@@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { ...@@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
return os; return os;
} }
static Type getTensorType(mlir::MLIRContext *context) { static mlir::Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = Identifier::get("t", context); auto t_dialect = mlir::Identifier::get("t", context);
return OpaqueType::get(t_dialect, "tensor", context); return mlir::OpaqueType::get(t_dialect, "tensor");
} }
static ParseResult parseCreateUninitTensorOp( static mlir::ParseResult parseCreateUninitTensorOp(
OpAsmParser &parser, // NOLINT mlir::OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT mlir::OperationState &result) { // NOLINT
auto loc = parser.getCurrentLocation(); auto loc = parser.getCurrentLocation();
::mlir::Type outputRawTypes[1]; mlir::Type outputRawTypes[1];
::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); ::llvm::ArrayRef<mlir::Type> outputTypes(outputRawTypes);
mlir::ArrayAttr shapeAttr; mlir::ArrayAttr shapeAttr;
if (parser.parseAttribute(shapeAttr, if (parser.parseAttribute(shapeAttr,
parser.getBuilder().getI64Type(), parser.getBuilder().getI64Type(),
"shape", "shape",
result.attributes)) result.attributes))
return failure(); return mlir::failure();
if (parser.parseOptionalAttrDict(result.attributes)) return failure(); if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure();
if (parser.parseArrow()) return failure(); if (parser.parseArrow()) return mlir::failure();
if (parser.parseType(outputRawTypes[0])) return failure(); if (parser.parseType(outputRawTypes[0])) return mlir::failure();
if (!outputRawTypes[0].isa<TensorType>()) if (!outputRawTypes[0].isa<TensorType>())
return parser.emitError(loc, "invalid kind of type specified"); return parser.emitError(loc, "invalid kind of type specified");
result.addTypes(outputTypes); result.addTypes(outputTypes);
return success(); return mlir::success();
} }
template <typename CreateUninitTensorOp> template <typename CreateUninitTensorOp>
static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p, // NOLINT
CreateUninitTensorOp op) { CreateUninitTensorOp op) {
p << CreateUninitTensorOp::getOperationName(); p << CreateUninitTensorOp::getOperationName();
p << " "; p << " ";
p.printAttributeWithoutType(op.shapeAttr()); p.printAttributeWithoutType(op.shapeAttr());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
p << " -> "; p << " -> ";
p << op.getOperation()->getResultTypes(); p << op.getOperation()->getResultTypes();
} }
// TODO(shibo): can be removed? static mlir::ParseResult parseSetTensorOp(
// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser, mlir::OpAsmParser &parser, // NOLINT
// OperationState& result) { mlir::OperationState &result) { // NOLINT
// auto loc = parser.getCurrentLocation(); llvm::SmallVector<mlir::OpAsmParser::OperandType, 1> operands;
// ::mlir::OpAsmParser::OperandType inputRawOperands[1]; if (parser.parseOperandList(operands, 1)) return mlir::failure();
// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType>
// inputOperands(inputRawOperands);
// ::mlir::Type inputRawTypes[1];
// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes);
//
// if (parser.parseOperand(inputRawOperands[0])) return failure();
//
// if (parser.parseColon()) return failure();
// if (parser.parseType(inputRawTypes[0])) return failure();
// if (!inputRawTypes[0].isa<TensorType>())
// return parser.emitError(loc, "invalid kind of type specified");
//
// Attribute value_attr;
// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands))
// return failure();
// if (parser.parseAttribute(value_attr, "value", result.attributes)) return
// failure();
// return success();
//}
// TODO(shibo): can be removed?
// template <typename FillTensorOp>
// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) {
// p << FillTensorOp::getOperationName();
// p << " ";
// p.printOperand(op.getOperand());
// p << " : ";
// p << op.getOperation()->getOperandTypes();
// p << " ";
// p << op.getAttr("value");
//}
static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return failure();
auto tensor_type = getTensorType(result.getContext()); auto tensor_type = getTensorType(result.getContext());
Attribute value_attr; mlir::Attribute value_attr;
return failure( return mlir::failure(
parser.resolveOperand(operands[0], tensor_type, result.operands) || parser.resolveOperand(operands[0], tensor_type, result.operands) ||
parser.parseAttribute(value_attr, "values", result.attributes)); parser.parseAttribute(value_attr, "values", result.attributes));
} }
template <typename SetTensorOp> template <typename SetTensorOp>
static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) { // NOLINT
p << SetTensorOp::getOperationName() << " "; p << SetTensorOp::getOperationName() << " ";
p.printOperand(op.getOperand()); p.printOperand(op.getOperand());
p << " " << op.getAttr("values"); p << " " << op->getAttr("values");
} }
} // namespace dt
} // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT #include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT
} // namespace infrt::dt #include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc"
...@@ -19,13 +19,8 @@ ...@@ -19,13 +19,8 @@
#include <string> #include <string>
using namespace mlir; // NOLINT namespace infrt {
namespace infrt::dt { namespace dt {
namespace detail {
struct TensorTypeStorage;
} // namespace detail
enum class TargetType : uint8_t { X86, CUDA }; enum class TargetType : uint8_t { X86, CUDA };
enum class LayoutType : uint8_t { NCHW, NHWC }; enum class LayoutType : uint8_t { NCHW, NHWC };
enum class PrecisionType : uint8_t { I32, F32 }; enum class PrecisionType : uint8_t { I32, F32 };
...@@ -34,9 +29,39 @@ llvm::Optional<TargetType> GetTargetType(mlir::StringRef key); ...@@ -34,9 +29,39 @@ llvm::Optional<TargetType> GetTargetType(mlir::StringRef key);
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key); llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key);
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key); llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key);
raw_ostream &operator<<(raw_ostream &os, TargetType type); mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type);
raw_ostream &operator<<(raw_ostream &os, LayoutType type); mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type);
raw_ostream &operator<<(raw_ostream &os, PrecisionType type); mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type);
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
class TensorType : public mlir::Type::TypeBase<TensorType, class TensorType : public mlir::Type::TypeBase<TensorType,
mlir::Type, mlir::Type,
...@@ -52,7 +77,7 @@ class TensorType : public mlir::Type::TypeBase<TensorType, ...@@ -52,7 +77,7 @@ class TensorType : public mlir::Type::TypeBase<TensorType,
PrecisionType precision(); PrecisionType precision();
}; };
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType); mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType);
class TensorMapType : public mlir::Type::TypeBase<TensorMapType, class TensorMapType : public mlir::Type::TypeBase<TensorMapType,
mlir::Type, mlir::Type,
...@@ -70,10 +95,10 @@ class StringType ...@@ -70,10 +95,10 @@ class StringType
static StringType get(); static StringType get();
static StringType get(mlir::MLIRContext *context); static StringType get(mlir::MLIRContext *context);
}; };
} // namespace dt
} // namespace infrt
#include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc" #include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.hpp.inc" #include "paddle/infrt/dialect/dense_tensor.hpp.inc"
} // namespace infrt::dt
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#include "paddle/infrt/dialect/diagnostic_utils.h" #include "paddle/infrt/dialect/diagnostic_utils.h"
#include <llvm/Support/raw_ostream.h>
#include <string> #include <string>
namespace infrt::dialect { namespace infrt {
namespace dialect {
struct MyScopedDiagnosicHandler::Impl { struct MyScopedDiagnosicHandler::Impl {
Impl() : diag_stream_(diag_str_) {} Impl() : diag_stream_(diag_str_) {}
...@@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) { ...@@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) {
return mlir::failure(true); return mlir::failure(true);
} }
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <memory> #include <memory>
namespace infrt::dialect { namespace infrt {
namespace dialect {
/** /**
* A scoped diagnostic handler to help debug MLIR process. * A scoped diagnostic handler to help debug MLIR process.
...@@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler { ...@@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler {
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;
}; };
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -13,24 +13,26 @@ ...@@ -13,24 +13,26 @@
// limitations under the License. // limitations under the License.
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LogicalResult.h> #include <mlir/Support/LogicalResult.h>
namespace infrt::hlir::dialect { namespace infrt {
namespace hlir {
namespace dialect {
class CinnDialect : public ::mlir::Dialect { class CinnDialect : public mlir::Dialect {
public: public:
explicit CinnDialect(::mlir::MLIRContext* ctx); explicit CinnDialect(mlir::MLIRContext* ctx);
//! We should register this function in dialect //! We should register this function in dialect
static llvm::StringRef getDialectNamespace() { static llvm::StringRef getDialectNamespace() {
return "infrt::hlir::dialect"; return "infrt::hlir::dialect";
} }
}; };
} // namespace dialect
} // namespace infrt::hlir::dialect } // namespace hlir
} // namespace infrt
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/test_kernels.h" #include "paddle/infrt/dialect/test_kernels.h"
namespace infrt::dialect { namespace infrt {
namespace dialect {
// ----INFRTDialect definition begin---- // ----INFRTDialect definition begin----
void INFRTDialect::initialize() { void INFRTDialect::initialize() {
...@@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type, ...@@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type,
// ----INFRTDialect definition end---- // ----INFRTDialect definition end----
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -18,19 +18,17 @@ ...@@ -18,19 +18,17 @@
#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h> #include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/MLIRContext.h> #include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h> #include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h> #include <mlir/IR/Types.h>
#include "paddle/infrt/dialect/infrt_base.hpp.inc" #include "paddle/infrt/dialect/infrt_base.hpp.inc"
namespace infrt::dialect { namespace infrt {
namespace dialect {
class INFRTDialect : public ::mlir::Dialect { class INFRTDialect : public mlir::Dialect {
explicit INFRTDialect(::mlir::MLIRContext *context) explicit INFRTDialect(mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), : mlir::Dialect(
context, getDialectNamespace(), context, mlir::TypeID::get<INFRTDialect>()) {
::mlir::TypeID::get<INFRTDialect>()) {
initialize(); initialize();
} }
...@@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect { ...@@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect {
mlir::DialectAsmPrinter &printer) const override; mlir::DialectAsmPrinter &printer) const override;
void initialize(); void initialize();
friend class ::mlir::MLIRContext; friend class mlir::MLIRContext;
public: public:
static ::llvm::StringRef getDialectNamespace() { return "infrt"; } static ::llvm::StringRef getDialectNamespace() { return "infrt"; }
}; };
} // namespace dialect
} // namespace infrt::dialect
namespace mlir {
template <typename T> template <typename T>
static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
...@@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT ...@@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
return b.getIntegerAttr(b.getI32Type(), constant); return b.getIntegerAttr(b.getI32Type(), constant);
} }
static mlir::SmallVector<::mlir::Value, 4> cvtValueToValueRange( static mlir::SmallVector<mlir::Value, 4> cvtValueToValueRange(
const mlir::Value &operand) { const mlir::Value &operand) {
return mlir::SmallVector<::mlir::Value, 4>(1, operand); return mlir::SmallVector<mlir::Value, 4>(1, operand);
} }
static mlir::SmallVector<::mlir::Value, 4> concatTwoValueRange( static mlir::SmallVector<mlir::Value, 4> concatTwoValueRange(
mlir::ValueRange operand_0, mlir::ValueRange operand_1) { mlir::ValueRange operand_0, mlir::ValueRange operand_1) {
mlir::SmallVector<::mlir::Value, 4> operands; mlir::SmallVector<mlir::Value, 4> operands;
operands.append(operand_0.begin(), operand_0.end()); operands.append(operand_0.begin(), operand_0.end());
operands.append(operand_1.begin(), operand_1.end()); operands.append(operand_1.begin(), operand_1.end());
return operands; return operands;
} }
} // namespace infrt
} // namespace mlir
...@@ -28,11 +28,11 @@ def TensorMapType : ...@@ -28,11 +28,11 @@ def TensorMapType :
def BufferType : OpaqueType<"b", "buffer", "buffer">; def BufferType : OpaqueType<"b", "buffer", "buffer">;
class INFRT_createI32Attr<string value> : NativeCodeCall< class INFRT_createI32Attr<string value> : NativeCodeCall<
"mlir::createI32Attr($_builder, $_loc, " # value # ")">; "infrt::createI32Attr($_builder, $_loc, " # value # ")">;
def INFRT_cvtValueToValueRange : NativeCodeCall< def INFRT_cvtValueToValueRange : NativeCodeCall<
"mlir::cvtValueToValueRange($0)">; "infrt::cvtValueToValueRange($0)">;
def INFRT_concatTwoValueRange : NativeCodeCall< def INFRT_concatTwoValueRange : NativeCodeCall<
"mlir::concatTwoValueRange($0, $1)">; "infrt::concatTwoValueRange($0, $1)">;
#endif // INFRT_BASE #endif // INFRT_BASE
...@@ -23,12 +23,10 @@ ...@@ -23,12 +23,10 @@
#include "paddle/infrt/dialect/tensor_shape.h" #include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt { namespace infrt {
void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT registry.insert<ts::TensorShapeDialect,
registry.insert<ts::TensorShapeDialect>(); dialect::INFRTDialect,
registry.insert<dialect::INFRTDialect>(); dt::DTDialect,
registry.insert<dt::DTDialect>(); mlir::pd::PaddleDialect>();
registry.insert<mlir::pd::PaddleDialect>();
} }
} // namespace infrt } // namespace infrt
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
#pragma once #pragma once
#include "mlir/IR/Dialect.h" #include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
namespace infrt { namespace infrt {
void registerCinnDialects(mlir::DialectRegistry &registry); // NOLINT
void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT
} // namespace infrt } // namespace infrt
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h> #include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h> #include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h> #include <mlir/Parser.h>
#include <unordered_map> #include <unordered_map>
...@@ -30,12 +30,15 @@ ...@@ -30,12 +30,15 @@
#include "paddle/infrt/dialect/diagnostic_utils.h" #include "paddle/infrt/dialect/diagnostic_utils.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h" #include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect { namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) { const std::string& mlir_source) {
// context->allowUnregisteredDialects(); // context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry()); mlir::DialectRegistry registry;
registerCinnDialects(registry);
context->appendDialectRegistry(registry);
// Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is // Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is
// enough。Don't need StandardOpsDialect. // enough。Don't need StandardOpsDialect.
// context->getDialectRegistry().insert<mlir::StandardOpsDialect>(); // context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
...@@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, ...@@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context) { mlir::MLIRContext* context) {
// context->allowUnregisteredDialects(); // context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry()); mlir::DialectRegistry registry;
context->getDialectRegistry().insert<mlir::StandardOpsDialect>(); registerCinnDialects(registry);
context->appendDialectRegistry(registry);
mlir::ScopedDiagnosticHandler scope_handler( mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) { context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error) if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
...@@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, ...@@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
return mlir::parseSourceFile(std::string(file_name), context); return mlir::parseSourceFile(std::string(file_name), context);
} }
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -15,16 +15,17 @@ ...@@ -15,16 +15,17 @@
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <mlir/IR/Module.h> #include <mlir/IR/BuiltinOps.h>
#include <string> #include <string>
#include <memory> #include <memory>
namespace infrt::dialect { namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source); const std::string& mlir_source);
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context); mlir::MLIRContext* context);
} // namespace dialect
} // namespace infrt::dialect } // namespace infrt
...@@ -17,14 +17,15 @@ ...@@ -17,14 +17,15 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Function.h> #include <mlir/IR/BuiltinTypes.h>
#include <mlir/Parser.h> #include <mlir/Parser.h>
#include <string> #include <string>
#include "paddle/infrt/dialect/init_infrt_dialects.h" #include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect { namespace infrt {
namespace dialect {
TEST(MlirLoader, basic) { TEST(MlirLoader, basic) {
mlir::MLIRContext context; mlir::MLIRContext context;
...@@ -42,8 +43,7 @@ func @main() -> f32 { ...@@ -42,8 +43,7 @@ func @main() -> f32 {
)ROC"; )ROC";
auto module = LoadMlirSource(&context, source); auto module = LoadMlirSource(&context, source);
module->verify(); EXPECT_TRUE(mlir::succeeded(module->verify()));
LOG(INFO) << "module name: " << module->getOperationName().data(); LOG(INFO) << "module name: " << module->getOperationName().data();
for (auto func : module->getOps<mlir::FuncOp>()) { for (auto func : module->getOps<mlir::FuncOp>()) {
LOG(INFO) << "get func " << func.getName().str(); LOG(INFO) << "get func " << func.getName().str();
...@@ -54,4 +54,5 @@ func @main() -> f32 { ...@@ -54,4 +54,5 @@ func @main() -> f32 {
} }
} }
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -20,5 +20,5 @@ func @main() -> tensor<?xf32> { ...@@ -20,5 +20,5 @@ func @main() -> tensor<?xf32> {
%c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32> %e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
infrt.return %e2 : tensor<?xf32> "pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
} }
\ No newline at end of file
...@@ -11,5 +11,5 @@ func @main() -> tensor<?xf32> { ...@@ -11,5 +11,5 @@ func @main() -> tensor<?xf32> {
%c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32> %c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
%d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32> %d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
infrt.return %d : tensor<?x3x256x256xf32> "pd.fetch"(%d) {name="output"} :(tensor<?x3x256x256xf32>)->()
} }
\ No newline at end of file
...@@ -18,5 +18,5 @@ func @main() -> tensor<?xf32> { ...@@ -18,5 +18,5 @@ func @main() -> tensor<?xf32> {
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32> %e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%e2) :(tensor<?xf32>)->() "pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
} }
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<INFRT_Dialect, mnemonic, traits>;
...@@ -12,34 +12,14 @@ ...@@ -12,34 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <glog/logging.h>
#include <llvm/Support/CommandLine.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Dialect.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h> #include <mlir/Support/MlirOptMain.h>
#include <mlir/Transforms/Passes.h> #include <mlir/Transforms/Passes.h>
#include <iostream>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h" #include "paddle/infrt/dialect/init_infrt_dialects.h"
#include "paddle/infrt/dialect/mlir_loader.h"
int main(int argc, char **argv) { int main(int argc, char **argv) {
mlir::MLIRContext *context = infrt::Global::getMLIRContext(); mlir::DialectRegistry registry;
infrt::registerCinnDialects(registry);
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::registerCanonicalizerPass(); mlir::registerCanonicalizerPass();
return mlir::failed( return mlir::failed(
mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry)); mlir::MlirOptMain(argc, argv, "infrt mlir pass driver", registry));
} }
...@@ -16,7 +16,7 @@ def PD_Dialect : Dialect { ...@@ -16,7 +16,7 @@ def PD_Dialect : Dialect {
This dialect contains the PaddlePaddle operators. This dialect contains the PaddlePaddle operators.
}]; }];
let cppNamespace = "::mlir::pd"; let cppNamespace = "mlir::pd";
} }
class PD_Op<string mnemonic, list<OpTrait> traits = []> : class PD_Op<string mnemonic, list<OpTrait> traits = []> :
......
...@@ -14,10 +14,15 @@ ...@@ -14,10 +14,15 @@
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "mlir/IR/Matchers.h" #include <mlir/IR/Matchers.h>
#include "mlir/IR/PatternMatch.h" #include <mlir/IR/PatternMatch.h>
#include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/infrt_base.h"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
namespace mlir { namespace mlir {
namespace pd { namespace pd {
PaddleDialect::PaddleDialect(MLIRContext *context) PaddleDialect::PaddleDialect(MLIRContext *context)
...@@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder, ...@@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder,
return builder.create<ConstantOp>(loc, value); return builder.create<ConstantOp>(loc, value);
} }
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
void ConstantOp::build(OpBuilder &builder, void ConstantOp::build(OpBuilder &builder,
OperationState &state, OperationState &state,
Attribute value) { Attribute value) {
...@@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes( ...@@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes(
inferredReturnTypes.push_back(attributes.get("value").getType()); inferredReturnTypes.push_back(attributes.get("value").getType());
return success(); return success();
} }
::mlir::OpFoldResult ConstantOp::fold( mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<::mlir::Attribute> operands) { ::llvm::ArrayRef<mlir::Attribute> operands) {
return value(); return value();
} }
...@@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes( ...@@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes(
return success(); return success();
} }
void ElementwiseAdd::getCanonicalizationPatterns( void ElementwiseAdd::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseMulAdd>(context); results.insert<FuseMulAdd>(context);
} }
::mlir::OpFoldResult ElementwiseAdd::fold( mlir::OpFoldResult ElementwiseAdd::fold(
llvm::ArrayRef<mlir::Attribute> operands) { llvm::ArrayRef<mlir::Attribute> operands) {
if (getElementTypeOrSelf(getType()).isa<FloatType>()) { if (getElementTypeOrSelf(getType()).isa<FloatType>()) {
if (!operands[0] || !operands[1]) return {}; if (!operands[0] || !operands[1]) return {};
...@@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes( ...@@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes(
} }
void ReluOp::getCanonicalizationPatterns( void ReluOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseFCRelu>(context); results.insert<FuseFCRelu>(context);
} }
void FusedRepeatedFCRelu::getCanonicalizationPatterns( void FusedRepeatedFCRelu::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseRepeatedFCRelu2>(context); results.insert<FuseRepeatedFCRelu2>(context);
} }
void BatchNormOp::getCanonicalizationPatterns( void BatchNormOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseBatchNormWithConvPattern>(context); results.insert<FuseBatchNormWithConvPattern>(context);
} }
......
...@@ -14,21 +14,20 @@ ...@@ -14,21 +14,20 @@
#pragma once #pragma once
#include "mlir/Dialect/Traits.h" #include <mlir/Dialect/Traits.h>
#include "mlir/IR/Attributes.h" #include <mlir/IR/Attributes.h>
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "mlir/IR/Dialect.h" #include <mlir/IR/BuiltinOps.h>
#include "mlir/IR/Function.h" #include <mlir/IR/BuiltinTypes.h>
#include "mlir/IR/Matchers.h" #include <mlir/IR/Dialect.h>
#include "mlir/IR/Module.h" #include <mlir/IR/Matchers.h>
#include "mlir/IR/OpImplementation.h" #include <mlir/IR/OpImplementation.h>
#include "mlir/IR/StandardTypes.h" #include <mlir/IR/TypeUtilities.h>
#include "mlir/IR/TypeUtilities.h" #include <mlir/Interfaces/CallInterfaces.h>
#include "mlir/Interfaces/CallInterfaces.h" #include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include <mlir/Interfaces/InferTypeOpInterface.h>
#include "mlir/Interfaces/InferTypeOpInterface.h" #include <mlir/Interfaces/LoopLikeInterface.h>
#include "mlir/Interfaces/LoopLikeInterface.h" #include <mlir/Interfaces/SideEffectInterfaces.h>
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir { namespace mlir {
namespace pd { namespace pd {
...@@ -53,9 +52,8 @@ class PaddleDialect : public Dialect { ...@@ -53,9 +52,8 @@ class PaddleDialect : public Dialect {
} }
}; };
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace pd } // namespace pd
} // namespace mlir } // namespace mlir
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
...@@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> { ...@@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> {
def PD_FetchOp : PD_Op<"fetch", [Terminator]> { def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let summary = "fetch Op"; let summary = "fetch Op";
let description = [{
Return the output tensor from the subgraph.
}];
let arguments = (ins PD_Tensor :$inputs, StrAttr:$name);
}
def PD_ReturnOp : PD_Op<"return", [Terminator]> {
let summary = "return Op";
let description = [{ let description = [{
Fetch tensor from the graph. Fetch tensor from the graph.
}]; }];
...@@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> { ...@@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let arguments = (ins Variadic<PD_Tensor>:$inputs); let arguments = (ins Variadic<PD_Tensor>:$inputs);
} }
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> {
let summary = "paddle graph Op"; let summary = "paddle graph Op";
let description = [{ let description = [{
Describe a paddle graph or subgraph. Describe a paddle graph or subgraph.
...@@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte ...@@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte
let hasFolder = 1; let hasFolder = 1;
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">, OpBuilder<(ins "Attribute":$value)>,
]; ];
} }
......
...@@ -18,12 +18,11 @@ ...@@ -18,12 +18,11 @@
#pragma once #pragma once
#include "mlir/IR/Diagnostics.h" #include <mlir/IR/Diagnostics.h>
#include "mlir/IR/Location.h" #include <mlir/IR/Location.h>
#include "mlir/IR/Operation.h" #include <mlir/IR/Operation.h>
#include "mlir/IR/StandardTypes.h" #include <mlir/IR/TypeUtilities.h>
#include "mlir/IR/TypeUtilities.h" #include <mlir/IR/Types.h>
#include "mlir/IR/Types.h"
namespace mlir { namespace mlir {
namespace PD { namespace PD {
......
...@@ -11,26 +11,25 @@ ...@@ -11,26 +11,25 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <llvm/ADT/Optional.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/ScopedPrinter.h>
#include <mlir/IR/BuiltinOps.h>
#include <llvm/Support/raw_os_ostream.hv
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Block.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/Region.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Passes.h>
#include <iostream> #include <iostream>
#include "llvm/ADT/Optional.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h" #include "paddle/infrt/dialect/init_infrt_dialects.h"
...@@ -114,17 +113,15 @@ int main(int argc, char **argv) { ...@@ -114,17 +113,15 @@ int main(int argc, char **argv) {
mlir::registerPassManagerCLOptions(); mlir::registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv, "mlir demo"); cl::ParseCommandLineOptions(argc, argv, "mlir demo");
mlir::MLIRContext *context = infrt::Global::getMLIRContext(); mlir::DialectRegistry registry;
// context->allowUnregisteredDialects(); infrt::registerCinnDialects(registry);
auto &registry = context->getDialectRegistry(); mlir::MLIRContext context(registry);
infrt::RegisterCinnDialects(registry);
// mlir will verify module automatically after parsing. // mlir will verify module automatically after parsing.
// https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051 // https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051
// mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source, // mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source,
// context); // context);
mlir::OwningModuleRef module_ref = mlir::OwningModuleRef module_ref =
mlir::parseSourceFile(inputFilename, context); mlir::parseSourceFile(inputFilename, &context);
std::cout << "----------print IR Structure begin----------" << std::endl; std::cout << "----------print IR Structure begin----------" << std::endl;
printOperation(module_ref->getOperation(), 0); printOperation(module_ref->getOperation(), 0);
std::cout << "----------print IR Structure end----------" << std::endl; std::cout << "----------print IR Structure end----------" << std::endl;
......
...@@ -17,16 +17,16 @@ ...@@ -17,16 +17,16 @@
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h> #include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h> #include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h> #include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h> #include <mlir/Support/LogicalResult.h>
namespace infrt::ts { namespace infrt {
namespace ts {
using namespace mlir; // NOLINT using namespace mlir; // NOLINT
void TensorShapeDialect::initialize() { void TensorShapeDialect::initialize() {
...@@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const { ...@@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const {
return Type(); return Type();
} }
void TensorShapeDialect::printType(::mlir::Type type, void TensorShapeDialect::printType(mlir::Type type,
::mlir::DialectAsmPrinter &os) const { mlir::DialectAsmPrinter &os) const {
if (type.isa<ShapeType>()) { if (type.isa<ShapeType>()) {
os << "shape"; os << "shape";
return; return;
...@@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type, ...@@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type,
} }
llvm_unreachable("unexpected 'shape' type kind"); llvm_unreachable("unexpected 'shape' type kind");
} }
} // namespace ts
} // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT #include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT
} // namespace infrt::ts #include "paddle/infrt/dialect/tensor_shape_dialect.cpp.inc"
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::ts { namespace infrt {
namespace ts {
class ShapeType class ShapeType
: public mlir::Type::TypeBase<ShapeType, mlir::Type, mlir::TypeStorage> { : public mlir::Type::TypeBase<ShapeType, mlir::Type, mlir::TypeStorage> {
...@@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase<PartialShapeType, ...@@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase<PartialShapeType,
public: public:
using Base::Base; using Base::Base;
}; };
} // namespace ts
} // namespace infrt
using namespace mlir; // NOLINT
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.hpp.inc" #include "paddle/infrt/dialect/tensor_shape.hpp.inc"
#include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc" #include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc"
} // namespace infrt::ts
...@@ -19,7 +19,7 @@ def TensorShapeDialect : Dialect { ...@@ -19,7 +19,7 @@ def TensorShapeDialect : Dialect {
def TS_Shape : DialectType<TensorShapeDialect, def TS_Shape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::ShapeType>()">, "!ts.shape type">, CPred<"$_self.isa<::infrt::ts::ShapeType>()">, "!ts.shape type">,
BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
let typeDescription = [{ let description = [{
`!ts.shape type` represents a static tensor shape. `!ts.shape type` represents a static tensor shape.
}]; }];
} }
...@@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { ...@@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
def TS_PartialShape : DialectType<TensorShapeDialect, def TS_PartialShape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::PartialShapeType>()">, "!ts.partial_shape type">, CPred<"$_self.isa<::infrt::ts::PartialShapeType>()">, "!ts.partial_shape type">,
BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> { BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> {
let typeDescription = [{ let description = [{
`!ts.partial_shape type` represents either a static tensor shape, unranked `!ts.partial_shape type` represents either a static tensor shape, unranked
tensor shape or a ranked tensor shape with unknown dimension sizes. tensor shape or a ranked tensor shape with unknown dimension sizes.
}]; }];
......
...@@ -11,10 +11,10 @@ ...@@ -11,10 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <llvm/Support/CommandLine.h>
#include <mlir/Pass/PassManager.h>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
#include <paddle/infrt/dialect/pd_ops.h>
#include <list> #include <list>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "llvm/ADT/SetVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -32,9 +31,9 @@ namespace { ...@@ -32,9 +31,9 @@ namespace {
// Reference the function nameed "FlexibleDFS" but defined in: // Reference the function nameed "FlexibleDFS" but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc. // paddle/fluid/framework/ir/subgraph_detector.cc.
bool reverseDfs(std::vector<::mlir::Operation *> source, bool reverseDfs(std::vector<mlir::Operation *> source,
const std::function<bool(const ::mlir::Operation *)> &func) { const std::function<bool(const mlir::Operation *)> &func) {
std::unordered_set<const ::mlir::Operation *> visited; std::unordered_set<const mlir::Operation *> visited;
while (!source.empty()) { while (!source.empty()) {
auto node = source.back(); auto node = source.back();
source.pop_back(); source.pop_back();
...@@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source, ...@@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
auto values = node->getOperands(); auto values = node->getOperands();
for (auto value : values) { for (auto value : values) {
// if the value is a block argument, the node is nullptr. // if the value is a block argument, the node is nullptr.
::mlir::Operation *node = value.getDefiningOp(); mlir::Operation *node = value.getDefiningOp();
if (node != nullptr && !visited.count(node)) { if (node != nullptr && !visited.count(node)) {
source.emplace_back(node); source.emplace_back(node);
} }
...@@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source, ...@@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
} }
// merge the first&second graph op to a new graph op. // merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
::mlir::pd::GraphOp first, mlir::pd::GraphOp first,
::mlir::pd::GraphOp second) { mlir::pd::GraphOp second) {
// comput inputs and outputs // comput inputs and outputs
::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs; ::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (::mlir::Value input : second.getOperands()) { for (mlir::Value input : second.getOperands()) {
if (input.getDefiningOp() != first) { if (input.getDefiningOp() != first) {
inputs.push_back(input); inputs.push_back(input);
} }
} }
::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping; ::llvm::DenseMap<mlir::Value, unsigned int> op_output_mapping;
for (::mlir::Value output : first.getResults()) { for (mlir::Value output : first.getResults()) {
for (::mlir::Operation *user : output.getUsers()) { for (mlir::Operation *user : output.getUsers()) {
if (user != second && user->getParentOp() != second) { if (user != second && user->getParentOp() != second) {
op_output_mapping[output] = outputs.size(); op_output_mapping[output] = outputs.size();
outputs.push_back(output); outputs.push_back(output);
...@@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT ...@@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
} }
} }
} }
auto fetch_op = second.getBody()->getTerminator(); auto return_op = second.getBody()->getTerminator();
outputs.append(fetch_op->getOperands().begin(), outputs.append(return_op->getOperands().begin(),
fetch_op->getOperands().end()); return_op->getOperands().end());
::llvm::SmallVector<::mlir::Type, 4> fetch_types; ::llvm::SmallVector<mlir::Type, 4> return_types;
for (auto value : outputs) { for (auto value : outputs) {
fetch_types.push_back(value.getType()); return_types.push_back(value.getType());
} }
// create the new graph op // create the new graph op
builder.setInsertionPoint(first); builder.setInsertionPoint(first);
auto loc = first.getLoc(); auto loc = first.getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs); auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
::mlir::Block *block = new ::mlir::Block; mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator(); auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(), block->getOperations().splice(block->begin(),
second.getBody()->getOperations(), second.getBody()->getOperations(),
...@@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT ...@@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
copy_range.begin(), copy_range.begin(),
copy_range.end()); copy_range.end());
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, outputs); builder.create<mlir::pd::ReturnOp>(loc, outputs);
graph_op.body().push_back(block); graph_op.body().push_back(block);
// mapping the output // mapping the output
unsigned int num_result = first.getNumResults(); unsigned int num_result = first.getNumResults();
fetch_op = first.getBody()->getTerminator(); return_op = first.getBody()->getTerminator();
for (unsigned int index = 0; index < num_result; ++index) { for (unsigned int index = 0; index < num_result; ++index) {
auto origin_value = first.getResult(index); auto origin_value = first.getResult(index);
if (op_output_mapping.find(origin_value) == op_output_mapping.end()) { if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
origin_value.replaceAllUsesWith(fetch_op->getOperand(index)); origin_value.replaceAllUsesWith(return_op->getOperand(index));
} else { } else {
auto inner_value = fetch_op->getOperand(index); auto inner_value = return_op->getOperand(index);
auto outer_value = graph_op.getResult(op_output_mapping[origin_value]); auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
while (!origin_value.use_empty()) { while (!origin_value.use_empty()) {
auto replace_value = auto replace_value =
...@@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT ...@@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
// Topological sort the function op. // Topological sort the function op.
void topoSortBlock(mlir::Block &body) { // NOLINT void topoSortBlock(mlir::Block &body) { // NOLINT
llvm::SetVector<Operation *> toSort; llvm::SetVector<mlir::Operation *> toSort;
if (body.empty()) return; if (body.empty()) return;
for (auto it = body.rbegin(); it != body.rend(); ++it) { for (auto it = body.rbegin(); it != body.rend(); ++it) {
toSort.insert(&*it); toSort.insert(&*it);
} }
llvm::SetVector<Operation *> result = llvm::SetVector<mlir::Operation *> result =
::mlir::topologicalSort(std::move(toSort)); mlir::topologicalSort(std::move(toSort));
for (auto *op : result) { for (auto *op : result) {
op->moveBefore(body.getTerminator()); op->moveBefore(body.getTerminator());
} }
...@@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT ...@@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT
// Implementation of the trtGraphFusePass. // Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() { void trtGraphFusePass::runOnFunction() {
mlir::Block &body = getFunction().front(); mlir::Block &body = getFunction().front();
::mlir::OpBuilder builder(&body, body.begin()); mlir::OpBuilder builder(&body, body.begin());
bool changed = false; bool changed = false;
do { do {
changed = false; changed = false;
for (auto &op : body) { for (auto &op : body) {
::mlir::pd::GraphOp graph_op = mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr == graph_op) continue; if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) { for (auto user_op : op.getUsers()) {
::mlir::pd::GraphOp user_graph_op = mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op); ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue; if (nullptr == user_graph_op) continue;
// get all dst input nodes except src. // get all dst input nodes except src.
std::vector<::mlir::Operation *> source_nodes; std::vector<mlir::Operation *> source_nodes;
for (auto operand : user_op->getOperands()) { for (auto operand : user_op->getOperands()) {
auto input = operand.getDefiningOp(); auto input = operand.getDefiningOp();
if (input != &op && input != nullptr) { if (input != &op && input != nullptr) {
...@@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() { ...@@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() {
} }
} }
// Reverse DFS from the source_nodes. // Reverse DFS from the source_nodes.
if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) { if (!reverseDfs(source_nodes,
return n == &op; [&op](const mlir::Operation *n) { return n == &op; })) {
})) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true; changed = true;
break; break;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "mlir/Pass/Pass.h" #include <mlir/Pass/Pass.h>
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -28,15 +28,15 @@ namespace trt { ...@@ -28,15 +28,15 @@ namespace trt {
* %a = "pd.feed"()... * %a = "pd.feed"()...
* %c = "pd.graph"(%a) { * %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.return" %m
* } ... * } ...
* %d = "pd.graph"(%c) { * %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)... * %m = "pd.conv3d"(%c)...
* "pd.fetch" %m * "pd.return" %m
* } ... * } ...
* %f = "pd.graph"(%a) { * %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.return" %m
* } ... * } ...
* "pd.fetch" %d, %f * "pd.fetch" %d, %f
* *
...@@ -47,13 +47,13 @@ namespace trt { ...@@ -47,13 +47,13 @@ namespace trt {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s * "pd.return" %n, %s
* } ... * } ...
* "pd.fetch" %d, %f * "pd.fetch" %d, %f
* } * }
*/ */
class trtGraphFusePass class trtGraphFusePass
: public ::mlir::PassWrapper<trtGraphFusePass, ::mlir::FunctionPass> { : public mlir::PassWrapper<trtGraphFusePass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtGraphFusePass"; } ::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override; void runOnFunction() override;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
...@@ -22,24 +22,24 @@ namespace infrt { ...@@ -22,24 +22,24 @@ namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtGraphSplitPass。 // Implementation of the trtGraphSplitPass。
void trtGraphSplitPass::runOnFunction() { void trtGraphSplitPass::runOnFunction() {
std::vector<::mlir::pd::GraphOp> worklist; std::vector<mlir::pd::GraphOp> worklist;
::mlir::Block& block = getFunction().front(); mlir::Block& block = getFunction().front();
for (auto& op : block) { for (auto& op : block) {
::mlir::pd::GraphOp graph_op = mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr != graph_op && if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op); worklist.push_back(graph_op);
} }
} }
while (!worklist.empty()) { while (!worklist.empty()) {
::mlir::pd::GraphOp graph_op = worklist.back(); mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back(); worklist.pop_back();
::mlir::Block* body = graph_op.getBody(); mlir::Block* body = graph_op.getBody();
auto fetch_op = body->getTerminator(); auto return_op = body->getTerminator();
graph_op.replaceAllUsesWith(fetch_op->getOperands()); graph_op.replaceAllUsesWith(return_op->getOperands());
auto copy_range = body->without_terminator(); auto copy_range = body->without_terminator();
block.getOperations().splice(::mlir::Block::iterator(graph_op), block.getOperations().splice(mlir::Block::iterator(graph_op),
body->getOperations(), body->getOperations(),
copy_range.begin(), copy_range.begin(),
copy_range.end()); copy_range.end());
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "mlir/Pass/Pass.h" #include <mlir/Pass/Pass.h>
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -31,9 +31,9 @@ namespace trt { ...@@ -31,9 +31,9 @@ namespace trt {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s * "pd.return" (%n, %s)
* } ... * } ...
* "pd.fetch" %d, %f * "pd.fetch" (%d, %f)
* } * }
* *
* destination func: * destination func:
...@@ -42,11 +42,11 @@ namespace trt { ...@@ -42,11 +42,11 @@ namespace trt {
* %c = "pd.conv2d"(%a) ... * %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ... * %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ... * %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f * "pd.fetch" (%d, %f)
* } * }
*/ */
class trtGraphSplitPass class trtGraphSplitPass
: public ::mlir::PassWrapper<trtGraphSplitPass, ::mlir::FunctionPass> { : public mlir::PassWrapper<trtGraphSplitPass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void runOnFunction() override; void runOnFunction() override;
......
...@@ -14,49 +14,48 @@ ...@@ -14,49 +14,48 @@
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtOpTellerPass。 // Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() { void trtOpTellerPass::runOnFunction() {
::mlir::Block &body = getFunction().front(); mlir::Block &body = getFunction().front();
std::vector<::mlir::Operation *> worklist; std::vector<mlir::Operation *> worklist;
worklist.reserve(body.getOperations().size()); worklist.reserve(body.getOperations().size());
for (auto &op : body) { for (auto &op : body) {
worklist.push_back(&op); worklist.push_back(&op);
} }
// Build GraphOp. // Build GraphOp.
::mlir::OpBuilder builder(&body, body.begin()); mlir::OpBuilder builder(&body, body.begin());
while (!worklist.empty()) { while (!worklist.empty()) {
auto *op = worklist.back(); auto *op = worklist.back();
worklist.pop_back(); worklist.pop_back();
if (op == nullptr) continue; if (op == nullptr) continue;
auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op); auto op1 = ::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op);
if (op1) continue; if (op1) continue;
auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op); auto op2 = ::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op);
if (op2) continue; if (op2) continue;
auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op); auto op3 = ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op);
if (op3) continue; if (op3) continue;
builder.setInsertionPoint(op); builder.setInsertionPoint(op);
auto loc = getFunction().getLoc(); auto loc = getFunction().getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>( auto graph_op = builder.create<mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands()); loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; ::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
for (auto v : for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) { ::llvm::SmallVector<mlir::Value, 4>{graph_op.getODSResults(0)}) {
tblgen_repl_values.push_back(v); tblgen_repl_values.push_back(v);
} }
op->replaceAllUsesWith(tblgen_repl_values); op->replaceAllUsesWith(tblgen_repl_values);
// Build graph op. // Build graph op.
::mlir::Block *block = new ::mlir::Block; mlir::Block *block = new mlir::Block;
graph_op.body().push_back(block); graph_op.body().push_back(block);
op->moveBefore(block, block->begin()); op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, op->getResults()); builder.create<mlir::pd::ReturnOp>(loc, op->getResults());
} }
} }
} // namespace trt } // namespace trt
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "mlir/Pass/Pass.h" #include <mlir/Pass/Pass.h>
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -29,7 +29,7 @@ namespace trt { ...@@ -29,7 +29,7 @@ namespace trt {
* %c = "pd.conv2d"(%a) ... * %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ... * %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ... * %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f * "pd.fetch" (%d, %f)
* } * }
* *
* destination func: * destination func:
...@@ -37,23 +37,23 @@ namespace trt { ...@@ -37,23 +37,23 @@ namespace trt {
* %a = "pd.feed"()... * %a = "pd.feed"()...
* %c = "pd.graph"(%a) { * %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.return" (%m)
* } ... * } ...
* %d = "pd.graph"(%c) { * %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)... * %m = "pd.conv3d"(%c)...
* "pd.fetch" %m * "pd.return" (%m)
* } ... * } ...
* %f = "pd.graph"(%a) { * %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.return" (%m)
* } ... * } ...
* "pd.fetch" %d, %f * "pd.fetch" (%d, %f)
* } * }
* TODO(winter-wang): Supplementary how to judge the operators can be supported * TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt. * by tensorrt.
*/ */
class trtOpTellerPass class trtOpTellerPass
: public ::mlir::PassWrapper<trtOpTellerPass, ::mlir::FunctionPass> { : public mlir::PassWrapper<trtOpTellerPass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtOpTellerPass"; } ::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override; void runOnFunction() override;
......
...@@ -13,27 +13,25 @@ ...@@ -13,27 +13,25 @@
// limitations under the License. // limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "mlir/IR/Matchers.h" #include <mlir/IR/Matchers.h>
#include "mlir/IR/OpImplementation.h" #include <mlir/IR/OpImplementation.h>
#include "mlir/IR/PatternMatch.h" #include <mlir/IR/PatternMatch.h>
#include "mlir/Interfaces/CallInterfaces.h" #include <mlir/Interfaces/CallInterfaces.h>
#include "mlir/Interfaces/SideEffectInterfaces.h" #include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt { namespace infrt {
namespace trt { namespace trt {
TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context) TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: ::mlir::Dialect("trt", context, ::mlir::TypeID::get<TensorRTDialect>()) { : mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT #include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>(); >();
#undef GET_OP_LIST
} }
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
...@@ -14,37 +14,32 @@ ...@@ -14,37 +14,32 @@
#pragma once #pragma once
#include "mlir/Dialect/Traits.h" #include <mlir/Dialect/Traits.h>
#include "mlir/IR/Attributes.h" #include <mlir/IR/Attributes.h>
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "mlir/IR/Dialect.h" #include <mlir/IR/BuiltinOps.h>
#include "mlir/IR/Function.h" #include <mlir/IR/BuiltinTypes.h>
#include "mlir/IR/Matchers.h" #include <mlir/IR/Dialect.h>
#include "mlir/IR/Module.h" #include <mlir/IR/Matchers.h>
#include "mlir/IR/OpImplementation.h" #include <mlir/IR/OpImplementation.h>
#include "mlir/IR/StandardTypes.h" #include <mlir/IR/TypeUtilities.h>
#include "mlir/IR/TypeUtilities.h" #include <mlir/Interfaces/CallInterfaces.h>
#include "mlir/Interfaces/CallInterfaces.h" #include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include <mlir/Interfaces/InferTypeOpInterface.h>
#include "mlir/Interfaces/InferTypeOpInterface.h" #include <mlir/Interfaces/LoopLikeInterface.h>
#include "mlir/Interfaces/LoopLikeInterface.h" #include <mlir/Interfaces/SideEffectInterfaces.h>
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
class TensorRTDialect : public ::mlir::Dialect { class TensorRTDialect : public mlir::Dialect {
public: public:
explicit TensorRTDialect(::mlir::MLIRContext* context); explicit TensorRTDialect(mlir::MLIRContext* context);
static llvm::StringRef getDialectNamespace() { return "trt"; } static llvm::StringRef getDialectNamespace() { return "trt"; }
}; };
// mlir bug。 can be removed safety when update mlir to llvm11. } // namespace trt
using namespace mlir; // NOLINT } // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc" #include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/test_kernels.h" #include "paddle/infrt/dialect/test_kernels.h"
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "mlir/IR/OpDefinition.h" #include <mlir/IR/OpDefinition.h>
#include "mlir/IR/OpImplementation.h" #include <mlir/IR/OpImplementation.h>
#include "mlir/IR/StandardTypes.h" #include <mlir/IR/TypeUtilities.h>
#include "mlir/IR/TypeUtilities.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BenchmarkOp // BenchmarkOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
...@@ -32,65 +31,67 @@ namespace infrt::dialect { ...@@ -32,65 +31,67 @@ namespace infrt::dialect {
// ... // ...
// } // }
static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT static mlir::ParseResult parseBenchmarkOp(
OperationState &result) { // NOLINT mlir::OpAsmParser &parser, // NOLINT
StringAttr nameAttr; mlir::OperationState &result) { // NOLINT
mlir::StringAttr nameAttr;
if (parser.parseAttribute(nameAttr, "name", result.attributes)) if (parser.parseAttribute(nameAttr, "name", result.attributes))
return failure(); return mlir::failure();
// Parse the operands, e.g. (%c : i32, %d : f32) // Parse the operands, e.g. (%c : i32, %d : f32)
if (parser.parseLParen()) return failure(); if (parser.parseLParen()) return mlir::failure();
SmallVector<OpAsmParser::OperandType, 4> operands; llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
SmallVector<Type, 4> types; llvm::SmallVector<mlir::Type, 4> types;
llvm::SMLoc type_loc = parser.getCurrentLocation(); llvm::SMLoc type_loc = parser.getCurrentLocation();
if (parser.parseOptionalRParen()) { if (parser.parseOptionalRParen()) {
// Parse non-empty operands // Parse non-empty operands
do { do {
// Parse %c : i32, // Parse %c : i32,
OpAsmParser::OperandType operand; mlir::OpAsmParser::OperandType operand;
Type type; mlir::Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type)) if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure(); return mlir::failure();
operands.push_back(operand); operands.push_back(operand);
types.push_back(type); types.push_back(type);
} while (succeeded(parser.parseOptionalComma())); } while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen()) return failure(); if (parser.parseRParen()) return mlir::failure();
} }
if (parser.resolveOperands(operands, types, type_loc, result.operands)) if (parser.resolveOperands(operands, types, type_loc, result.operands))
return failure(); return mlir::failure();
// Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1 // Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1
do { do {
StringRef attr; mlir::StringRef attr;
Attribute resultAttr; mlir::Attribute resultAttr;
if (parser.parseKeyword(&attr) || parser.parseEqual() || if (parser.parseKeyword(&attr) || parser.parseEqual() ||
parser.parseAttribute(resultAttr, parser.parseAttribute(resultAttr,
parser.getBuilder().getIntegerType(32), parser.getBuilder().getIntegerType(32),
attr, attr,
result.attributes)) result.attributes))
return failure(); return mlir::failure();
} while (succeeded(parser.parseOptionalComma())); } while (mlir::succeeded(parser.parseOptionalComma()));
// Set the default attribute num_warmup_runs to 1 if unset // Set the default attribute num_warmup_runs to 1 if unset
auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) { auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) {
bool found = llvm::any_of(result.attributes, bool found = llvm::any_of(result.attributes,
[attr_name](const NamedAttribute &attr) { [attr_name](const mlir::NamedAttribute &attr) {
return attr.first == attr_name; return attr.getName() == attr_name;
}); });
if (!found) { if (!found) {
IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value); mlir::IntegerAttr default_val =
parser.getBuilder().getI32IntegerAttr(value);
result.addAttribute(attr_name, default_val); result.addAttribute(attr_name, default_val);
} }
}; };
setDefaultAttrIfUnset("num_warmup_runs", 1); setDefaultAttrIfUnset("num_warmup_runs", 1);
Region *target = result.addRegion(); mlir::Region *target = result.addRegion();
return parser.parseRegion(*target, return parser.parseRegion(*target,
operands, operands,
types, types,
...@@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT ...@@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
// max_count = 100, duration_secs = 1 { // max_count = 100, duration_secs = 1 {
// ... // ...
// } // }
static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT static void print(mlir::OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p << "infrt.benchmark "; p << "infrt.benchmark ";
// Print the name attribute, e.g "add.i32" // Print the name attribute, e.g "add.i32"
auto name_attr = op.getAttr("name"); auto name_attr = op->getAttr("name");
p << name_attr; p << name_attr;
// Print the operands and types, e.g. (%c : i32, %d : f32) // Print the operands and types, e.g. (%c : i32, %d : f32)
...@@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT ...@@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
bool need_comma = false; bool need_comma = false;
// Print the attributes, e.g. max_count = 100, duration_secs = 1 // Print the attributes, e.g. max_count = 100, duration_secs = 1
for (auto &name_attr : op.getAttrs()) { for (auto &name_attr : op->getAttrs()) {
auto id = name_attr.first; auto id = name_attr.getName();
if (id == "name") continue; if (id == "name") continue;
if (need_comma) p << ", "; if (need_comma) p << ", ";
auto attr = name_attr.second; auto attr = name_attr.getValue();
p << id << " = "; p << id << " = ";
if (auto int_attr = attr.dyn_cast<IntegerAttr>()) { if (auto int_attr = attr.dyn_cast<mlir::IntegerAttr>()) {
int_attr.getValue().print(p.getStream(), /*isSigned=*/false); int_attr.getValue().print(p.getStream(), /*isSigned=*/false);
} else { } else {
op.emitOpError("Unexpected attribute"); op.emitOpError("Unexpected attribute");
...@@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT ...@@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p.printRegion(op.region(), /*printEntryBlockArgs=*/false); p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
} }
static LogicalResult verify(BenchmarkOp op) { static mlir::LogicalResult verify(BenchmarkOp op) {
// Verify that the target benchmark region has exactly one return value. // Verify that the target benchmark region has exactly one return value.
auto &region = op.region(); auto &region = op.region();
auto &last_op = region.front().back(); auto &last_op = region.front().back();
...@@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) { ...@@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) {
"incorrect number of return values. One return value is expected"); "incorrect number of return values. One return value is expected");
} }
return success(); return mlir::success();
} }
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.cpp.inc" #include "paddle/infrt/dialect/test_kernels.cpp.inc"
} // namespace infrt::dialect
...@@ -13,11 +13,8 @@ ...@@ -13,11 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "mlir/IR/OpDefinition.h" #include <mlir/IR/OpDefinition.h>
#include "mlir/Interfaces/SideEffectInterfaces.h" #include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::dialect {
using namespace mlir; // NOLINT
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.hpp.inc" #include "paddle/infrt/dialect/test_kernels.hpp.inc"
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/types.h"
namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/StandardTypes.h>
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
#include "paddle/infrt/host_context/op_executable.h" #include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/symbol_table.h" #include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct CoreRuntime::Impl { struct CoreRuntime::Impl {
KernelRegistry* kernel_registry{}; KernelRegistry* kernel_registry{};
...@@ -90,4 +91,5 @@ llvm::SmallVector<ValueRef, 4> CoreRuntime::GetResults( ...@@ -90,4 +91,5 @@ llvm::SmallVector<ValueRef, 4> CoreRuntime::GetResults(
CoreRuntime::~CoreRuntime() {} CoreRuntime::~CoreRuntime() {}
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
#include "paddle/infrt/host_context/value.h" #include "paddle/infrt/host_context/value.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
class KernelRegistry; class KernelRegistry;
class OpExecutable; class OpExecutable;
...@@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime { ...@@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime {
OpExecutableBuilder* NewOpExecutable(const std::string& op_name); OpExecutableBuilder* NewOpExecutable(const std::string& op_name);
}; };
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "paddle/infrt/host_context/value.h" #include "paddle/infrt/host_context/value.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
/** /**
* KernelFrame captures the states(input arguments, attributes, results) * KernelFrame captures the states(input arguments, attributes, results)
...@@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame { ...@@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame {
} }
}; };
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/host_context/kernel_utils.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; } int add_i32(int a, int b) { return a + b; }
...@@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) { ...@@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) {
ASSERT_EQ(results[0]->get<int>(), 3); ASSERT_EQ(results[0]->get<int>(), 3);
} }
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace infrt::host_context { namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; } int add_i32(int a, int b) { return a + b; }
float add_f32(float a, float b) { return a + b; } float add_f32(float a, float b) { return a + b; }
...@@ -66,4 +67,5 @@ TEST(KernelImpl, pair) { ...@@ -66,4 +67,5 @@ TEST(KernelImpl, pair) {
ASSERT_EQ(results[1]->get<float>(), 3.f); ASSERT_EQ(results[1]->get<float>(), 3.f);
} }
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/infrt/host_context/mlir_function_executable.h" #include "paddle/infrt/host_context/mlir_function_executable.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <mlir/IR/BuiltinOps.h>
#include <string> // NOLINT #include <string> // NOLINT
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <mlir/IR/Function.h> #include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#pragma once #pragma once
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h> #include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OperationSupport.h> #include <mlir/IR/OperationSupport.h>
#include <unordered_map> #include <unordered_map>
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h> #include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h> #include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h> #include <mlir/Parser.h>
...@@ -40,7 +41,8 @@ ...@@ -40,7 +41,8 @@
#include "paddle/infrt/host_context/value.h" #include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/tensor_shape.h" #include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
template <typename T> template <typename T>
std::string DumpToString(T& op) { // NOLINT std::string DumpToString(T& op) { // NOLINT
...@@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) { ...@@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) {
template <> template <>
boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute( boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none; if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) { if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>(); auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(32)) { if (val.getType().isInteger(32)) {
return val.getInt(); return val.getInt();
} }
...@@ -125,10 +127,10 @@ boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute( ...@@ -125,10 +127,10 @@ boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
} }
template <> template <>
boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute( boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none; if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) { if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>(); auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(64)) { if (val.getType().isInteger(64)) {
return val.getInt(); return val.getInt();
} }
...@@ -139,10 +141,10 @@ boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute( ...@@ -139,10 +141,10 @@ boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
// TODO(Superjomn) Make double and float parsing share some thing. // TODO(Superjomn) Make double and float parsing share some thing.
template <> template <>
boost::optional<float> MlirToRuntimeTranslator::EmitAttribute( boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none; if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) { if (attr.isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>(); auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF32()) return val.getValueAsDouble(); if (val.getType().isF32()) return val.getValueAsDouble();
} }
return boost::none; return boost::none;
...@@ -150,10 +152,10 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute( ...@@ -150,10 +152,10 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
template <> template <>
boost::optional<double> MlirToRuntimeTranslator::EmitAttribute( boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none; if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) { if (attr.isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>(); auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF64()) return val.getValueAsDouble(); if (val.getType().isF64()) return val.getValueAsDouble();
} }
return boost::none; return boost::none;
...@@ -161,17 +163,17 @@ boost::optional<double> MlirToRuntimeTranslator::EmitAttribute( ...@@ -161,17 +163,17 @@ boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
template <> template <>
boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute( boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::StringAttr>()) return boost::none; if (!attr.isa<mlir::StringAttr>()) return boost::none;
return attr->cast<mlir::StringAttr>().getValue().str(); return attr.cast<mlir::StringAttr>().getValue().str();
} }
#define PROCESS_ARRAY_INT(type__, bits__) \ #define PROCESS_ARRAY_INT(type__, bits__) \
template <> \ template <> \
boost::optional<std::vector<type__>> MlirToRuntimeTranslator::EmitAttribute( \ boost::optional<std::vector<type__>> MlirToRuntimeTranslator::EmitAttribute( \
const mlir::Attribute* attr) { \ const mlir::Attribute& attr) { \
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; \ if (!attr.isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr->cast<mlir::ArrayAttr>(); \ auto array = attr.cast<mlir::ArrayAttr>(); \
CHECK(!array.empty()); \ CHECK(!array.empty()); \
\ \
if (!array[0].getType().isInteger(bits__)) { \ if (!array[0].getType().isInteger(bits__)) { \
...@@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64); ...@@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64);
template <> template <>
boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute( boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>(); auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty()); CHECK(!array.empty());
if (!array[0].getType().isF32()) return boost::none; if (!array[0].getType().isF32()) return boost::none;
...@@ -207,9 +209,9 @@ boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute( ...@@ -207,9 +209,9 @@ boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
template <> template <>
boost::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute( boost::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>(); auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty()); CHECK(!array.empty());
if (!array[0].getType().isF64()) return boost::none; if (!array[0].getType().isF64()) return boost::none;
...@@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (int i = 0, e = op->getNumOperands(); i < e; i++) { for (int i = 0, e = op->getNumOperands(); i < e; i++) {
// function argument as value // function argument as value
auto operand = op->getOperand(i); auto operand = op->getOperand(i);
if (operand.getKind() == mlir::Value::Kind::BlockArgument) { /// if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>(); mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg); Value* arg_value = GetValue(arg);
impl_->cur_op->AppendArgument(arg_value); impl_->cur_op->AppendArgument(arg_value);
...@@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (size_t i = 0; i < attrs.size(); i++) { for (size_t i = 0; i < attrs.size(); i++) {
auto& attr = attrs[i]; auto& attr = attrs[i];
if (auto v = EmitAttribute<int32_t>(&attr.second)) { if (auto v = EmitAttribute<int32_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<int64_t>(&attr.second)) { } else if (auto v = EmitAttribute<int64_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<float>(&attr.second)) { } else if (auto v = EmitAttribute<float>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<double>(&attr.second)) { } else if (auto v = EmitAttribute<double>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::string>(&attr.second)) { } else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int16_t>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int64_t>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<float>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<double>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else { } else {
LOG(FATAL) << "Not supported attribute type"; LOG(FATAL) << "Not supported attribute type";
...@@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
llvm::SmallVector<mlir::Type, 0> results; llvm::SmallVector<mlir::Type, 0> results;
auto func_type = auto func_type =
mlir::FunctionType::get(inputs, results, region.getContext()); mlir::FunctionType::get(region.getContext(), inputs, results);
auto* function = impl_->cur_op->CreateFunctionExecutable( auto* function = impl_->cur_op->CreateFunctionExecutable(
&region, func_type, &impl_->func_defs); &region, func_type, &impl_->func_defs);
impl_->cur_op->AppendAttribute(new Value(function)); impl_->cur_op->AppendAttribute(new Value(function));
...@@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) { ...@@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) {
execute.Run(); execute.Run();
} }
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -29,7 +29,8 @@ class Attribute; ...@@ -29,7 +29,8 @@ class Attribute;
class Value; class Value;
} // namespace mlir } // namespace mlir
namespace infrt::host_context { namespace infrt {
namespace host_context {
class CoreRuntimeBuilder; class CoreRuntimeBuilder;
class Value; class Value;
...@@ -73,7 +74,7 @@ class MlirToRuntimeTranslator { ...@@ -73,7 +74,7 @@ class MlirToRuntimeTranslator {
bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table); bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table);
template <typename T> template <typename T>
boost::optional<T> EmitAttribute(const mlir::Attribute* attr); boost::optional<T> EmitAttribute(const mlir::Attribute& attr);
Value* GetOpResult(mlir::Operation* op); Value* GetOpResult(mlir::Operation* op);
...@@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime); ...@@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime);
*/ */
void TestMlir(mlir::ModuleOp module, KernelRegistry* registry); void TestMlir(mlir::ModuleOp module, KernelRegistry* registry);
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
#include "paddle/infrt/kernel/tensor_shape_kernels.h" #include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h" #include "paddle/infrt/kernel/test_kernels.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
TEST(MlirToRuntimeTranslate, basic) { TEST(MlirToRuntimeTranslate, basic) {
mlir::MLIRContext context; mlir::MLIRContext context;
...@@ -48,7 +49,7 @@ func @main() -> () { ...@@ -48,7 +49,7 @@ func @main() -> () {
)ROC"; )ROC";
auto module = dialect::LoadMlirSource(&context, source); auto module = dialect::LoadMlirSource(&context, source);
module->verify(); EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry; KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry); kernel::RegisterFloatBasicKernels(&registry);
...@@ -74,7 +75,7 @@ func @main() -> () { ...@@ -74,7 +75,7 @@ func @main() -> () {
)ROC"; )ROC";
auto module = dialect::LoadMlirSource(&context, source); auto module = dialect::LoadMlirSource(&context, source);
module->verify(); EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry; KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry); kernel::RegisterFloatBasicKernels(&registry);
...@@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F ...@@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
// LOG(INFO) << "content: " << content << std::endl; // LOG(INFO) << "content: " << content << std::endl;
auto module = dialect::LoadMlirSource(context, content); auto module = dialect::LoadMlirSource(context, content);
module->verify(); EXPECT_TRUE(mlir::succeeded(module->verify()));
host_context::KernelRegistry registry; host_context::KernelRegistry registry;
...@@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F ...@@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
} }
} }
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/infrt/host_context/op_executable.h" #include "paddle/infrt/host_context/op_executable.h"
#include <mlir/IR/BuiltinOps.h>
#include <string> #include <string>
#include "paddle/infrt/host_context/kernel_frame.h" #include "paddle/infrt/host_context/kernel_frame.h"
...@@ -21,7 +22,8 @@ ...@@ -21,7 +22,8 @@
#include "paddle/infrt/host_context/mlir_function_executable.h" #include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/symbol_table.h" #include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct OpExecutable::Impl { struct OpExecutable::Impl {
Impl(const std::string& op_name, Impl(const std::string& op_name,
...@@ -148,4 +150,5 @@ void OpExecutable::Execute() { ...@@ -148,4 +150,5 @@ void OpExecutable::Execute() {
OpExecutable::~OpExecutable() {} OpExecutable::~OpExecutable() {}
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -14,19 +14,18 @@ ...@@ -14,19 +14,18 @@
#pragma once #pragma once
#include <llvm/ADT/ArrayRef.h> #include <llvm/ADT/ArrayRef.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "mlir/IR/Function.h"
#include "mlir/IR/Region.h"
namespace mlir { namespace mlir {
class FuncOp; class FuncOp;
} // namespace mlir } // namespace mlir
namespace infrt::host_context { namespace infrt {
namespace host_context {
class SymbolTable; class SymbolTable;
class KernelRegistry; class KernelRegistry;
...@@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable { ...@@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable {
function_defs_t* function_defs); function_defs_t* function_defs);
}; };
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
using infrt::host_context::Attribute; using infrt::host_context::Attribute;
namespace infrt::kernel { namespace infrt {
namespace kernel {
template <typename T> template <typename T>
T add(T a, T b) { T add(T a, T b) {
...@@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) { ...@@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print<float>)); registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print<float>));
} }
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -15,13 +15,16 @@ ...@@ -15,13 +15,16 @@
#pragma once #pragma once
#include <string> #include <string>
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct KernelRegistry; struct KernelRegistry;
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
namespace infrt::kernel { namespace infrt {
namespace kernel {
/** /**
* Register all the basic kernels to \p registry. * Register all the basic kernels to \p registry.
...@@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry); ...@@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry);
void RegisterIntBasicKernels(host_context::KernelRegistry* registry); void RegisterIntBasicKernels(host_context::KernelRegistry* registry);
void RegisterFloatBasicKernels(host_context::KernelRegistry* registry); void RegisterFloatBasicKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -25,7 +25,8 @@ ...@@ -25,7 +25,8 @@
#include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h" #include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel { namespace infrt {
namespace kernel {
using namespace host_context; // NOLINT using namespace host_context; // NOLINT
using namespace tensor; // NOLINT using namespace tensor; // NOLINT
...@@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { ...@@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShallowCopyTensor)); INFRT_KERNEL(ShallowCopyTensor));
} }
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
#pragma once #pragma once
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct KernelRegistry; struct KernelRegistry;
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
namespace infrt::kernel { namespace infrt {
namespace kernel {
void RegisterTensorKernels(host_context::KernelRegistry* registry); void RegisterTensorKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
#include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/tensor/tensor_shape.h" #include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel { namespace infrt {
namespace kernel {
void PrintShape(const tensor::TensorShape& shape) { void PrintShape(const tensor::TensorShape& shape) {
llvm::raw_os_ostream oos(std::cout); llvm::raw_os_ostream oos(std::cout);
...@@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) { ...@@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape)); registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape));
} }
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -14,14 +14,18 @@ ...@@ -14,14 +14,18 @@
#pragma once #pragma once
namespace infrt::host_context { namespace infrt {
namespace host_context {
class KernelRegistry; class KernelRegistry;
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
namespace infrt::kernel { namespace infrt {
namespace kernel {
void RegisterTensorShapeKernels(host_context::KernelRegistry* registry); void RegisterTensorShapeKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -33,7 +33,8 @@ using infrt::host_context::Attribute; ...@@ -33,7 +33,8 @@ using infrt::host_context::Attribute;
using infrt::host_context::MlirFunctionExecutable; using infrt::host_context::MlirFunctionExecutable;
using infrt::host_context::RemainingArguments; using infrt::host_context::RemainingArguments;
namespace infrt::kernel { namespace infrt {
namespace kernel {
namespace { namespace {
class BenchmarkStats { class BenchmarkStats {
public: public:
...@@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) { ...@@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShadowCopyTensor)); INFRT_KERNEL(ShadowCopyTensor));
} }
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -15,17 +15,21 @@ ...@@ -15,17 +15,21 @@
#pragma once #pragma once
#include <string> #include <string>
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct KernelRegistry; struct KernelRegistry;
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
namespace infrt::kernel { namespace infrt {
namespace kernel {
/** /**
* Register all the test kernels to registry. * Register all the test kernels to registry.
*/ */
void RegisterTestKernels(host_context::KernelRegistry* registry); void RegisterTestKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace infrt::paddle::cpp { namespace infrt {
namespace paddle {
namespace cpp {
/* /*
* Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc * Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc
...@@ -226,4 +228,6 @@ class ProgramDescAPI { ...@@ -226,4 +228,6 @@ class ProgramDescAPI {
virtual void SetVersion(int64_t version) = 0; virtual void SetVersion(int64_t version) = 0;
}; };
} // namespace infrt::paddle::cpp } // namespace cpp
} // namespace paddle
} // namespace infrt
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
#include "paddle/infrt/common/target.h" #include "paddle/infrt/common/target.h"
#include "paddle/infrt/common/type.h" #include "paddle/infrt/common/type.h"
namespace infrt::paddle { namespace infrt {
namespace paddle {
int SizeOfType(framework_proto::VarType::Type type) { int SizeOfType(framework_proto::VarType::Type type) {
using Type = framework_proto::VarType::Type; using Type = framework_proto::VarType::Type;
...@@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) { ...@@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) {
LoadLoDTensor(fin, out, target); LoadLoDTensor(fin, out, target);
} }
} // namespace infrt::paddle } // namespace paddle
} // namespace infrt
...@@ -25,7 +25,8 @@ ...@@ -25,7 +25,8 @@
#include "paddle/infrt/paddle/scope.h" #include "paddle/infrt/paddle/scope.h"
#include "paddle/infrt/paddle/tensor.h" #include "paddle/infrt/paddle/tensor.h"
namespace infrt::paddle { namespace infrt {
namespace paddle {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
// Read a __model__ file. // Read a __model__ file.
...@@ -52,4 +53,5 @@ void TensorFromStream( ...@@ -52,4 +53,5 @@ void TensorFromStream(
const common::Target& target = common::DefaultHostTarget()); const common::Target& target = common::DefaultHostTarget());
void ReadBinaryFile(const std::string& filename, std::string* contents); void ReadBinaryFile(const std::string& filename, std::string* contents);
} // namespace infrt::paddle } // namespace paddle
} // namespace infrt
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/block_desc.h" #include "paddle/infrt/paddle/pb/block_desc.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
template <> template <>
framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>( framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>(
...@@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp<framework_proto::OpDesc>() { ...@@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp<framework_proto::OpDesc>() {
return desc_->add_ops(); return desc_->add_ops();
} }
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
...@@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI { ...@@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI {
framework_proto::BlockDesc* desc_; // not_own framework_proto::BlockDesc* desc_; // not_own
}; };
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/op_desc.h" #include "paddle/infrt/paddle/pb/op_desc.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
google::protobuf::internal::RepeatedPtrIterator<framework_proto::OpDesc_Attr> google::protobuf::internal::RepeatedPtrIterator<framework_proto::OpDesc_Attr>
FindAttr(framework_proto::OpDesc *desc, const std::string &name) { FindAttr(framework_proto::OpDesc *desc, const std::string &name) {
...@@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector<std::string>, strings); ...@@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector<std::string>, strings);
GET_ATTR_IMPL(std::string, s); GET_ATTR_IMPL(std::string, s);
GET_ATTRS_IMPL(std::vector<int64_t>, longs); GET_ATTRS_IMPL(std::vector<int64_t>, longs);
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
#include "paddle/infrt/support/variant.h" #include "paddle/infrt/support/variant.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
...@@ -195,4 +197,6 @@ template <> ...@@ -195,4 +197,6 @@ template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name, void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v); const std::vector<int> &v);
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
template <> template <>
framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>( framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>(
...@@ -32,4 +34,6 @@ ProgramDesc::AddBlock<framework_proto::BlockDesc>() { ...@@ -32,4 +34,6 @@ ProgramDesc::AddBlock<framework_proto::BlockDesc>() {
return desc_->add_blocks(); return desc_->add_blocks();
} }
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
class ProgramDesc : public cpp::ProgramDescAPI { class ProgramDesc : public cpp::ProgramDescAPI {
...@@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI { ...@@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI {
framework_proto::ProgramDesc *desc_; // not_own framework_proto::ProgramDesc *desc_; // not_own
}; };
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
cpp::VarDescAPI::Type VarDesc::GetType() const { cpp::VarDescAPI::Type VarDesc::GetType() const {
auto type = desc_->type().type(); auto type = desc_->type().type();
...@@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() { ...@@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() {
return std::vector<framework_proto::VarType::TensorDesc *>(); return std::vector<framework_proto::VarType::TensorDesc *>();
} }
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
// convert between std::vector and protobuf repeated. // convert between std::vector and protobuf repeated.
...@@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI { ...@@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI {
framework_proto::VarDesc *desc_; framework_proto::VarDesc *desc_;
}; };
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -435,6 +435,10 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place, ...@@ -435,6 +435,10 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place,
} }
void DenseTensor::ShareBufferWith(const DenseTensor& tensor) { void DenseTensor::ShareBufferWith(const DenseTensor& tensor) {
if (storage_ == nullptr) {
storage_ = make_intrusive<paddle::experimental::SharedStorage>(
paddle::platform::CPUPlace());
}
if (storage_ != nullptr && tensor.storage_ != nullptr) { if (storage_ != nullptr && tensor.storage_ != nullptr) {
storage_->set_data_shared(tensor.storage_->data_shared()); storage_->set_data_shared(tensor.storage_->data_shared());
} }
......
...@@ -152,6 +152,9 @@ def ShardingScaler(scaler): ...@@ -152,6 +152,9 @@ def ShardingScaler(scaler):
param_grads = [] param_grads = []
param_grads_fp16 = [] param_grads_fp16 = []
param_grads_fp32 = [] param_grads_fp32 = []
if hasattr(optimizer, "update_slice"):
optimizer.update_slice()
optimizer.update_scaler = True
if getattr(optimizer._optim, '_param_groups', None) and isinstance( if getattr(optimizer._optim, '_param_groups', None) and isinstance(
optimizer._optim._param_groups[0], dict): optimizer._optim._param_groups[0], dict):
...@@ -161,27 +164,21 @@ def ShardingScaler(scaler): ...@@ -161,27 +164,21 @@ def ShardingScaler(scaler):
if param._grad_ivar() is not None: if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar()) param_grads.append(param._grad_ivar())
if param._grad_ivar( if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16: ).dtype in [core.VarDesc.VarType.FP16, paddle.float16]:
param_grads_fp16.append(param._grad_ivar()) param_grads_fp16.append(param._grad_ivar())
else: else:
param_grads_fp32.append(param._grad_ivar()) param_grads_fp32.append(param._grad_ivar())
else: else:
param_grads = [ for param in optimizer._optim._parameter_list:
param._grad_ivar() for param in optimizer._optim._parameter_list if param.grad is not None:
if param._grad_ivar() is not None param_grads.append(param.grad)
] if param.grad.dtype in [
param_grads_fp16 = [ core.VarDesc.VarType.FP16, paddle.float16
param._grad_ivar() for param in optimizer._optim._parameter_list ]:
if (param._grad_ivar() is not None param_grads_fp16.append(param.grad)
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16 else:
) param_grads_fp32.append(param.grad)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._optim._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32
)
]
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
......
...@@ -34,6 +34,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) ...@@ -34,6 +34,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3)
list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
...@@ -250,6 +251,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -250,6 +251,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
...@@ -1058,6 +1060,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) ...@@ -1058,6 +1060,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120)
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2021)
np.random.seed(2021)
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
fleet.init(is_collective=True)
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model,
sharding_stage,
use_pure_fp16=False,
accumulate_grad=False,
opt_group=False,
recompute=False):
group = paddle.distributed.new_group([0, 1])
if opt_group:
optimizer = optimizer_setting(
model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group)
else:
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = ShardingScaler(scaler)
if sharding_stage == 2:
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=accumulate_grad)
elif sharding_stage == 3:
model = ShardingStage3(
model, optimizer=optimizer, group=group, sync_comm=recompute)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if sharding_stage == 3:
model.get_all_parameters()
return model.parameters()
def test_stage2_stage3():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict)
# fp32
stage2_params = train_mlp(
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True)
stage3_params = train_mlp(
mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp32 accumulate grad
stage2_params = train_mlp(
mlp3,
sharding_stage=2,
use_pure_fp16=False,
accumulate_grad=True,
opt_group=True)
stage3_params = train_mlp(
mlp4,
sharding_stage=3,
use_pure_fp16=False,
accumulate_grad=True,
opt_group=True)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp16
stage2_params = train_mlp(
mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False)
stage3_params = train_mlp(
mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp16 recompute
stage3_params = train_mlp(
mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False)
stage3_params_re = train_mlp(
mlp8,
sharding_stage=3,
use_pure_fp16=True,
opt_group=False,
recompute=True)
for i in range(len(stage3_params)):
for j in range(len(stage3_params_re)):
if stage3_params[i].name == stage3_params_re[j].name:
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_re[j].numpy(),
rtol=1e-6)
return
if __name__ == '__main__':
test_stage2_stage3()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import unittest
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertFlattenContiguousRangeTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(batch):
return np.random.random([2, batch, 4, 8, 3]).astype(np.float32)
for batch in [1, 2, 4]:
for start_axis in range(5):
for stop_axis in range(start_axis, 5):
type = "flatten_contiguous_range"
op_outputs = {
"Out": ["output_data"],
"XShape": ["xshape_data"]
}
ops_config = [{
"op_type": type,
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": op_outputs,
"op_attrs": {
"start_axis": start_axis,
"stop_axis": stop_axis,
}
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, batch))
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [2, 1, 4, 8, 3]}
self.dynamic_shape.max_input_shape = {"input_data": [2, 4, 4, 8, 3]}
self.dynamic_shape.opt_input_shape = {"input_data": [2, 2, 4, 8, 3]}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 >= 7000:
if dynamic_shape:
return 1, 2
else:
if attrs[0]['start_axis'] == 0:
return 0, 3
else:
return 1, 2
else:
return 0, 3
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage3(self):
self.run_mnist_2gpu('dygraph_sharding_stage3.py')
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -66,6 +66,15 @@ class TestStackOpBase(XPUOpTest): ...@@ -66,6 +66,15 @@ class TestStackOpBase(XPUOpTest):
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
self.check_output_with_place(place) self.check_output_with_place(place)
def test_check_grad(self):
if self.dtype == 'int64' or self.dtype == 'int32':
pass
else:
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, self.get_x_names(), 'Y')
class TestStackOp1(TestStackOpBase): class TestStackOp1(TestStackOpBase):
def initParameters(self): def initParameters(self):
...@@ -81,11 +90,17 @@ class TestStackOp3(TestStackOpBase): ...@@ -81,11 +90,17 @@ class TestStackOp3(TestStackOpBase):
def initParameters(self): def initParameters(self):
self.axis = -1 self.axis = -1
def test_check_grad(self):
pass
class TestStackOp4(TestStackOpBase): class TestStackOp4(TestStackOpBase):
def initParameters(self): def initParameters(self):
self.axis = -4 self.axis = -4
def test_check_grad(self):
pass
class TestStackOp5(TestStackOpBase): class TestStackOp5(TestStackOpBase):
def initParameters(self): def initParameters(self):
...@@ -113,7 +128,7 @@ class TestStackOpint(TestStackOpBase): ...@@ -113,7 +128,7 @@ class TestStackOpint(TestStackOpBase):
self.num_inputs = 4 self.num_inputs = 4
self.input_dim = (5, 6, 7) self.input_dim = (5, 6, 7)
self.axis = 0 self.axis = 0
self.dtype = 'int' self.dtype = 'int32'
def initParameters(self): def initParameters(self):
self.num_inputs = 16 self.num_inputs = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册