diff --git a/cmake/external/llvm.cmake b/cmake/external/llvm.cmake index e080a7359af98276be2d6bfc53e6b5917f83bde9..27210e5260048a57cc442fce4c6cf8657e401568 100644 --- a/cmake/external/llvm.cmake +++ b/cmake/external/llvm.cmake @@ -1,7 +1,7 @@ include(FetchContent) -set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz) -set(LLVM_MD5 39d32b6be466781dddf5869318dcba53) +set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/infrt/llvm_b5149f4e66a49a98b67e8e2de4e24a4af8e2781b.tar.gz) +set(LLVM_MD5 022819bb5760817013cf4b8a37e97d5e) set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm) set(FETCHCONTENT_QUIET OFF) @@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") # To build with MLIR, the LLVM is build from source code using the following flags: #[==[ -cmake -G Ninja ../llvm \ +cmake ../llvm -G "Unix Makefiles" \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ -DLLVM_BUILD_EXAMPLES=OFF \ -DLLVM_TARGETS_TO_BUILD="X86" \ @@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ZLIB=OFF \ -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DCMAKE_INSTALL_PREFIX=./install #]==] -# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit) +# The matched llvm-project version is b5149f4e66a49a98b67e8e2de4e24a4af8e2781b (currently a temporary commit) add_definitions(${LLVM_DEFINITIONS}) @@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS}) # The minimum needed libraries for MLIR IR parse and transform. -set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib) +set(MLIR_IR_LIBS MLIRAnalysis MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib) # tb_base is the name of a xxx.td file (without the .td suffix) @@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base) mlir_tablegen(${td_base}.cpp.inc -gen-op-defs) if (mlir_tablegen_on_DIALECT) mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT}) + mlir_tablegen(${td_base}_dialect.cpp.inc --gen-dialect-defs -dialect=${mlir_tablegen_on_DIALECT}) endif() add_public_tablegen_target(${td_base}_IncGen) add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index ef50df3084f8cea7e0a137cc86e24f7e3c17fdd7..55bbc55450876c47dc0affb27323dbf397cc5c6c 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -46,8 +46,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( << " is diabled by config in TensorRT"; return false; } - return tensorrt::OpTeller::Global().Tell(node, no_calib_int8, - with_dynamic_shape); + bool is_ok = tensorrt::OpTeller::Global().Tell(node, no_calib_int8, + with_dynamic_shape); + if (!is_ok) + VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT"; + return is_ok; }; framework::ir::SubGraphFuser fuser( diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 2799fb9e174d3209143d4be3a95250fb2eb882e6..d4b680288e347947c9ef5e86e2138cfdae9d1b83 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1416,6 +1416,7 @@ USE_TRT_CONVERTER(elementwise_min_tensor); USE_TRT_CONVERTER(elementwise_pow_tensor); USE_TRT_CONVERTER(transpose); USE_TRT_CONVERTER(flatten); +USE_TRT_CONVERTER(flatten_contiguous_range); USE_TRT_CONVERTER(matmul); USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(relu); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index a885b69fa7fbcc19e4fe4825410d2f862ba8c568..017caca6adc814af32d6045ce0510099c5935ed8 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -3,7 +3,7 @@ nv_library(tensorrt_converter SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc - shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc + shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc gather_op.cc anchor_generator_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/flatten_contiguous_range_op.cc b/paddle/fluid/inference/tensorrt/convert/flatten_contiguous_range_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..706814340a0e9c06dfa11dd68500ceca040cbf00 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/flatten_contiguous_range_op.cc @@ -0,0 +1,136 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { +/* + * flatten_contiguous_range trt converter + */ +class FlattenContiguousRangeOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + int dims = input->getDimensions().nbDims; + int start_axis = BOOST_GET_CONST(int, op_desc.GetAttr("start_axis")); + int stop_axis = BOOST_GET_CONST(int, op_desc.GetAttr("stop_axis")); + + nvinfer1::IShuffleLayer* layer = nullptr; + if (!engine_->with_dynamic_shape()) { + if (start_axis < 0) start_axis += dims + 1; + if (stop_axis < 0) stop_axis += dims + 1; + int dim_prod = 1; + nvinfer1::Dims flatten_dim; + flatten_dim.nbDims = dims - (stop_axis - start_axis); + for (int i = 0, j = 0; i < dims; ++i) { + if (start_axis <= i + 1 && i + 1 <= stop_axis) { + int dim_i = input->getDimensions().d[i]; + PADDLE_ENFORCE_GT(dim_i, 0, platform::errors::InvalidArgument( + "flatten_contiguous_range input dim " + "should be > 0, but got %d.", + dim_i)); + dim_prod *= dim_i; + if (i + 1 == stop_axis) { + flatten_dim.d[j++] = dim_prod; + } + } else { + flatten_dim.d[j++] = input->getDimensions().d[i]; + } + } + layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + layer->setReshapeDimensions(flatten_dim); + } else { + if (start_axis < 0) start_axis += dims; + if (stop_axis < 0) stop_axis += dims; + auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input); + auto* shape_layer_itensor = shape_layer->getOutput(0); + + nvinfer1::Dims start_dim, size_dim, stride_dim; + start_dim.nbDims = 1; + size_dim.nbDims = 1; + stride_dim.nbDims = 1; + start_dim.d[0] = start_axis; + size_dim.d[0] = stop_axis - start_axis + 1; + stride_dim.d[0] = 1; + auto* slice_layer = + TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor, start_dim, + size_dim, stride_dim); + uint32_t reduce_dim = 1; + auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER( + engine_, Reduce, *(slice_layer->getOutput(0)), + nvinfer1::ReduceOperation::kPROD, reduce_dim, true); + + nvinfer1::ITensor* input_shape = nullptr; + if (start_axis == 0 && stop_axis == dims - 1) { + input_shape = reduce_prod_layer->getOutput(0); + } else { + std::vector itensors; + if (start_axis > 0) { + nvinfer1::Dims left_start_dim, left_size_dim, left_stride_dim; + left_start_dim.nbDims = 1; + left_size_dim.nbDims = 1; + left_stride_dim.nbDims = 1; + left_start_dim.d[0] = 0; + left_size_dim.d[0] = start_axis; + left_stride_dim.d[0] = 1; + auto* slice_layer_left = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *shape_layer_itensor, left_start_dim, + left_size_dim, left_stride_dim); + itensors.push_back(slice_layer_left->getOutput(0)); + } + itensors.push_back(reduce_prod_layer->getOutput(0)); + if (stop_axis < dims - 1) { + nvinfer1::Dims right_start_dim, right_size_dim, right_stride_dim; + right_start_dim.nbDims = 1; + right_size_dim.nbDims = 1; + right_stride_dim.nbDims = 1; + right_start_dim.d[0] = stop_axis + 1; + right_size_dim.d[0] = dims - stop_axis - 1; + right_stride_dim.d[0] = 1; + auto* slice_layer_right = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *shape_layer_itensor, right_start_dim, + right_size_dim, right_stride_dim); + itensors.push_back(slice_layer_right->getOutput(0)); + } + auto* concat_layer = TRT_ENGINE_ADD_LAYER( + engine_, Concatenation, itensors.data(), itensors.size()); + concat_layer->setAxis(0); + input_shape = concat_layer->getOutput(0); + } + layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + layer->setInput(1, *input_shape); + } + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "flatten_contiguous_range", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(flatten_contiguous_range, + FlattenContiguousRangeOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ddee4e0d682b0107589b7f1a267c67742e4ae074..6663103d4ca37445b96ce53fa39ddc3474988999 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -55,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller { // #endif #if IS_TRT_VERSION_GE(7000) teller_set.insert("tile"); + teller_set.insert("flatten_contiguous_range"); #endif #if CUDA_VERSION >= 10020 teller_set.insert("reshape"); @@ -531,6 +532,37 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (axis != 1) return false; } } + if (op_type == "flatten_contiguous_range") { + if (!with_dynamic_shape) { + int start_axis = BOOST_GET_CONST(int, desc.GetAttr("start_axis")); + int stop_axis = BOOST_GET_CONST(int, desc.GetAttr("stop_axis")); + auto x_var_name = desc.Input("X")[0]; + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + auto* x_var_desc = block->FindVar(x_var_name); + const auto x_shape = x_var_desc->GetShape(); + int dims = x_shape.size(); + if (start_axis < 0) start_axis += dims; + if (start_axis == 0) { + VLOG(3) << "TRT flatten_contiguous_range not support the " + "batch-dimension being changed"; + return false; + } + if (stop_axis < 0) stop_axis += dims; + for (int i = start_axis; i <= stop_axis; ++i) { + if (x_shape[i] < 0) { + VLOG(3) << "On TRT static shape,flatten_contiguous_range input dim " + "should be > 0"; + return false; + } + } + } + } if (op_type == "gather") { auto gather_inputs = desc.Inputs(); diff --git a/paddle/fluid/operators/stack_op_xpu.cc b/paddle/fluid/operators/stack_op_xpu.cc index 01ec4a2b16b4a4684c04d7f781a807ec69527644..a2590e1180c1a3e782190a6279b85a5fbdc2060a 100644 --- a/paddle/fluid/operators/stack_op_xpu.cc +++ b/paddle/fluid/operators/stack_op_xpu.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/stack_op.h" #include -#ifdef PADDLE_WITH_XPU +#include +#include "paddle/fluid/operators/concat_op.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { namespace operators { @@ -59,14 +62,44 @@ class StackXPUKernel : public framework::OpKernel { } }; +template +class StackGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto dx = ctx.MultiOutput(framework::GradVarName("X")); + auto axis = ctx.Attr("axis"); + auto& dev_ctx = ctx.template device_context(); + auto dy_dims = dy->dims(); + + if (axis < 0) axis += dy_dims.size() + 1; + auto dy_shape = framework::vectorize(dy_dims); + + std::vector dx_dims_list(dx.size(), 1); + std::vector dx_lists; + for (auto out : dx) { + dx_lists.push_back(out->mutable_data(ctx.GetPlace())); + } + + int r = xpu::split(dev_ctx.x_context(), dy->data(), dx_lists, + dy_shape, dx_dims_list, axis); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "The stack_grad XPU kernel return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + } // namespace operators } // namespace paddle namespace plat = paddle::platform; namespace ops = paddle::operators; - REGISTER_OP_XPU_KERNEL(stack, - ops::StackXPUKernel, + ops::StackXPUKernel, ops::StackXPUKernel, - ops::StackXPUKernel); + ops::StackXPUKernel); +REGISTER_OP_XPU_KERNEL(stack_grad, + ops::StackGradXPUKernel, + ops::StackGradXPUKernel); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu1_op_list.h b/paddle/fluid/platform/device/xpu/xpu1_op_list.h index 26a1426bea0360735a59ba4e912482b0e8cc2b02..a76bdd4ae967987748abe4aefa144ce3ac83a545 100644 --- a/paddle/fluid/platform/device/xpu/xpu1_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu1_op_list.h @@ -300,6 +300,7 @@ XPUOpMap& get_kl1_ops() { pOpKernelType(vartype::UINT8, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 79261a5d7bc88ea8007acbab07f4c21ea522f0f0..3d140b4693a6fa4e02f178638c528a7c90e71501 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -333,6 +333,8 @@ XPUOpMap& get_kl2_ops() { {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, + {"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index 8f05d286bf0339e52eecdb043731bba41db7504d..8af3012a220ad1e06803b6832dc3c3558af7bb53 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -77,7 +77,6 @@ add_subdirectory(paddle) # MLIR td file generations set(infrt_mlir_incs - ops_inc basic_kernels_inc test_kernels_inc infrt_base_inc diff --git a/paddle/infrt/common/global.h b/paddle/infrt/common/global.h index f89164d03f31dedc81aca779f16fd42f979f3aab..e6586cb3a3c603ed352b360a45c3cce879978657 100644 --- a/paddle/infrt/common/global.h +++ b/paddle/infrt/common/global.h @@ -14,7 +14,7 @@ #pragma once -#include "mlir/IR/MLIRContext.h" +#include #include "paddle/infrt/tensor/dense_host_tensor.h" namespace infrt { diff --git a/paddle/infrt/dialect/CMakeLists.txt b/paddle/infrt/dialect/CMakeLists.txt index d145843684c6366897d4347d66998af71e4250c2..c064b2145266bfb44f05c0c118b03388fa1b8e8b 100644 --- a/paddle/infrt/dialect/CMakeLists.txt +++ b/paddle/infrt/dialect/CMakeLists.txt @@ -2,7 +2,6 @@ core_gather_headers() gather_srcs(infrt_src SRCS dialect.cc - types.cc basic_kernels.cc test_kernels.cc infrt_base.cc @@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS pd_types.cc pd_ops.cc ) - -mlir_tablegen_on(ops) mlir_tablegen_on(basic_kernels) mlir_tablegen_on(test_kernels) mlir_tablegen_on(infrt_base DIALECT infrt) @@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite) # TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code add_executable(infrtopt opt.cc) -target_link_libraries(infrtopt infrt ${mlir_libs}) -add_dependencies(infrtopt infrt) +target_link_libraries(infrtopt infrt) add_executable(print-ir print_ir.cc) target_link_libraries(print-ir infrt ${mlir_libs}) diff --git a/paddle/infrt/dialect/basic_kernels.cc b/paddle/infrt/dialect/basic_kernels.cc index b4d2b9182b0c5035f829715c21970a47fb79e9cb..bad7e73ec5ae5c3216a912729637664bba17d3b0 100644 --- a/paddle/infrt/dialect/basic_kernels.cc +++ b/paddle/infrt/dialect/basic_kernels.cc @@ -17,17 +17,17 @@ #include #include #include -#include -#include +#include +#include #include #include -#include #include #include #include "paddle/infrt/dialect/dense_tensor.h" -namespace infrt::dialect { +namespace infrt { +namespace dialect { using namespace mlir; // NOLINT static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT @@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT OperationState &result) { // NOLINT return parseConstantOp( - IntegerType::get(32, result.getContext()), parser, result); + IntegerType::get(result.getContext(), 32), parser, result); } static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT OperationState &result) { // NOLINT return parseConstantOp( - IntegerType::get(64, result.getContext()), parser, result); + IntegerType::get(result.getContext(), 64), parser, result); } static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT @@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT } static void print(OpAsmPrinter &p, CallOp op) { // NOLINT - p << "infrt.call " << op.getAttr("callee") << "("; + p << "infrt.call " << op->getAttr("callee") << "("; p.printOperands(op.getOperands()); p << ")"; - p.printOptionalAttrDict(op.getAttrs(), {"callee"}); + p.printOptionalAttrDict(op->getAttrs(), {"callee"}); p << " : "; } @@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); } static LogicalResult verify(ConstantI64Op op) { return success(); } static LogicalResult verify(ReturnOp op) { - auto function = dyn_cast(op.getParentOp()); + auto function = dyn_cast(op->getParentOp()); if (!function) return success(); @@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) { return success(); } +} // namespace dialect +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/basic_kernels.cpp.inc" - -} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/basic_kernels.h b/paddle/infrt/dialect/basic_kernels.h index 65316bc1437c027a03e629e8f5cab868b5470758..b82abcd52d28f45b18824d9ea6f9e12c2ec1c574 100644 --- a/paddle/infrt/dialect/basic_kernels.h +++ b/paddle/infrt/dialect/basic_kernels.h @@ -13,12 +13,9 @@ // limitations under the License. #pragma once +#include #include #include -using namespace mlir; // NOLINT - -namespace infrt::dialect { #define GET_OP_CLASSES #include "paddle/infrt/dialect/basic_kernels.hpp.inc" -} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/basic_kernels.td b/paddle/infrt/dialect/basic_kernels.td index df5e4d8a2c6a1c50bb959ec5ec4a18b6bf451d59..7d8de79fbae2b0cb36ca354b8f6f39fc94851ebe 100644 --- a/paddle/infrt/dialect/basic_kernels.td +++ b/paddle/infrt/dialect/basic_kernels.td @@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> { let results = (outs Variadic); let extraClassDeclaration = [{ - StringRef getCallee() { return callee(); } + mlir::StringRef getCallee() { return callee(); } mlir::FunctionType getCalleeType(); }]; } @@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> { let arguments = (ins Variadic:$operands); - let builders = [OpBuilder< - "OpBuilder &b, OperationState &result", - [{ build(b, result, llvm::None); }]>]; + let builders = [OpBuilder<(ins), + [{ build($_builder, $_state, llvm::None); }]>]; } class AddOp : INFRT_Op<"add." # suffix, [NoSideEffect]> { diff --git a/paddle/infrt/dialect/dense_tensor.cc b/paddle/infrt/dialect/dense_tensor.cc index 629a7b16523fcaabe789b7a5f8d2146c6cd7633d..7685cdc65b9ad00492e0ca8a084ac7c686c94d89 100644 --- a/paddle/infrt/dialect/dense_tensor.cc +++ b/paddle/infrt/dialect/dense_tensor.cc @@ -17,12 +17,11 @@ #include #include #include +#include +#include #include -#include -#include #include #include -#include #include #include @@ -31,68 +30,37 @@ #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/tensor_shape.h" -namespace infrt::dt { - +namespace infrt { +namespace dt { void DTDialect::initialize() { - allowUnknownTypes(); addOperations< #define GET_OP_LIST #include "paddle/infrt/dialect/dense_tensor.cpp.inc" >(); } -namespace detail { -struct TensorTypeStorage : public mlir::TypeStorage { - TensorTypeStorage(TargetType target, - LayoutType layout, - PrecisionType precision) - : target_(target), layout_(layout), precision_(precision) {} - - using KeyTy = std::tuple; - - bool operator==(const KeyTy &key) const { - return key == KeyTy(target_, layout_, precision_); - } - - static llvm::hash_code hashKey(const KeyTy &key) { - return llvm::hash_value(key); - } - - static TensorTypeStorage *construct( - mlir::TypeStorageAllocator &allocator, // NOLINT - const KeyTy &key) { - return new (allocator.allocate()) - TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key)); - } - - TargetType target_; - LayoutType layout_; - PrecisionType precision_; -}; -} // namespace detail - llvm::Optional GetTargetType(mlir::StringRef key) { - if (key.equals_lower("x86")) + if (key.equals_insensitive("x86")) return TargetType::X86; - else if (key.equals_lower("cuda")) + else if (key.equals_insensitive("cuda")) return TargetType::CUDA; else return llvm::None; } llvm::Optional GetLayoutType(mlir::StringRef key) { - if (key.equals_lower("nchw")) + if (key.equals_insensitive("nchw")) return LayoutType::NCHW; - else if (key.equals_lower("nhwc")) + else if (key.equals_insensitive("nhwc")) return LayoutType::NHWC; else return llvm::None; } llvm::Optional GetPrecisionType(mlir::StringRef key) { - if (key.equals_lower("i32")) + if (key.equals_insensitive("i32")) return PrecisionType::I32; - else if (key.equals_lower("f32")) + else if (key.equals_insensitive("f32")) return PrecisionType::F32; else return llvm::None; @@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; } PrecisionType TensorType::precision() { return getImpl()->precision_; } -raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) { +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType) { os << "TensorType<" << tensorType.target() << ", " << tensorType.layout() << ", " << tensorType.precision() << ">"; return os; @@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) { return Base::get(context); } -raw_ostream &operator<<(raw_ostream &os, TargetType type) { +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type) { switch (type) { case (TargetType::X86): os << "X86"; @@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) { return os; } -raw_ostream &operator<<(raw_ostream &os, LayoutType type) { +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type) { switch (type) { case (LayoutType::NCHW): os << "NCHW"; @@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) { return os; } -raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type) { switch (type) { case (PrecisionType::I32): os << "I32"; @@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { return os; } -static Type getTensorType(mlir::MLIRContext *context) { - auto t_dialect = Identifier::get("t", context); - return OpaqueType::get(t_dialect, "tensor", context); +static mlir::Type getTensorType(mlir::MLIRContext *context) { + auto t_dialect = mlir::Identifier::get("t", context); + return mlir::OpaqueType::get(t_dialect, "tensor"); } -static ParseResult parseCreateUninitTensorOp( - OpAsmParser &parser, // NOLINT - OperationState &result) { // NOLINT +static mlir::ParseResult parseCreateUninitTensorOp( + mlir::OpAsmParser &parser, // NOLINT + mlir::OperationState &result) { // NOLINT auto loc = parser.getCurrentLocation(); - ::mlir::Type outputRawTypes[1]; - ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef outputTypes(outputRawTypes); mlir::ArrayAttr shapeAttr; if (parser.parseAttribute(shapeAttr, parser.getBuilder().getI64Type(), "shape", result.attributes)) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + return mlir::failure(); + if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure(); - if (parser.parseArrow()) return failure(); - if (parser.parseType(outputRawTypes[0])) return failure(); + if (parser.parseArrow()) return mlir::failure(); + if (parser.parseType(outputRawTypes[0])) return mlir::failure(); if (!outputRawTypes[0].isa()) return parser.emitError(loc, "invalid kind of type specified"); result.addTypes(outputTypes); - return success(); + return mlir::success(); } template -static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT +static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p, // NOLINT CreateUninitTensorOp op) { p << CreateUninitTensorOp::getOperationName(); p << " "; p.printAttributeWithoutType(op.shapeAttr()); - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); + p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); p << " -> "; p << op.getOperation()->getResultTypes(); } -// TODO(shibo): can be removed? -// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser, -// OperationState& result) { -// auto loc = parser.getCurrentLocation(); -// ::mlir::OpAsmParser::OperandType inputRawOperands[1]; -// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> -// inputOperands(inputRawOperands); -// ::mlir::Type inputRawTypes[1]; -// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); -// -// if (parser.parseOperand(inputRawOperands[0])) return failure(); -// -// if (parser.parseColon()) return failure(); -// if (parser.parseType(inputRawTypes[0])) return failure(); -// if (!inputRawTypes[0].isa()) -// return parser.emitError(loc, "invalid kind of type specified"); -// -// Attribute value_attr; -// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands)) -// return failure(); -// if (parser.parseAttribute(value_attr, "value", result.attributes)) return -// failure(); -// return success(); -//} - -// TODO(shibo): can be removed? -// template -// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) { -// p << FillTensorOp::getOperationName(); -// p << " "; -// p.printOperand(op.getOperand()); -// p << " : "; -// p << op.getOperation()->getOperandTypes(); -// p << " "; -// p << op.getAttr("value"); -//} - -static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT - OperationState &result) { // NOLINT - SmallVector operands; - if (parser.parseOperandList(operands, 1)) return failure(); +static mlir::ParseResult parseSetTensorOp( + mlir::OpAsmParser &parser, // NOLINT + mlir::OperationState &result) { // NOLINT + llvm::SmallVector operands; + if (parser.parseOperandList(operands, 1)) return mlir::failure(); auto tensor_type = getTensorType(result.getContext()); - Attribute value_attr; - return failure( + mlir::Attribute value_attr; + return mlir::failure( parser.resolveOperand(operands[0], tensor_type, result.operands) || parser.parseAttribute(value_attr, "values", result.attributes)); } template -static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT +static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) { // NOLINT p << SetTensorOp::getOperationName() << " "; p.printOperand(op.getOperand()); - p << " " << op.getAttr("values"); + p << " " << op->getAttr("values"); } +} // namespace dt +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT -} // namespace infrt::dt +#include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc" diff --git a/paddle/infrt/dialect/dense_tensor.h b/paddle/infrt/dialect/dense_tensor.h index 866c62213ab058037bafb116602cc0d609fd3bec..416925d3382bad640753b77e5516d6e45a425eef 100644 --- a/paddle/infrt/dialect/dense_tensor.h +++ b/paddle/infrt/dialect/dense_tensor.h @@ -19,13 +19,8 @@ #include -using namespace mlir; // NOLINT -namespace infrt::dt { - -namespace detail { -struct TensorTypeStorage; -} // namespace detail - +namespace infrt { +namespace dt { enum class TargetType : uint8_t { X86, CUDA }; enum class LayoutType : uint8_t { NCHW, NHWC }; enum class PrecisionType : uint8_t { I32, F32 }; @@ -34,9 +29,39 @@ llvm::Optional GetTargetType(mlir::StringRef key); llvm::Optional GetLayoutType(mlir::StringRef key); llvm::Optional GetPrecisionType(mlir::StringRef key); -raw_ostream &operator<<(raw_ostream &os, TargetType type); -raw_ostream &operator<<(raw_ostream &os, LayoutType type); -raw_ostream &operator<<(raw_ostream &os, PrecisionType type); +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type); +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type); +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type); + +namespace detail { +struct TensorTypeStorage : public mlir::TypeStorage { + TensorTypeStorage(TargetType target, + LayoutType layout, + PrecisionType precision) + : target_(target), layout_(layout), precision_(precision) {} + + using KeyTy = std::tuple; + + bool operator==(const KeyTy &key) const { + return key == KeyTy(target_, layout_, precision_); + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static TensorTypeStorage *construct( + mlir::TypeStorageAllocator &allocator, // NOLINT + const KeyTy &key) { + return new (allocator.allocate()) + TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key)); + } + + TargetType target_; + LayoutType layout_; + PrecisionType precision_; +}; +} // namespace detail class TensorType : public mlir::Type::TypeBase #include -namespace infrt::dialect { +namespace infrt { +namespace dialect { struct MyScopedDiagnosicHandler::Impl { Impl() : diag_stream_(diag_str_) {} @@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) { return mlir::failure(true); } -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/diagnostic_utils.h b/paddle/infrt/dialect/diagnostic_utils.h index 3a8098cf751812d35dc3eac1041bed0536055288..746e61c8fe5c3151f3c6ea1da5bd105d1492082e 100644 --- a/paddle/infrt/dialect/diagnostic_utils.h +++ b/paddle/infrt/dialect/diagnostic_utils.h @@ -18,7 +18,8 @@ #include -namespace infrt::dialect { +namespace infrt { +namespace dialect { /** * A scoped diagnostic handler to help debug MLIR process. @@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler { std::unique_ptr impl_; }; -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/dialect.cc b/paddle/infrt/dialect/dialect.cc index cbcd5d0f0fa785a21c78d0ae25f40e6211a504ee..fe07b91d22ed54ae576d828f208ec766f5b719da 100644 --- a/paddle/infrt/dialect/dialect.cc +++ b/paddle/infrt/dialect/dialect.cc @@ -13,24 +13,26 @@ // limitations under the License. #include +#include #include -#include #include #include -#include #include #include -namespace infrt::hlir::dialect { +namespace infrt { +namespace hlir { +namespace dialect { -class CinnDialect : public ::mlir::Dialect { +class CinnDialect : public mlir::Dialect { public: - explicit CinnDialect(::mlir::MLIRContext* ctx); + explicit CinnDialect(mlir::MLIRContext* ctx); //! We should register this function in dialect static llvm::StringRef getDialectNamespace() { return "infrt::hlir::dialect"; } }; - -} // namespace infrt::hlir::dialect +} // namespace dialect +} // namespace hlir +} // namespace infrt diff --git a/paddle/infrt/dialect/infrt_base.cc b/paddle/infrt/dialect/infrt_base.cc index b28ad5ad4b5a59c898cc08303626df09b2ef70c9..e8005661bbd6527f6c21076fd0f3a362a5541968 100644 --- a/paddle/infrt/dialect/infrt_base.cc +++ b/paddle/infrt/dialect/infrt_base.cc @@ -18,7 +18,8 @@ #include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/test_kernels.h" -namespace infrt::dialect { +namespace infrt { +namespace dialect { // ----INFRTDialect definition begin---- void INFRTDialect::initialize() { @@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type, // ----INFRTDialect definition end---- -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/infrt_base.h b/paddle/infrt/dialect/infrt_base.h index 58acd7c9a409a5a13f31bf3bd1688f0bb26b3e0b..1a7fbcf395a6e9be70de6021f8e60f94922f32c3 100644 --- a/paddle/infrt/dialect/infrt_base.h +++ b/paddle/infrt/dialect/infrt_base.h @@ -18,19 +18,17 @@ #include #include #include -#include #include #include #include "paddle/infrt/dialect/infrt_base.hpp.inc" -namespace infrt::dialect { - -class INFRTDialect : public ::mlir::Dialect { - explicit INFRTDialect(::mlir::MLIRContext *context) - : ::mlir::Dialect(getDialectNamespace(), - context, - ::mlir::TypeID::get()) { +namespace infrt { +namespace dialect { +class INFRTDialect : public mlir::Dialect { + explicit INFRTDialect(mlir::MLIRContext *context) + : mlir::Dialect( + getDialectNamespace(), context, mlir::TypeID::get()) { initialize(); } @@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect { mlir::DialectAsmPrinter &printer) const override; void initialize(); - friend class ::mlir::MLIRContext; + friend class mlir::MLIRContext; public: static ::llvm::StringRef getDialectNamespace() { return "infrt"; } }; - -} // namespace infrt::dialect - -namespace mlir { +} // namespace dialect template static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT @@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT return b.getIntegerAttr(b.getI32Type(), constant); } -static mlir::SmallVector<::mlir::Value, 4> cvtValueToValueRange( +static mlir::SmallVector cvtValueToValueRange( const mlir::Value &operand) { - return mlir::SmallVector<::mlir::Value, 4>(1, operand); + return mlir::SmallVector(1, operand); } -static mlir::SmallVector<::mlir::Value, 4> concatTwoValueRange( +static mlir::SmallVector concatTwoValueRange( mlir::ValueRange operand_0, mlir::ValueRange operand_1) { - mlir::SmallVector<::mlir::Value, 4> operands; + mlir::SmallVector operands; operands.append(operand_0.begin(), operand_0.end()); operands.append(operand_1.begin(), operand_1.end()); return operands; } - -} // namespace mlir +} // namespace infrt diff --git a/paddle/infrt/dialect/infrt_base.td b/paddle/infrt/dialect/infrt_base.td index 7d6fdbbbf2f68f6629c2299f807cbb9fa7605f74..1abd294236d93cfb0aa7ce70db25f2acfb57a06a 100644 --- a/paddle/infrt/dialect/infrt_base.td +++ b/paddle/infrt/dialect/infrt_base.td @@ -28,11 +28,11 @@ def TensorMapType : def BufferType : OpaqueType<"b", "buffer", "buffer">; class INFRT_createI32Attr : NativeCodeCall< - "mlir::createI32Attr($_builder, $_loc, " # value # ")">; + "infrt::createI32Attr($_builder, $_loc, " # value # ")">; def INFRT_cvtValueToValueRange : NativeCodeCall< - "mlir::cvtValueToValueRange($0)">; + "infrt::cvtValueToValueRange($0)">; def INFRT_concatTwoValueRange : NativeCodeCall< - "mlir::concatTwoValueRange($0, $1)">; + "infrt::concatTwoValueRange($0, $1)">; #endif // INFRT_BASE diff --git a/paddle/infrt/dialect/init_infrt_dialects.cc b/paddle/infrt/dialect/init_infrt_dialects.cc index 4bc2bf70942d29723c731f90da446ee0acc257f5..c3769414dbb390566a177cfcec0b62009b53018a 100644 --- a/paddle/infrt/dialect/init_infrt_dialects.cc +++ b/paddle/infrt/dialect/init_infrt_dialects.cc @@ -23,12 +23,10 @@ #include "paddle/infrt/dialect/tensor_shape.h" namespace infrt { - -void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); +void registerCinnDialects(mlir::DialectRegistry ®istry) { // NOLINT + registry.insert(); } - } // namespace infrt diff --git a/paddle/infrt/dialect/init_infrt_dialects.h b/paddle/infrt/dialect/init_infrt_dialects.h index 50caca018980d05b112459ecf27f81e538cf9e2a..0912e9ef2555b49a7fd2d22c5e3ab6a457cbb05b 100644 --- a/paddle/infrt/dialect/init_infrt_dialects.h +++ b/paddle/infrt/dialect/init_infrt_dialects.h @@ -14,10 +14,8 @@ #pragma once -#include "mlir/IR/Dialect.h" - +#include +#include namespace infrt { - -void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT - +void registerCinnDialects(mlir::DialectRegistry ®istry); // NOLINT } // namespace infrt diff --git a/paddle/infrt/dialect/mlir_loader.cc b/paddle/infrt/dialect/mlir_loader.cc index b318a6a763483141de7c1521614cb82538615bb6..1d0696e77dcda612eeb8c367958e44e2efed5354 100644 --- a/paddle/infrt/dialect/mlir_loader.cc +++ b/paddle/infrt/dialect/mlir_loader.cc @@ -16,8 +16,8 @@ #include #include +#include #include -#include #include #include #include @@ -30,12 +30,15 @@ #include "paddle/infrt/dialect/diagnostic_utils.h" #include "paddle/infrt/dialect/init_infrt_dialects.h" -namespace infrt::dialect { +namespace infrt { +namespace dialect { mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, const std::string& mlir_source) { // context->allowUnregisteredDialects(); - RegisterCinnDialects(context->getDialectRegistry()); + mlir::DialectRegistry registry; + registerCinnDialects(registry); + context->appendDialectRegistry(registry); // Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is // enough。Don't need StandardOpsDialect. // context->getDialectRegistry().insert(); @@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, mlir::MLIRContext* context) { // context->allowUnregisteredDialects(); - RegisterCinnDialects(context->getDialectRegistry()); - context->getDialectRegistry().insert(); - + mlir::DialectRegistry registry; + registerCinnDialects(registry); + context->appendDialectRegistry(registry); mlir::ScopedDiagnosticHandler scope_handler( context, [](mlir::Diagnostic& diag) { if (diag.getSeverity() != mlir::DiagnosticSeverity::Error) @@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, return mlir::parseSourceFile(std::string(file_name), context); } -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/mlir_loader.h b/paddle/infrt/dialect/mlir_loader.h index 092da7d9ce03f64f43a2bfa237c7fa60983959a1..5e50ad9e5a27176a1bea32356b0cf343140bb441 100644 --- a/paddle/infrt/dialect/mlir_loader.h +++ b/paddle/infrt/dialect/mlir_loader.h @@ -15,16 +15,17 @@ #pragma once #include -#include +#include #include #include -namespace infrt::dialect { +namespace infrt { +namespace dialect { mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, const std::string& mlir_source); mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, mlir::MLIRContext* context); - -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/mlir_loader_test.cc b/paddle/infrt/dialect/mlir_loader_test.cc index 1b622d585ad8ee556ea8f35eb64560f49fb5710d..11150530730444ed74f547b9bb8abef5473c61b0 100644 --- a/paddle/infrt/dialect/mlir_loader_test.cc +++ b/paddle/infrt/dialect/mlir_loader_test.cc @@ -17,14 +17,15 @@ #include #include #include -#include +#include #include #include #include "paddle/infrt/dialect/init_infrt_dialects.h" -namespace infrt::dialect { +namespace infrt { +namespace dialect { TEST(MlirLoader, basic) { mlir::MLIRContext context; @@ -42,8 +43,7 @@ func @main() -> f32 { )ROC"; auto module = LoadMlirSource(&context, source); - module->verify(); - + EXPECT_TRUE(mlir::succeeded(module->verify())); LOG(INFO) << "module name: " << module->getOperationName().data(); for (auto func : module->getOps()) { LOG(INFO) << "get func " << func.getName().str(); @@ -54,4 +54,5 @@ func @main() -> f32 { } } -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/mlir_tests/rewrite.mlir b/paddle/infrt/dialect/mlir_tests/rewrite.mlir index bfad9d1f6924d4da7b818968ebb796cf8f346935..5e207634da8e4bb96719254700d7f30e4cdfe52a 100644 --- a/paddle/infrt/dialect/mlir_tests/rewrite.mlir +++ b/paddle/infrt/dialect/mlir_tests/rewrite.mlir @@ -20,5 +20,5 @@ func @main() -> tensor { %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor - infrt.return %e2 : tensor + "pd.fetch"(%e2) {name="output"} :(tensor)->() } \ No newline at end of file diff --git a/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir b/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir index 9ea1ec0ebca365b42be8d310793dc3c5f7dd4cf4..2889b92b18ef08fb6014eff948e2a5fc3d50c7f3 100644 --- a/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir +++ b/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir @@ -11,5 +11,5 @@ func @main() -> tensor { %c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor %d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor - infrt.return %d : tensor + "pd.fetch"(%d) {name="output"} :(tensor)->() } \ No newline at end of file diff --git a/paddle/infrt/dialect/mlir_tests/trt_ops.mlir b/paddle/infrt/dialect/mlir_tests/trt_ops.mlir index 009b6d1c19653e52a0ef0174892cdcbeccf18154..d98f107bab41e959d82acfd681d762d7981eab51 100644 --- a/paddle/infrt/dialect/mlir_tests/trt_ops.mlir +++ b/paddle/infrt/dialect/mlir_tests/trt_ops.mlir @@ -18,5 +18,5 @@ func @main() -> tensor { %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor - "pd.fetch"(%e2) :(tensor)->() + "pd.fetch"(%e2) {name="output"} :(tensor)->() } diff --git a/paddle/infrt/dialect/ops.td b/paddle/infrt/dialect/ops.td deleted file mode 100644 index 264134a447c63f637090e2f9919f2b97cad1ab4f..0000000000000000000000000000000000000000 --- a/paddle/infrt/dialect/ops.td +++ /dev/null @@ -1,6 +0,0 @@ -include "mlir/IR/OpBase.td" -include "paddle/infrt/dialect/infrt_base.td" - - -class INFRT_Op traits = []> : - Op; diff --git a/paddle/infrt/dialect/opt.cc b/paddle/infrt/dialect/opt.cc index d90d25230d0c24fb84ccdcf2cd282ba814b9a665..5bcf5a23f4c532b1056ceaa54c80902b32e4061a 100644 --- a/paddle/infrt/dialect/opt.cc +++ b/paddle/infrt/dialect/opt.cc @@ -12,34 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include - -#include - -#include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/init_infrt_dialects.h" -#include "paddle/infrt/dialect/mlir_loader.h" int main(int argc, char **argv) { - mlir::MLIRContext *context = infrt::Global::getMLIRContext(); - - auto ®istry = context->getDialectRegistry(); - infrt::RegisterCinnDialects(registry); - + mlir::DialectRegistry registry; + infrt::registerCinnDialects(registry); mlir::registerCanonicalizerPass(); - return mlir::failed( - mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry)); + mlir::MlirOptMain(argc, argv, "infrt mlir pass driver", registry)); } diff --git a/paddle/infrt/dialect/pd_op_base.td b/paddle/infrt/dialect/pd_op_base.td index af53df113dfb3e908d5066fed984a8c37942df25..a3e3c4ae592779c36f175ecfc20c154724be0863 100644 --- a/paddle/infrt/dialect/pd_op_base.td +++ b/paddle/infrt/dialect/pd_op_base.td @@ -16,7 +16,7 @@ def PD_Dialect : Dialect { This dialect contains the PaddlePaddle operators. }]; - let cppNamespace = "::mlir::pd"; + let cppNamespace = "mlir::pd"; } class PD_Op traits = []> : diff --git a/paddle/infrt/dialect/pd_ops.cc b/paddle/infrt/dialect/pd_ops.cc index ce10be6d100f82b3a431b45098121fc5011496e6..fe3899688384628b2f1a5cba577f5f46515275e0 100644 --- a/paddle/infrt/dialect/pd_ops.cc +++ b/paddle/infrt/dialect/pd_ops.cc @@ -14,10 +14,15 @@ #include "paddle/infrt/dialect/pd_ops.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" +#include +#include #include "paddle/infrt/dialect/infrt_base.h" +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT + +#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT + namespace mlir { namespace pd { PaddleDialect::PaddleDialect(MLIRContext *context) @@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder, return builder.create(loc, value); } -#define GET_OP_CLASSES -#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT -#undef GET_OP_CLASSES - -#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT - void ConstantOp::build(OpBuilder &builder, OperationState &state, Attribute value) { @@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes( inferredReturnTypes.push_back(attributes.get("value").getType()); return success(); } -::mlir::OpFoldResult ConstantOp::fold( - ::llvm::ArrayRef<::mlir::Attribute> operands) { +mlir::OpFoldResult ConstantOp::fold( + ::llvm::ArrayRef operands) { return value(); } @@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes( return success(); } void ElementwiseAdd::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { results.insert(context); } -::mlir::OpFoldResult ElementwiseAdd::fold( +mlir::OpFoldResult ElementwiseAdd::fold( llvm::ArrayRef operands) { if (getElementTypeOrSelf(getType()).isa()) { if (!operands[0] || !operands[1]) return {}; @@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes( } void ReluOp::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { results.insert(context); } void FusedRepeatedFCRelu::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { results.insert(context); } void BatchNormOp::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { results.insert(context); } diff --git a/paddle/infrt/dialect/pd_ops.h b/paddle/infrt/dialect/pd_ops.h index 71e0a53988d1ac8dbd9e1031f830360dc4167cc4..7d1d1d6f58451321a7edae50df4c19a043bf6b29 100644 --- a/paddle/infrt/dialect/pd_ops.h +++ b/paddle/infrt/dialect/pd_ops.h @@ -14,21 +14,20 @@ #pragma once -#include "mlir/Dialect/Traits.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/DerivedAttributeOpInterface.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace mlir { namespace pd { @@ -53,9 +52,8 @@ class PaddleDialect : public Dialect { } }; -#define GET_OP_CLASSES -#include "paddle/infrt/dialect/pd_ops.hpp.inc" -#undef GET_OP_CLASSES - } // namespace pd } // namespace mlir + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pd_ops.hpp.inc" diff --git a/paddle/infrt/dialect/pd_ops.td b/paddle/infrt/dialect/pd_ops.td index b020b7ad5dbc783c1dba192bcb64f02080fbf93c..3addf15082a12c28341da53add36ab1541721b67 100644 --- a/paddle/infrt/dialect/pd_ops.td +++ b/paddle/infrt/dialect/pd_ops.td @@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> { def PD_FetchOp : PD_Op<"fetch", [Terminator]> { let summary = "fetch Op"; + let description = [{ + Return the output tensor from the subgraph. + }]; + + let arguments = (ins PD_Tensor :$inputs, StrAttr:$name); +} + +def PD_ReturnOp : PD_Op<"return", [Terminator]> { + let summary = "return Op"; + let description = [{ Fetch tensor from the graph. }]; @@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> { let arguments = (ins Variadic:$inputs); } -def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { +def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> { let summary = "paddle graph Op"; let description = [{ Describe a paddle graph or subgraph. @@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte let hasFolder = 1; let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">, + OpBuilder<(ins "Attribute":$value)>, ]; } diff --git a/paddle/infrt/dialect/pd_types.h b/paddle/infrt/dialect/pd_types.h index 6f9fe56338a9fd7e5a6b1d532d396cc75efe0415..0da888a9c076922fc21d5cce004dc839bd705762 100644 --- a/paddle/infrt/dialect/pd_types.h +++ b/paddle/infrt/dialect/pd_types.h @@ -18,12 +18,11 @@ #pragma once -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" +#include +#include +#include +#include +#include namespace mlir { namespace PD { diff --git a/paddle/infrt/dialect/print_ir.cc b/paddle/infrt/dialect/print_ir.cc index 43a3577b90f109c638aa08c00de3feb6e8150a7d..5cfd16ee859438c891d6ccf77b97e663620e584c 100644 --- a/paddle/infrt/dialect/print_ir.cc +++ b/paddle/infrt/dialect/print_ir.cc @@ -11,26 +11,25 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include "llvm/ADT/Optional.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/ScopedPrinter.h" -#include "llvm/Support/raw_os_ostream.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/Verifier.h" -#include "mlir/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Passes.h" #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/init_infrt_dialects.h" @@ -114,17 +113,15 @@ int main(int argc, char **argv) { mlir::registerPassManagerCLOptions(); cl::ParseCommandLineOptions(argc, argv, "mlir demo"); - mlir::MLIRContext *context = infrt::Global::getMLIRContext(); - // context->allowUnregisteredDialects(); - auto ®istry = context->getDialectRegistry(); - infrt::RegisterCinnDialects(registry); - + mlir::DialectRegistry registry; + infrt::registerCinnDialects(registry); + mlir::MLIRContext context(registry); // mlir will verify module automatically after parsing. // https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051 // mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source, // context); mlir::OwningModuleRef module_ref = - mlir::parseSourceFile(inputFilename, context); + mlir::parseSourceFile(inputFilename, &context); std::cout << "----------print IR Structure begin----------" << std::endl; printOperation(module_ref->getOperation(), 0); std::cout << "----------print IR Structure end----------" << std::endl; diff --git a/paddle/infrt/dialect/tensor_shape.cc b/paddle/infrt/dialect/tensor_shape.cc index ef5a5525cb22f337f6111823283fadde7c6aff22..92c03818264ee7c44626042dd1de53b66bb8c54b 100644 --- a/paddle/infrt/dialect/tensor_shape.cc +++ b/paddle/infrt/dialect/tensor_shape.cc @@ -17,16 +17,16 @@ #include #include #include +#include +#include #include -#include -#include #include #include -#include #include #include -namespace infrt::ts { +namespace infrt { +namespace ts { using namespace mlir; // NOLINT void TensorShapeDialect::initialize() { @@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const { return Type(); } -void TensorShapeDialect::printType(::mlir::Type type, - ::mlir::DialectAsmPrinter &os) const { +void TensorShapeDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &os) const { if (type.isa()) { os << "shape"; return; @@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type, } llvm_unreachable("unexpected 'shape' type kind"); } +} // namespace ts +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT -} // namespace infrt::ts +#include "paddle/infrt/dialect/tensor_shape_dialect.cpp.inc" diff --git a/paddle/infrt/dialect/tensor_shape.h b/paddle/infrt/dialect/tensor_shape.h index bd3fa8853675af4f1a19d2bdcf413cc0f80809fb..af892af735d2a4e2a8e97ac90e5fb2ba0e9fd1d8 100644 --- a/paddle/infrt/dialect/tensor_shape.h +++ b/paddle/infrt/dialect/tensor_shape.h @@ -17,7 +17,8 @@ #include #include -namespace infrt::ts { +namespace infrt { +namespace ts { class ShapeType : public mlir::Type::TypeBase { @@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase()">, "!ts.shape type">, BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { - let typeDescription = [{ + let description = [{ `!ts.shape type` represents a static tensor shape. }]; } @@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { def TS_PartialShape : DialectType()">, "!ts.partial_shape type">, BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> { - let typeDescription = [{ + let description = [{ `!ts.partial_shape type` represents either a static tensor shape, unranked tensor shape or a ranked tensor shape with unknown dimension sizes. }]; diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc index dc0f2acb2b733e0f9d35f8153d6ac7f8ab0610cc..1baef7a3f77fdd9d3e363110ea3679aa942e222f 100644 --- a/paddle/infrt/dialect/tensorrt/trt_exec.cc +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -11,10 +11,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include -#include "llvm/Support/CommandLine.h" -#include "mlir/Pass/PassManager.h" #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 181f462962aeefa91ee572716090b86946a4cd42..1da80ef2c3b1000c045327510a03081f8aa954ca 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -14,14 +14,13 @@ #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" +#include +#include +#include +#include #include #include #include -#include "llvm/ADT/SetVector.h" -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/IR/Builders.h" -#include "paddle/infrt/dialect/pd_ops.h" -#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -32,9 +31,9 @@ namespace { // Reference the function nameed "FlexibleDFS" but defined in: // paddle/fluid/framework/ir/subgraph_detector.cc. -bool reverseDfs(std::vector<::mlir::Operation *> source, - const std::function &func) { - std::unordered_set visited; +bool reverseDfs(std::vector source, + const std::function &func) { + std::unordered_set visited; while (!source.empty()) { auto node = source.back(); source.pop_back(); @@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source, auto values = node->getOperands(); for (auto value : values) { // if the value is a block argument, the node is nullptr. - ::mlir::Operation *node = value.getDefiningOp(); + mlir::Operation *node = value.getDefiningOp(); if (node != nullptr && !visited.count(node)) { source.emplace_back(node); } @@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source, } // merge the first&second graph op to a new graph op. -void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT - ::mlir::pd::GraphOp first, - ::mlir::pd::GraphOp second) { +void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT + mlir::pd::GraphOp first, + mlir::pd::GraphOp second) { // comput inputs and outputs - ::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs; - for (::mlir::Value input : second.getOperands()) { + ::llvm::SmallVector inputs(first.getOperands()), outputs; + for (mlir::Value input : second.getOperands()) { if (input.getDefiningOp() != first) { inputs.push_back(input); } } - ::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping; - for (::mlir::Value output : first.getResults()) { - for (::mlir::Operation *user : output.getUsers()) { + ::llvm::DenseMap op_output_mapping; + for (mlir::Value output : first.getResults()) { + for (mlir::Operation *user : output.getUsers()) { if (user != second && user->getParentOp() != second) { op_output_mapping[output] = outputs.size(); outputs.push_back(output); @@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT } } } - auto fetch_op = second.getBody()->getTerminator(); - outputs.append(fetch_op->getOperands().begin(), - fetch_op->getOperands().end()); - ::llvm::SmallVector<::mlir::Type, 4> fetch_types; + auto return_op = second.getBody()->getTerminator(); + outputs.append(return_op->getOperands().begin(), + return_op->getOperands().end()); + ::llvm::SmallVector return_types; for (auto value : outputs) { - fetch_types.push_back(value.getType()); + return_types.push_back(value.getType()); } // create the new graph op builder.setInsertionPoint(first); auto loc = first.getLoc(); - auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs); - ::mlir::Block *block = new ::mlir::Block; + auto graph_op = builder.create(loc, return_types, inputs); + mlir::Block *block = new mlir::Block; auto copy_range = second.getBody()->without_terminator(); block->getOperations().splice(block->begin(), second.getBody()->getOperations(), @@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT copy_range.begin(), copy_range.end()); builder.setInsertionPointToEnd(block); - builder.create(loc, outputs); + builder.create(loc, outputs); graph_op.body().push_back(block); // mapping the output unsigned int num_result = first.getNumResults(); - fetch_op = first.getBody()->getTerminator(); + return_op = first.getBody()->getTerminator(); for (unsigned int index = 0; index < num_result; ++index) { auto origin_value = first.getResult(index); if (op_output_mapping.find(origin_value) == op_output_mapping.end()) { - origin_value.replaceAllUsesWith(fetch_op->getOperand(index)); + origin_value.replaceAllUsesWith(return_op->getOperand(index)); } else { - auto inner_value = fetch_op->getOperand(index); + auto inner_value = return_op->getOperand(index); auto outer_value = graph_op.getResult(op_output_mapping[origin_value]); while (!origin_value.use_empty()) { auto replace_value = @@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT // Topological sort the function op. void topoSortBlock(mlir::Block &body) { // NOLINT - llvm::SetVector toSort; + llvm::SetVector toSort; if (body.empty()) return; for (auto it = body.rbegin(); it != body.rend(); ++it) { toSort.insert(&*it); } - llvm::SetVector result = - ::mlir::topologicalSort(std::move(toSort)); + llvm::SetVector result = + mlir::topologicalSort(std::move(toSort)); for (auto *op : result) { op->moveBefore(body.getTerminator()); } @@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT // Implementation of the trtGraphFusePass. void trtGraphFusePass::runOnFunction() { mlir::Block &body = getFunction().front(); - ::mlir::OpBuilder builder(&body, body.begin()); + mlir::OpBuilder builder(&body, body.begin()); bool changed = false; do { changed = false; for (auto &op : body) { - ::mlir::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); + mlir::pd::GraphOp graph_op = + ::llvm::dyn_cast_or_null(&op); if (nullptr == graph_op) continue; for (auto user_op : op.getUsers()) { - ::mlir::pd::GraphOp user_graph_op = - ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op); + mlir::pd::GraphOp user_graph_op = + ::llvm::dyn_cast_or_null(user_op); if (nullptr == user_graph_op) continue; // get all dst input nodes except src. - std::vector<::mlir::Operation *> source_nodes; + std::vector source_nodes; for (auto operand : user_op->getOperands()) { auto input = operand.getDefiningOp(); if (input != &op && input != nullptr) { @@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() { } } // Reverse DFS from the source_nodes. - if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) { - return n == &op; - })) { + if (!reverseDfs(source_nodes, + [&op](const mlir::Operation *n) { return n == &op; })) { mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); changed = true; break; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index e7134e88f316c916787e7faba7f34432922d36c6..f1e555c6f67ecaadff76fb17f68ebaae1a6528e1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "mlir/Pass/Pass.h" +#include namespace infrt { namespace trt { @@ -28,15 +28,15 @@ namespace trt { * %a = "pd.feed"()... * %c = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.fetch" %m + * "pd.return" %m * } ... * %d = "pd.graph"(%c) { * %m = "pd.conv3d"(%c)... - * "pd.fetch" %m + * "pd.return" %m * } ... * %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.fetch" %m + * "pd.return" %m * } ... * "pd.fetch" %d, %f * @@ -47,13 +47,13 @@ namespace trt { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.fetch" %n, %s + * "pd.return" %n, %s * } ... * "pd.fetch" %d, %f * } */ class trtGraphFusePass - : public ::mlir::PassWrapper { + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index 2b45364de2036f0fc1747b42e860bf2a22b80b51..257f2b528542557db33121a4c304eb8e6f657007 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -14,7 +14,7 @@ #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" -#include "mlir/IR/Builders.h" +#include #include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" @@ -22,24 +22,24 @@ namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 void trtGraphSplitPass::runOnFunction() { - std::vector<::mlir::pd::GraphOp> worklist; - ::mlir::Block& block = getFunction().front(); + std::vector worklist; + mlir::Block& block = getFunction().front(); for (auto& op : block) { - ::mlir::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); + mlir::pd::GraphOp graph_op = + ::llvm::dyn_cast_or_null(&op); if (nullptr != graph_op && graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { worklist.push_back(graph_op); } } while (!worklist.empty()) { - ::mlir::pd::GraphOp graph_op = worklist.back(); + mlir::pd::GraphOp graph_op = worklist.back(); worklist.pop_back(); - ::mlir::Block* body = graph_op.getBody(); - auto fetch_op = body->getTerminator(); - graph_op.replaceAllUsesWith(fetch_op->getOperands()); + mlir::Block* body = graph_op.getBody(); + auto return_op = body->getTerminator(); + graph_op.replaceAllUsesWith(return_op->getOperands()); auto copy_range = body->without_terminator(); - block.getOperations().splice(::mlir::Block::iterator(graph_op), + block.getOperations().splice(mlir::Block::iterator(graph_op), body->getOperations(), copy_range.begin(), copy_range.end()); diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index 092df0cf834e5995cf0c3c693a3cb4949856ca58..d30d186647fc32aa4e16047000ee4071effb900d 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "mlir/Pass/Pass.h" +#include namespace infrt { namespace trt { @@ -31,9 +31,9 @@ namespace trt { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.fetch" %n, %s + * "pd.return" (%n, %s) * } ... - * "pd.fetch" %d, %f + * "pd.fetch" (%d, %f) * } * * destination func: @@ -42,11 +42,11 @@ namespace trt { * %c = "pd.conv2d"(%a) ... * %d = "pd.conv3d"(%c) ... * %f = "pd.conv2d"(%a) ... - * "pd.fetch" %d, %f + * "pd.fetch" (%d, %f) * } */ class trtGraphSplitPass - : public ::mlir::PassWrapper { + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 7b7fbb05c1d13b834447932f63c7b394e14b9715..4e8d40b982b2eaf13aeef4f026d783c3f353c14b 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -14,49 +14,48 @@ #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" -#include "mlir/IR/Builders.h" +#include #include "paddle/infrt/dialect/pd_ops.h" -#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { // Implementation of the trtOpTellerPass。 void trtOpTellerPass::runOnFunction() { - ::mlir::Block &body = getFunction().front(); - std::vector<::mlir::Operation *> worklist; + mlir::Block &body = getFunction().front(); + std::vector worklist; worklist.reserve(body.getOperations().size()); for (auto &op : body) { worklist.push_back(&op); } // Build GraphOp. - ::mlir::OpBuilder builder(&body, body.begin()); + mlir::OpBuilder builder(&body, body.begin()); while (!worklist.empty()) { auto *op = worklist.back(); worklist.pop_back(); if (op == nullptr) continue; - auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op); + auto op1 = ::llvm::dyn_cast_or_null(op); if (op1) continue; - auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op); + auto op2 = ::llvm::dyn_cast_or_null(op); if (op2) continue; - auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op); + auto op3 = ::llvm::dyn_cast_or_null(op); if (op3) continue; builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); - auto graph_op = builder.create<::mlir::pd::GraphOp>( + auto graph_op = builder.create( loc, op->getResultTypes(), op->getOperands()); - ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + ::llvm::SmallVector tblgen_repl_values; for (auto v : - ::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) { + ::llvm::SmallVector{graph_op.getODSResults(0)}) { tblgen_repl_values.push_back(v); } op->replaceAllUsesWith(tblgen_repl_values); // Build graph op. - ::mlir::Block *block = new ::mlir::Block; + mlir::Block *block = new mlir::Block; graph_op.body().push_back(block); op->moveBefore(block, block->begin()); builder.setInsertionPointToEnd(block); - builder.create(loc, op->getResults()); + builder.create(loc, op->getResults()); } } } // namespace trt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index b03945b3459c0237343006019fe15e8a2e508492..fb16c974f7fb3f923bdc460d62d8e5b9f628fff9 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "mlir/Pass/Pass.h" +#include namespace infrt { namespace trt { @@ -29,7 +29,7 @@ namespace trt { * %c = "pd.conv2d"(%a) ... * %d = "pd.conv3d"(%c) ... * %f = "pd.conv2d"(%a) ... - * "pd.fetch" %d, %f + * "pd.fetch" (%d, %f) * } * * destination func: @@ -37,23 +37,23 @@ namespace trt { * %a = "pd.feed"()... * %c = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.fetch" %m + * "pd.return" (%m) * } ... * %d = "pd.graph"(%c) { * %m = "pd.conv3d"(%c)... - * "pd.fetch" %m + * "pd.return" (%m) * } ... * %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.fetch" %m + * "pd.return" (%m) * } ... - * "pd.fetch" %d, %f + * "pd.fetch" (%d, %f) * } * TODO(winter-wang): Supplementary how to judge the operators can be supported * by tensorrt. */ class trtOpTellerPass - : public ::mlir::PassWrapper { + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.cc b/paddle/infrt/dialect/tensorrt/trt_ops.cc index 4c02238b10e1da770454682d889addbe078b0a54..35b7967892cafcea66c382e5681ee43480b02735 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.cc +++ b/paddle/infrt/dialect/tensorrt/trt_ops.cc @@ -13,27 +13,25 @@ // limitations under the License. #include "paddle/infrt/dialect/tensorrt/trt_ops.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include +#include +#include +#include namespace infrt { namespace trt { -TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context) - : ::mlir::Dialect("trt", context, ::mlir::TypeID::get()) { +TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context) + : mlir::Dialect("trt", context, mlir::TypeID::get()) { addOperations< #define GET_OP_LIST #include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT >(); -#undef GET_OP_LIST } -#define GET_OP_CLASSES -#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT -#undef GET_OP_CLASSES - } // namespace trt } // namespace infrt + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.h b/paddle/infrt/dialect/tensorrt/trt_ops.h index c9043c2280de0f7970fb323876b10c68c6a63de7..a37491ec1abc7fd423fef23df5170936d2a769c7 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.h +++ b/paddle/infrt/dialect/tensorrt/trt_ops.h @@ -14,37 +14,32 @@ #pragma once -#include "mlir/Dialect/Traits.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/DerivedAttributeOpInterface.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace infrt { namespace trt { -class TensorRTDialect : public ::mlir::Dialect { +class TensorRTDialect : public mlir::Dialect { public: - explicit TensorRTDialect(::mlir::MLIRContext* context); + explicit TensorRTDialect(mlir::MLIRContext* context); static llvm::StringRef getDialectNamespace() { return "trt"; } }; -// mlir bug。 can be removed safety when update mlir to llvm11. -using namespace mlir; // NOLINT +} // namespace trt +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc" -#undef GET_OP_CLASSES - -} // namespace trt -} // namespace infrt diff --git a/paddle/infrt/dialect/test_kernels.cc b/paddle/infrt/dialect/test_kernels.cc index 894d96f95ad5cb291ced0c71ecb94ec9ab879423..c4588d7cf8bab748832865fc3aaab1913f33d11b 100644 --- a/paddle/infrt/dialect/test_kernels.cc +++ b/paddle/infrt/dialect/test_kernels.cc @@ -14,14 +14,13 @@ #include "paddle/infrt/dialect/test_kernels.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" - -namespace infrt::dialect { +#include +#include +#include +#include +namespace infrt { +namespace dialect { //===----------------------------------------------------------------------===// // BenchmarkOp //===----------------------------------------------------------------------===// @@ -32,65 +31,67 @@ namespace infrt::dialect { // ... // } -static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT - OperationState &result) { // NOLINT - StringAttr nameAttr; +static mlir::ParseResult parseBenchmarkOp( + mlir::OpAsmParser &parser, // NOLINT + mlir::OperationState &result) { // NOLINT + mlir::StringAttr nameAttr; if (parser.parseAttribute(nameAttr, "name", result.attributes)) - return failure(); + return mlir::failure(); // Parse the operands, e.g. (%c : i32, %d : f32) - if (parser.parseLParen()) return failure(); + if (parser.parseLParen()) return mlir::failure(); - SmallVector operands; - SmallVector types; + llvm::SmallVector operands; + llvm::SmallVector types; llvm::SMLoc type_loc = parser.getCurrentLocation(); if (parser.parseOptionalRParen()) { // Parse non-empty operands do { // Parse %c : i32, - OpAsmParser::OperandType operand; - Type type; + mlir::OpAsmParser::OperandType operand; + mlir::Type type; if (parser.parseOperand(operand) || parser.parseColonType(type)) - return failure(); + return mlir::failure(); operands.push_back(operand); types.push_back(type); } while (succeeded(parser.parseOptionalComma())); - if (parser.parseRParen()) return failure(); + if (parser.parseRParen()) return mlir::failure(); } if (parser.resolveOperands(operands, types, type_loc, result.operands)) - return failure(); + return mlir::failure(); // Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1 do { - StringRef attr; - Attribute resultAttr; + mlir::StringRef attr; + mlir::Attribute resultAttr; if (parser.parseKeyword(&attr) || parser.parseEqual() || parser.parseAttribute(resultAttr, parser.getBuilder().getIntegerType(32), attr, result.attributes)) - return failure(); - } while (succeeded(parser.parseOptionalComma())); + return mlir::failure(); + } while (mlir::succeeded(parser.parseOptionalComma())); // Set the default attribute num_warmup_runs to 1 if unset auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) { bool found = llvm::any_of(result.attributes, - [attr_name](const NamedAttribute &attr) { - return attr.first == attr_name; + [attr_name](const mlir::NamedAttribute &attr) { + return attr.getName() == attr_name; }); if (!found) { - IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value); + mlir::IntegerAttr default_val = + parser.getBuilder().getI32IntegerAttr(value); result.addAttribute(attr_name, default_val); } }; setDefaultAttrIfUnset("num_warmup_runs", 1); - Region *target = result.addRegion(); + mlir::Region *target = result.addRegion(); return parser.parseRegion(*target, operands, types, @@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT // max_count = 100, duration_secs = 1 { // ... // } -static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT +static void print(mlir::OpAsmPrinter &p, BenchmarkOp op) { // NOLINT p << "infrt.benchmark "; // Print the name attribute, e.g "add.i32" - auto name_attr = op.getAttr("name"); + auto name_attr = op->getAttr("name"); p << name_attr; // Print the operands and types, e.g. (%c : i32, %d : f32) @@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT bool need_comma = false; // Print the attributes, e.g. max_count = 100, duration_secs = 1 - for (auto &name_attr : op.getAttrs()) { - auto id = name_attr.first; + for (auto &name_attr : op->getAttrs()) { + auto id = name_attr.getName(); if (id == "name") continue; if (need_comma) p << ", "; - auto attr = name_attr.second; + auto attr = name_attr.getValue(); p << id << " = "; - if (auto int_attr = attr.dyn_cast()) { + if (auto int_attr = attr.dyn_cast()) { int_attr.getValue().print(p.getStream(), /*isSigned=*/false); } else { op.emitOpError("Unexpected attribute"); @@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } -static LogicalResult verify(BenchmarkOp op) { +static mlir::LogicalResult verify(BenchmarkOp op) { // Verify that the target benchmark region has exactly one return value. auto ®ion = op.region(); auto &last_op = region.front().back(); @@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) { "incorrect number of return values. One return value is expected"); } - return success(); + return mlir::success(); } +} // namespace dialect +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/test_kernels.cpp.inc" - -} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/test_kernels.h b/paddle/infrt/dialect/test_kernels.h index 29d4209cb7280e0d3d9947c1a9d0cfff75ade01b..73c8a6fb387bca6ebc7ae393e4bba32ab94aa951 100644 --- a/paddle/infrt/dialect/test_kernels.h +++ b/paddle/infrt/dialect/test_kernels.h @@ -13,11 +13,8 @@ // limitations under the License. #pragma once -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include -namespace infrt::dialect { -using namespace mlir; // NOLINT #define GET_OP_CLASSES #include "paddle/infrt/dialect/test_kernels.hpp.inc" -} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/types.cc b/paddle/infrt/dialect/types.cc deleted file mode 100644 index 6d6f6a20b46c90d0bdbb79e5b732255b4a6e27bf..0000000000000000000000000000000000000000 --- a/paddle/infrt/dialect/types.cc +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/infrt/dialect/types.h" - -namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir diff --git a/paddle/infrt/dialect/types.h b/paddle/infrt/dialect/types.h deleted file mode 100644 index a9a2b61871cc0911b756deddda8ba60fade4ac94..0000000000000000000000000000000000000000 --- a/paddle/infrt/dialect/types.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include diff --git a/paddle/infrt/host_context/core_runtime.cc b/paddle/infrt/host_context/core_runtime.cc index cdb8cc99ecb2631d1b9cdf1b8adb830fe9e826a5..e3917bd07d24248becb013e2d6ef6546608285f9 100644 --- a/paddle/infrt/host_context/core_runtime.cc +++ b/paddle/infrt/host_context/core_runtime.cc @@ -23,7 +23,8 @@ #include "paddle/infrt/host_context/op_executable.h" #include "paddle/infrt/host_context/symbol_table.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct CoreRuntime::Impl { KernelRegistry* kernel_registry{}; @@ -90,4 +91,5 @@ llvm::SmallVector CoreRuntime::GetResults( CoreRuntime::~CoreRuntime() {} -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/core_runtime.h b/paddle/infrt/host_context/core_runtime.h index 802f8b17bb0105169c269e6dae9f37331655a1de..acb6a66cac630f695afbdcc527d7b397973aa84f 100644 --- a/paddle/infrt/host_context/core_runtime.h +++ b/paddle/infrt/host_context/core_runtime.h @@ -22,7 +22,8 @@ #include "paddle/infrt/host_context/value.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { class KernelRegistry; class OpExecutable; @@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime { OpExecutableBuilder* NewOpExecutable(const std::string& op_name); }; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_frame.h b/paddle/infrt/host_context/kernel_frame.h index 20cb17dc7fbe241557f70e5d0e2f6cf15dc69b56..5186b88fe2c41a8b4939dd70fde9123549764856 100644 --- a/paddle/infrt/host_context/kernel_frame.h +++ b/paddle/infrt/host_context/kernel_frame.h @@ -21,7 +21,8 @@ #include "llvm/ADT/SmallVector.h" #include "paddle/infrt/host_context/value.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { /** * KernelFrame captures the states(input arguments, attributes, results) @@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame { } }; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_registry_test.cc b/paddle/infrt/host_context/kernel_registry_test.cc index f36ec2a1cac7ded8bd1fc6c30061ce001bdeda1c..7fca56343041c2827f0dce57ca98fb9158ef66f4 100644 --- a/paddle/infrt/host_context/kernel_registry_test.cc +++ b/paddle/infrt/host_context/kernel_registry_test.cc @@ -18,7 +18,8 @@ #include "paddle/infrt/host_context/kernel_utils.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { int add_i32(int a, int b) { return a + b; } @@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) { ASSERT_EQ(results[0]->get(), 3); } -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_utils_test.cc b/paddle/infrt/host_context/kernel_utils_test.cc index 1904eb106a29375f4997ec099151835d409e09b8..bebd8d86e50bbd6a2d80325f9fbd8254718c8d0a 100644 --- a/paddle/infrt/host_context/kernel_utils_test.cc +++ b/paddle/infrt/host_context/kernel_utils_test.cc @@ -16,7 +16,8 @@ #include -namespace infrt::host_context { +namespace infrt { +namespace host_context { int add_i32(int a, int b) { return a + b; } float add_f32(float a, float b) { return a + b; } @@ -66,4 +67,5 @@ TEST(KernelImpl, pair) { ASSERT_EQ(results[1]->get(), 3.f); } -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_function_executable.cc b/paddle/infrt/host_context/mlir_function_executable.cc index 5f8dacf8e448acc494856fb1c7117d61b3075190..47ec27ebec300f1cedd57b11e0dd1e6b37611141 100644 --- a/paddle/infrt/host_context/mlir_function_executable.cc +++ b/paddle/infrt/host_context/mlir_function_executable.cc @@ -15,6 +15,7 @@ #include "paddle/infrt/host_context/mlir_function_executable.h" #include +#include #include // NOLINT diff --git a/paddle/infrt/host_context/mlir_function_executable.h b/paddle/infrt/host_context/mlir_function_executable.h index ba5fa154d6fcc3183c3a882e1eb1bd05daa66129..a6428df86e6b27061d92856970682bc29499d825 100644 --- a/paddle/infrt/host_context/mlir_function_executable.h +++ b/paddle/infrt/host_context/mlir_function_executable.h @@ -13,7 +13,8 @@ // limitations under the License. #pragma once -#include +#include +#include #include #include diff --git a/paddle/infrt/host_context/mlir_program_executor.h b/paddle/infrt/host_context/mlir_program_executor.h index b2af4d2d79db54aa02a34b0371c63f992c055f58..c2ccb90640b21bcfb675a707d6cb60cf5028ab36 100644 --- a/paddle/infrt/host_context/mlir_program_executor.h +++ b/paddle/infrt/host_context/mlir_program_executor.h @@ -15,9 +15,9 @@ #pragma once #include +#include +#include #include -#include -#include #include #include diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc index 25324b1291582b406eb5b33c1241609a9e2ed5d6..3dbc7a702be38d986b6f77b345abe85f939370e6 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -16,8 +16,9 @@ #include #include +#include +#include #include -#include #include #include @@ -40,7 +41,8 @@ #include "paddle/infrt/host_context/value.h" #include "paddle/infrt/tensor/tensor_shape.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { template std::string DumpToString(T& op) { // NOLINT @@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) { template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - if (attr->isa()) { - auto val = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); if (val.getType().isInteger(32)) { return val.getInt(); } @@ -125,10 +127,10 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( } template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - if (attr->isa()) { - auto val = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); if (val.getType().isInteger(64)) { return val.getInt(); } @@ -139,10 +141,10 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( // TODO(Superjomn) Make double and float parsing share some thing. template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - if (attr->isa()) { - auto val = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); if (val.getType().isF32()) return val.getValueAsDouble(); } return boost::none; @@ -150,10 +152,10 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - if (attr->isa()) { - auto val = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); if (val.getType().isF64()) return val.getValueAsDouble(); } return boost::none; @@ -161,17 +163,17 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - return attr->cast().getValue().str(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + return attr.cast().getValue().str(); } #define PROCESS_ARRAY_INT(type__, bits__) \ template <> \ boost::optional> MlirToRuntimeTranslator::EmitAttribute( \ - const mlir::Attribute* attr) { \ - if (!attr->isa()) return boost::none; \ - auto array = attr->cast(); \ + const mlir::Attribute& attr) { \ + if (!attr.isa()) return boost::none; \ + auto array = attr.cast(); \ CHECK(!array.empty()); \ \ if (!array[0].getType().isInteger(bits__)) { \ @@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64); template <> boost::optional> MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - auto array = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + auto array = attr.cast(); CHECK(!array.empty()); if (!array[0].getType().isF32()) return boost::none; @@ -207,9 +209,9 @@ boost::optional> MlirToRuntimeTranslator::EmitAttribute( template <> boost::optional> MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - auto array = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + auto array = attr.cast(); CHECK(!array.empty()); if (!array[0].getType().isF64()) return boost::none; @@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { for (int i = 0, e = op->getNumOperands(); i < e; i++) { // function argument as value auto operand = op->getOperand(i); - if (operand.getKind() == mlir::Value::Kind::BlockArgument) { + /// if (operand.getKind() == mlir::Value::Kind::BlockArgument) { + if (operand.isa()) { mlir::BlockArgument arg = operand.dyn_cast(); Value* arg_value = GetValue(arg); impl_->cur_op->AppendArgument(arg_value); @@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { for (size_t i = 0; i < attrs.size(); i++) { auto& attr = attrs[i]; - if (auto v = EmitAttribute(&attr.second)) { + if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(*v)); - } else if (auto v = EmitAttribute(&attr.second)) { + } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(*v)); - } else if (auto v = EmitAttribute(&attr.second)) { + } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(*v)); - } else if (auto v = EmitAttribute(&attr.second)) { + } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(*v)); - } else if (auto v = EmitAttribute(&attr.second)) { + } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); } else { LOG(FATAL) << "Not supported attribute type"; @@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { llvm::SmallVector results; auto func_type = - mlir::FunctionType::get(inputs, results, region.getContext()); + mlir::FunctionType::get(region.getContext(), inputs, results); auto* function = impl_->cur_op->CreateFunctionExecutable( ®ion, func_type, &impl_->func_defs); impl_->cur_op->AppendAttribute(new Value(function)); @@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) { execute.Run(); } -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.h b/paddle/infrt/host_context/mlir_to_runtime_translate.h index 598e81bfd96d8acbc6d7eeba046df701a955b628..fcd79eaf386eed5a6a8eaa31712e344bab56dbd4 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.h +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.h @@ -29,7 +29,8 @@ class Attribute; class Value; } // namespace mlir -namespace infrt::host_context { +namespace infrt { +namespace host_context { class CoreRuntimeBuilder; class Value; @@ -73,7 +74,7 @@ class MlirToRuntimeTranslator { bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table); template - boost::optional EmitAttribute(const mlir::Attribute* attr); + boost::optional EmitAttribute(const mlir::Attribute& attr); Value* GetOpResult(mlir::Operation* op); @@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime); */ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry); -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc b/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc index 9b85be977ab6c1964b006385dfdc78414f1e482b..375daa4515e17fe1618c71d642825d112a3f788f 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc @@ -29,7 +29,8 @@ #include "paddle/infrt/kernel/tensor_shape_kernels.h" #include "paddle/infrt/kernel/test_kernels.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { TEST(MlirToRuntimeTranslate, basic) { mlir::MLIRContext context; @@ -48,7 +49,7 @@ func @main() -> () { )ROC"; auto module = dialect::LoadMlirSource(&context, source); - module->verify(); + EXPECT_TRUE(mlir::succeeded(module->verify())); KernelRegistry registry; kernel::RegisterFloatBasicKernels(®istry); @@ -74,7 +75,7 @@ func @main() -> () { )ROC"; auto module = dialect::LoadMlirSource(&context, source); - module->verify(); + EXPECT_TRUE(mlir::succeeded(module->verify())); KernelRegistry registry; kernel::RegisterFloatBasicKernels(®istry); @@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor, !infrt.tensorverify(); + EXPECT_TRUE(mlir::succeeded(module->verify())); host_context::KernelRegistry registry; @@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor, !infrt.tensor #include #include "paddle/infrt/host_context/kernel_frame.h" @@ -21,7 +22,8 @@ #include "paddle/infrt/host_context/mlir_function_executable.h" #include "paddle/infrt/host_context/symbol_table.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct OpExecutable::Impl { Impl(const std::string& op_name, @@ -148,4 +150,5 @@ void OpExecutable::Execute() { OpExecutable::~OpExecutable() {} -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/op_executable.h b/paddle/infrt/host_context/op_executable.h index e2248225a5cafa44be27604ad3b5f606c37cf6c7..550f6ab6349ed2f3f503ea7b0b425f7dbc1aea2c 100644 --- a/paddle/infrt/host_context/op_executable.h +++ b/paddle/infrt/host_context/op_executable.h @@ -14,19 +14,18 @@ #pragma once #include - +#include +#include #include #include #include -#include "mlir/IR/Function.h" -#include "mlir/IR/Region.h" - namespace mlir { class FuncOp; } // namespace mlir -namespace infrt::host_context { +namespace infrt { +namespace host_context { class SymbolTable; class KernelRegistry; @@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable { function_defs_t* function_defs); }; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/kernel/basic_kernels.cc b/paddle/infrt/kernel/basic_kernels.cc index d7f2c3865157dd973e21f8527a4804e4e3209bf3..b186cfcfd2b355f97711ecc916e497c2916d4060 100644 --- a/paddle/infrt/kernel/basic_kernels.cc +++ b/paddle/infrt/kernel/basic_kernels.cc @@ -23,7 +23,8 @@ using infrt::host_context::Attribute; -namespace infrt::kernel { +namespace infrt { +namespace kernel { template T add(T a, T b) { @@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) { registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print)); } -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/basic_kernels.h b/paddle/infrt/kernel/basic_kernels.h index 9e98885cf6ebfb8e000424874da70f3a34e2e127..feb66be61f530676cf79a32be1e52d69017d21bc 100644 --- a/paddle/infrt/kernel/basic_kernels.h +++ b/paddle/infrt/kernel/basic_kernels.h @@ -15,13 +15,16 @@ #pragma once #include -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct KernelRegistry; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt -namespace infrt::kernel { +namespace infrt { +namespace kernel { /** * Register all the basic kernels to \p registry. @@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry); void RegisterIntBasicKernels(host_context::KernelRegistry* registry); void RegisterFloatBasicKernels(host_context::KernelRegistry* registry); -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_kernels.cc b/paddle/infrt/kernel/tensor_kernels.cc index 2fa477aa4dbda6f7282e65d705c27d433f2839c1..51e000492237435de555bc53bb63d23ce7ecbeb2 100644 --- a/paddle/infrt/kernel/tensor_kernels.cc +++ b/paddle/infrt/kernel/tensor_kernels.cc @@ -25,7 +25,8 @@ #include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_shape.h" -namespace infrt::kernel { +namespace infrt { +namespace kernel { using namespace host_context; // NOLINT using namespace tensor; // NOLINT @@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { INFRT_KERNEL(ShallowCopyTensor)); } -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_kernels.h b/paddle/infrt/kernel/tensor_kernels.h index 8f2180ba80a4f81c910aafe915d61288da99c930..df8e25c32393c903c3e6801e23095aeff6eca9b4 100644 --- a/paddle/infrt/kernel/tensor_kernels.h +++ b/paddle/infrt/kernel/tensor_kernels.h @@ -14,12 +14,16 @@ #pragma once -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct KernelRegistry; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt -namespace infrt::kernel { +namespace infrt { +namespace kernel { void RegisterTensorKernels(host_context::KernelRegistry* registry); -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_shape_kernels.cc b/paddle/infrt/kernel/tensor_shape_kernels.cc index a04b492819298b3b6673324856d6277cc1ab6bea..4edbecfa108869ee1f8181c8efd42adc91224d6d 100644 --- a/paddle/infrt/kernel/tensor_shape_kernels.cc +++ b/paddle/infrt/kernel/tensor_shape_kernels.cc @@ -24,7 +24,8 @@ #include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/tensor/tensor_shape.h" -namespace infrt::kernel { +namespace infrt { +namespace kernel { void PrintShape(const tensor::TensorShape& shape) { llvm::raw_os_ostream oos(std::cout); @@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) { registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape)); } -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_shape_kernels.h b/paddle/infrt/kernel/tensor_shape_kernels.h index e87c6c37e88a08fa2b2c85d35621786e9a46e65e..e31a37463be43bcc997368bd9693b3d866eff454 100644 --- a/paddle/infrt/kernel/tensor_shape_kernels.h +++ b/paddle/infrt/kernel/tensor_shape_kernels.h @@ -14,14 +14,18 @@ #pragma once -namespace infrt::host_context { +namespace infrt { +namespace host_context { class KernelRegistry; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt -namespace infrt::kernel { +namespace infrt { +namespace kernel { void RegisterTensorShapeKernels(host_context::KernelRegistry* registry); -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/test_kernels.cc b/paddle/infrt/kernel/test_kernels.cc index d5f64d09b602fd8696eba966f717d440214043c8..ccfb3356a855f418f14e42ed8a368f31d2fe8b27 100644 --- a/paddle/infrt/kernel/test_kernels.cc +++ b/paddle/infrt/kernel/test_kernels.cc @@ -33,7 +33,8 @@ using infrt::host_context::Attribute; using infrt::host_context::MlirFunctionExecutable; using infrt::host_context::RemainingArguments; -namespace infrt::kernel { +namespace infrt { +namespace kernel { namespace { class BenchmarkStats { public: @@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) { INFRT_KERNEL(ShadowCopyTensor)); } -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/test_kernels.h b/paddle/infrt/kernel/test_kernels.h index f42884dfaf2c9005b31e0ef335d1316625337a6f..f5639ec1afaad769d62530c4ef91eafa35779218 100644 --- a/paddle/infrt/kernel/test_kernels.h +++ b/paddle/infrt/kernel/test_kernels.h @@ -15,17 +15,21 @@ #pragma once #include -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct KernelRegistry; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt -namespace infrt::kernel { +namespace infrt { +namespace kernel { /** * Register all the test kernels to registry. */ void RegisterTestKernels(host_context::KernelRegistry* registry); -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/paddle/cpp/desc_api.h b/paddle/infrt/paddle/cpp/desc_api.h index ccd79c048ab14593838b5173cadcbe979019045a..3b2dcb0018b2fcf733585ce28dac16aadffd7639 100644 --- a/paddle/infrt/paddle/cpp/desc_api.h +++ b/paddle/infrt/paddle/cpp/desc_api.h @@ -18,7 +18,9 @@ #include #include -namespace infrt::paddle::cpp { +namespace infrt { +namespace paddle { +namespace cpp { /* * Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc @@ -226,4 +228,6 @@ class ProgramDescAPI { virtual void SetVersion(int64_t version) = 0; }; -} // namespace infrt::paddle::cpp +} // namespace cpp +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/model_parser.cc b/paddle/infrt/paddle/model_parser.cc index 285280e69435b046a3c3073faf575de31662a2b5..f3de1a630451cc387765040191be8715768be510 100644 --- a/paddle/infrt/paddle/model_parser.cc +++ b/paddle/infrt/paddle/model_parser.cc @@ -22,7 +22,8 @@ #include "paddle/infrt/common/target.h" #include "paddle/infrt/common/type.h" -namespace infrt::paddle { +namespace infrt { +namespace paddle { int SizeOfType(framework_proto::VarType::Type type) { using Type = framework_proto::VarType::Type; @@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) { LoadLoDTensor(fin, out, target); } -} // namespace infrt::paddle +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/model_parser.h b/paddle/infrt/paddle/model_parser.h index 73125fadedb82b9dd628fc0fb65a3b5c54d24a42..373f77033dcefa1a81cd8756da859b6d232337a0 100644 --- a/paddle/infrt/paddle/model_parser.h +++ b/paddle/infrt/paddle/model_parser.h @@ -25,7 +25,8 @@ #include "paddle/infrt/paddle/scope.h" #include "paddle/infrt/paddle/tensor.h" -namespace infrt::paddle { +namespace infrt { +namespace paddle { namespace framework_proto = ::paddle::framework::proto; // Read a __model__ file. @@ -52,4 +53,5 @@ void TensorFromStream( const common::Target& target = common::DefaultHostTarget()); void ReadBinaryFile(const std::string& filename, std::string* contents); -} // namespace infrt::paddle +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/block_desc.cc b/paddle/infrt/paddle/pb/block_desc.cc index 11186bc68af1640e5a83559d3eda4ca958eab8b4..5b28fa5464c547a9badeefef0ef5888fc10ccaaf 100644 --- a/paddle/infrt/paddle/pb/block_desc.cc +++ b/paddle/infrt/paddle/pb/block_desc.cc @@ -14,7 +14,9 @@ #include "paddle/infrt/paddle/pb/block_desc.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { template <> framework_proto::VarDesc* BlockDesc::GetVar( @@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp() { return desc_->add_ops(); } -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/block_desc.h b/paddle/infrt/paddle/pb/block_desc.h index 9c1b7f9adf172fa615415f786b9a94e6ee03e22e..c9e325699a4bc4bd18eaf76a5f44cc37aa8c17d9 100644 --- a/paddle/infrt/paddle/pb/block_desc.h +++ b/paddle/infrt/paddle/pb/block_desc.h @@ -18,7 +18,9 @@ #include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/framework.pb.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { namespace framework_proto = ::paddle::framework::proto; @@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI { framework_proto::BlockDesc* desc_; // not_own }; -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/op_desc.cc b/paddle/infrt/paddle/pb/op_desc.cc index c7b1e66f506425db53fd0dfdbbf43c2dc2bc4b2a..32dcefb1ac684a647d978e7d92351ae46a58f9d6 100644 --- a/paddle/infrt/paddle/pb/op_desc.cc +++ b/paddle/infrt/paddle/pb/op_desc.cc @@ -14,7 +14,9 @@ #include "paddle/infrt/paddle/pb/op_desc.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { google::protobuf::internal::RepeatedPtrIterator FindAttr(framework_proto::OpDesc *desc, const std::string &name) { @@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector, strings); GET_ATTR_IMPL(std::string, s); GET_ATTRS_IMPL(std::vector, longs); -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/op_desc.h b/paddle/infrt/paddle/pb/op_desc.h index 81d57d9f32252773626db6bf554c388253c99a1f..2829f2aca2e08dd186c4a38c3b26d808cc1e1138 100644 --- a/paddle/infrt/paddle/pb/op_desc.h +++ b/paddle/infrt/paddle/pb/op_desc.h @@ -19,7 +19,9 @@ #include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/support/variant.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { namespace framework_proto = ::paddle::framework::proto; @@ -195,4 +197,6 @@ template <> void OpDesc::SetAttr>(const std::string &name, const std::vector &v); -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/program_desc.cc b/paddle/infrt/paddle/pb/program_desc.cc index ed8a7e36e0129c7b8b121989fcb80c363f73fc8d..9d725485a974d3f6800a4bb3cca661d8653333c3 100644 --- a/paddle/infrt/paddle/pb/program_desc.cc +++ b/paddle/infrt/paddle/pb/program_desc.cc @@ -17,7 +17,9 @@ #include #include -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { template <> framework_proto::BlockDesc* ProgramDesc::GetBlock( @@ -32,4 +34,6 @@ ProgramDesc::AddBlock() { return desc_->add_blocks(); } -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/program_desc.h b/paddle/infrt/paddle/pb/program_desc.h index 4adad650c974dfc4cffe57ff70ac01965a3e733d..b1e64f8e86611fd8ef4e8be8a2064ceb1cd7a5ae 100644 --- a/paddle/infrt/paddle/pb/program_desc.h +++ b/paddle/infrt/paddle/pb/program_desc.h @@ -21,7 +21,9 @@ #include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/framework.pb.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { namespace framework_proto = ::paddle::framework::proto; class ProgramDesc : public cpp::ProgramDescAPI { @@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI { framework_proto::ProgramDesc *desc_; // not_own }; -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/var_desc.cc b/paddle/infrt/paddle/pb/var_desc.cc index cf80df4f1b845b1f89d971a353563d934144b7ca..7ea2e24da3446c22e5f359122eb2d8d1ef5b12b4 100644 --- a/paddle/infrt/paddle/pb/var_desc.cc +++ b/paddle/infrt/paddle/pb/var_desc.cc @@ -19,7 +19,9 @@ #include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/framework.pb.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { cpp::VarDescAPI::Type VarDesc::GetType() const { auto type = desc_->type().type(); @@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() { return std::vector(); } -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/var_desc.h b/paddle/infrt/paddle/pb/var_desc.h index 4cff5fdee0375d02e5fd014e287fe74f2c9a0d77..7215ba6bb6aa7b52af69ed76562d3c65422c95a5 100644 --- a/paddle/infrt/paddle/pb/var_desc.h +++ b/paddle/infrt/paddle/pb/var_desc.h @@ -23,7 +23,9 @@ #include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/framework.pb.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { namespace framework_proto = ::paddle::framework::proto; // convert between std::vector and protobuf repeated. @@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI { framework_proto::VarDesc *desc_; }; -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/pten/core/dense_tensor.cc b/paddle/pten/core/dense_tensor.cc index 0b5f5cb18e13dbd79166e5f1a96608eb9a9411dc..eb6f834d7277901b41ad797625a583f4212e98c0 100644 --- a/paddle/pten/core/dense_tensor.cc +++ b/paddle/pten/core/dense_tensor.cc @@ -435,6 +435,10 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place, } void DenseTensor::ShareBufferWith(const DenseTensor& tensor) { + if (storage_ == nullptr) { + storage_ = make_intrusive( + paddle::platform::CPUPlace()); + } if (storage_ != nullptr && tensor.storage_ != nullptr) { storage_->set_data_shared(tensor.storage_->data_shared()); } diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d04aac1551e64f63625722b08088eb3d8552b6 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -0,0 +1,675 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import time +import contextlib +import logging +import functools +import numpy as np +from itertools import chain +from functools import reduce +from types import MethodType +from collections import deque, OrderedDict + +import paddle +from paddle import nn +from paddle.autograd import PyLayer +import paddle.fluid.core as core +import paddle.distributed as dist +from paddle.fluid.framework import ParamBase +from paddle.fluid.clip import ClipGradByGlobalNorm +from paddle.distributed.collective import _get_global_group + +from .sharding_utils import Type, ShardingClipGrad +from ..pp_utils.utils import _all_gather + +# CUDA alignment 256 bytes +alignment = {"gpu": 256, } +align = { + Type.fp16.value: 2, + Type.fp32.value: 4, +} + +global CHECK_LAYER +CHECK_LAYER = dict() # Help to check layer's id -> layer's name + + +class ShardingStage3(nn.Layer): + """ + A wrapper for Sharding Stage3 Layer in Dygraph. + + .. warning: ShardingStage3 encapsulates the layer strategy and integrates it into the nn.Layer. + + .. ZeRO: https://arxiv.org/pdf/1910.02054.pdf. + """ + + def __init__(self, + layer, + optimizer, + group=None, + sync_buffers=False, + device="gpu", + pertrain_sync_models=True, + accumulate_grads=False, + offload=False, + sync_comm=False): + super().__init__() + + # Default configs + assert core.is_compiled_with_cuda(), "Only support CUDA." + self._layer = layer + self._default_device = device + self.__sync_buffers = sync_buffers + self._accumulate_grads = accumulate_grads + self._offload = offload + self._sync_comm = sync_comm + + # Communication group establishment + self._group = dist.new_group(_get_global_group() + .ranks) if group is None else group + self._world_size_scaling = 1.0 / self._group.nranks + assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1." + self._rank = self._group.rank + self._global_root_rank = 0 # picking rank 0 as the reference + self._global_ranks = self._group.ranks + self._param2buffer_size = dict() # {param.name: size} + self._param2buffer = dict( + ) # {param.name: [(start0, end0),(start1, end1), ...]} + self._trainable_params = dict() # {layer.name: [trainable_params]} + + assert not isinstance( + optimizer, list), "Multiple optimizers are not supported now." + self._optim = _OptimizerWrapper(optimizer, self._offload, self._group, + self._update_params_slice) + self._ori_parameter_list = self._optim._parameter_list + self._ori_param_groups = self._optim._param_groups + + # Replace optimizer's _grad_clip + if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm): + logging.warning( + "While using ClipGradByGlobalNorm in ShardingStage3, the grad clip of original optimizer will be changed." + ) + self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip, + paddle.get_device(), + self._group) + + # Synchronous all ranks models + if pertrain_sync_models: + self._sync_params_and_buffers() + + self._segment_rank_params(self._layer) + + # In the first step, record the execution order of the layer + self._order_tracer = OrderedDict() + self._order_tracer["order"] = 0 + self._order_tracer["layer"] = [] + # Register task flow + self._task_flow = TaskFlow() + # Register forward hooks + self._register_forward_hooks(self._layer) + # Register backward parameter hooks + self._register_backward_hooks() + # Redefine optimizer step and clear function + self._redefine_opt_step() + self._redefine_opt_clear() + + @paddle.no_grad() + def _sync_params_and_buffers(self): + """ + Sync all model states for all ranks + """ + + for p in self._layer.parameters(): + dist.broadcast( + p, + src=self._global_root_rank, + group=self._group, + use_calc_stream=True) + + # Multi stream operation will be supported later + dist.wait(tensor=p, group=self._group, use_calc_stream=True) + + def _clear_gradients(self): + assert len(self._trainable_params.keys()) > 0 + current_layer_params = self._layer.parameters(include_sublayers=True) + trainable_params = list( + filter(lambda x: x.trainable, current_layer_params)) + for param in trainable_params: + assert hasattr( + param, "fw_storage" + ), "Find {} don't have fw_storage attribute.".format(param.name) + + # param.bw_storage.zero_() + param.fw_storage.clear_gradient(False) + param.fw_storage._gradient_set_empty(False) + param.bw_storage._clear() + + # Update param memery slice + def _update_params_slice(self): + update_list = self._update_params() + + if not isinstance(self._optim._param_groups[0], dict): + slice_params = [param.fw_storage for param in update_list] + self._optim._parameter_list = slice_params + self._optim._param_groups = slice_params + else: + params_name_list = list(map(lambda p: p.name, update_list)) + for param_group in self._optim._param_groups: + slice_p = [] + for p in param_group['params']: + if p.name in params_name_list: + assert hasattr( + p, "fw_storage" + ), "Find {} don't have fw_storage attribute.".format( + p.name) + slice_p.append(p.fw_storage) + param_group['params'] = slice_p + + def forward(self, *inputs, **kwargs): + """ + A wrapper for Sharding Stage3 layer. + """ + # 1.Sync layer's buffers state + if self.__sync_buffers: + self._sync_buffers() + + # 2.Normal FW on the base model + fw = self._layer(*inputs, **kwargs) + + return fw + + def _segment_rank_params(self, layer, name="last_layer"): + current_layer_params = _current_layer_params(layer) + if current_layer_params: + CHECK_LAYER[id(layer)] = name + self._flatten_layer_params(layer, current_layer_params) + + for name, sub_layer in layer.named_children(): + self._segment_rank_params(sub_layer, name) + + def _flatten_layer_params(self, layer, current_layer_params): + def _add_manage_info(trainable_param): + return _PartitionParam(trainable_param) + + trainable_params = list( + filter(lambda x: x.trainable, current_layer_params)) + assert id(layer) not in self._trainable_params.keys() + self._trainable_params[id(layer)] = list( + map(_add_manage_info, trainable_params)) + + for param in self._trainable_params[id(layer)]: + if param.name in self._param2buffer.keys(): + continue + self._param2buffer[param.name] = [] + # 1.Params alignment + offset = 0 + # CUDA alignment 256 bytes + size = param._numel() * align[param.dtype] + remaining = size % alignment[self._default_device] + ali = 0 if remaining == 0 else alignment[ + self._default_device] - remaining + align_ = ali // align[param.dtype] + + offset = align_ + param._numel() + buffer_size = offset if offset % self._group.nranks == 0 else offset + self._group.nranks - ( + offset % self._group.nranks) + self._param2buffer_size[param.name] = buffer_size + + # 2.Combination param buffer + assert buffer_size % self._group.nranks == 0 + pre_buffer = buffer_size // self._group.nranks + + for rank_ in range(self._group.nranks): + self._param2buffer[param.name].append( + (rank_ * pre_buffer, (rank_ + 1) * pre_buffer)) + + # 3.Flatten layer params and release other rank buffer + self._param_storage(param, buffer_size) + + def _param_storage(self, param, buffer_size): + assert isinstance(buffer_size, int) + value = np.zeros( + buffer_size, + dtype=np.float16) if Type.fp16.value == param.dtype else np.zeros( + buffer_size, dtype=np.float32) + buffer = core.VarBase(value=value, place=core.CPUPlace()) + + param_shape = param.shape + origin_state = param.stop_gradient + param.stop_gradient = True + param.flatten_() + param.stop_gradient = origin_state + start, end = self._param2buffer[param.name][self._rank] + + # Copy the current param value + tmp_var = core.VarBase( + tensor=buffer._slice(0, param._numel()), place=core.CPUPlace()) + param_cpu = param.cpu() + tmp_var.value().get_tensor().set(param_cpu.value().get_tensor(), + core.CPUPlace()) + param.value().get_tensor()._set_dims(param_shape) + param._clear() + + # Current rank param_storage + param.fw_storage = core.VarBase( + buffer._slice(start, end), "slice@" + param.name) + param.status = "part" + + # Updata optimizer master weights + if param.dtype == Type.fp16.value: + self._optim._master_weights[param.fw_storage.name] = paddle.cast( + param.fw_storage, Type.fp32.value) + + def _register_forward_hooks(self, layer): + current_layer_params = _current_layer_params(layer) + if current_layer_params: + self._register_forward_all_hooks(layer, self._task_flow) + + for _, sub_layer in layer.named_children(): + self._register_forward_hooks(sub_layer) + + def _register_forward_all_hooks(self, sub_layer, task_flow): + def _forward_pre_hook(layer, inputs): + return ForwardPreHooks(layer, self._order_tracer, + self._trainable_params, self._param2buffer, + self._rank, self._group, self._sync_comm, + task_flow) + + def _forward_post_hook(layer, inputs, outputs): + return ForwardPostHooks.apply( + outputs, layer, self._order_tracer, self._trainable_params, + self._param2buffer, self._param2buffer_size, self._rank, + self._group, self._sync_comm, task_flow) + + # register previous forward hooks + sub_layer.register_forward_pre_hook(_forward_pre_hook) + + # register post forward hooks + sub_layer.register_forward_post_hook(_forward_post_hook) + + @paddle.no_grad() + def _sync_buffers(self): + for buffer in self._layer.buffers(include_sublayers=True): + dist.broadcast( + buffer, + self._global_root_rank, + self._group, + use_calc_stream=True) + # Multi stream operation will be supported later + dist.wait(tensor=buffer, group=self._group, use_calc_stream=True) + + def __getattr__(self, name): + """Forward missing attributes to wrapped layer.""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self._layer, name) + + def _update_params(self): + update_list = [] + assert len(self._trainable_params.keys()) > 0 + current_layer_params = self._layer.parameters(include_sublayers=True) + trainable_params = list( + filter(lambda x: x.trainable, current_layer_params)) + for param in trainable_params: + assert hasattr( + param, + "fw_storage"), "Find {} don't have fw_storage attribute".format( + param.name) + + if self._accumulate_grads: + param.bw_storage.scale_(scale=self._world_size_scaling) + param.fw_storage = _VarBaseWrapper(param) + param.fw_storage._copy_gradient_from(param.bw_storage) + update_list.append(param) + return update_list + + def get_all_parameters(self): + assert len(self._trainable_params.keys()) > 0 + current_layer_params = self._layer.parameters(include_sublayers=True) + trainable_params = list( + filter(lambda x: x.trainable, current_layer_params)) + for param in trainable_params: + if param.use_count > 0: + continue + assert hasattr( + param, + "fw_storage"), "Find {} don't have fw_storage attribute".format( + param.name) + + full_param = _all_gather( + param.fw_storage, self._group, use_calc_stream=True) + dist.wait( + tensor=full_param, group=self._group, use_calc_stream=True) + core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( + param) + param.value().get_tensor()._set_dims(param.shape) + param.fw_storage._clear() + param.fw_storage = None + param.status = "all" + param.use_count += 1 + + self._optim._parameter_list = self._ori_parameter_list + self._optim._param_groups = self._ori_param_groups + + def _register_backward_hooks(self): + current_layer_params = self._layer.parameters(include_sublayers=True) + trainable_params = list( + filter(lambda x: x.trainable, current_layer_params)) + + for param in trainable_params: + allreduce_function = self._get_allreduce_fn(param) + param._register_backward_hook(allreduce_function) + + def _get_allreduce_fn(self, param): + @paddle.no_grad() + def reduce(*_): + if param.name in self._task_flow.full_grad.keys(): + full_grad = self._task_flow.full_grad[param.name] + with paddle.amp.auto_cast(enable=False): + if not self._accumulate_grads: + full_grad.scale_(scale=self._world_size_scaling) + # Only support sync allreduce current rank's layer now + dist.all_reduce( + tensor=full_grad, + group=self._group, + use_calc_stream=True) + dist.wait( + tensor=full_grad, + group=self._group, + use_calc_stream=True) + + start, end = self._param2buffer[param.name][self._rank] + if not self._accumulate_grads or param.bw_storage is None: + param.bw_storage = core.VarBase( + full_grad._slice(start, end)).detach().clone() + else: + param.bw_storage.add_( + core.VarBase(full_grad._slice(start, end)).detach() + .clone()) + param.clear_gradient(False) + param._gradient_set_empty(False) + tmp_var = self._task_flow.full_grad.pop(param.name) + tmp_var._clear() + + if param.name in self._task_flow.full_param.keys(): + if param.status == "all": + param.use_count = 0 + param._clear() + start, end = self._param2buffer[param.name][self._rank] + with paddle.amp.auto_cast(enable=False): + param.fw_storage = core.VarBase( + self._task_flow.full_param[param.name]._slice(start, + end), + param.name + "@slice").detach().clone() + param.status = "part" + tmp_var = self._task_flow.full_param.pop(param.name) + tmp_var._clear() + + return reduce + + def _redefine_opt_step(self): + params_slice_func = self._update_params_slice + opt_step = self._optim.step + update_scaler = self._optim.update_scaler + + def _opt_step(self): + if not update_scaler: + params_slice_func() + opt_step() + + self._optim.step = MethodType(_opt_step, self._optim) + + def _redefine_opt_clear(self): + clear_func = self._clear_gradients + + def _opt_clear(self): + clear_func() + + self._optim.clear_grad = MethodType(_opt_clear, self._optim) + + +def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank, + group, sync_comm, task_flow): + + # Record layer's id + layer_id = id(layer) + use_calc, sync_wait = False, False + + if layer_id not in order_tracer.keys() or sync_comm: + use_calc, sync_wait = True, True + task_flow.use_calc[layer_id] = use_calc + else: + task_flow.use_calc[layer_id] = use_calc + _wait_layer(trainable_params, layer_id, task_flow, group, use_calc) + + if layer_id == order_tracer["layer"][-1]: return + order_ = order_tracer[layer_id] + layer_id = order_tracer["layer"][order_ + 1] + _allgather_buffer( + layer_id, + trainable_params, + group, + use_calc_stream=use_calc, + task_flow=task_flow, + sync_wait=sync_wait) + return + + +class ForwardPostHooks(PyLayer): + @staticmethod + def forward(ctx, inputs, layer, order_tracer, trainable_params, + param2buffer, param2buffer_size, rank, group, sync_comm, + task_flow): + _release_param(layer, trainable_params, param2buffer, rank, task_flow) + + layer_id = id(layer) + if layer_id not in order_tracer.keys(): + order_ = order_tracer["order"] + order_tracer[layer_id] = order_ + order_tracer["order"] += 1 + order_tracer["layer"].append(layer_id) + ctx.order_tracer = order_tracer + ctx.task_flow = task_flow + ctx.group = group + ctx.layer = layer + ctx.sync_comm = sync_comm + ctx.trainable_params = trainable_params + ctx.param2buffer_size = param2buffer_size + + return inputs + + @staticmethod + def backward(ctx, *args): + # Load context value + order_tracer = ctx.order_tracer + task_flow = ctx.task_flow + group = ctx.group + layer = ctx.layer + trainable_params = ctx.trainable_params + param2buffer_size = ctx.param2buffer_size + sync_comm = ctx.sync_comm + layer_id = id(layer) + use_calc, sync_wait = False, False + if sync_comm: + use_calc, sync_wait = True, True + _allgather_buffer( + layer_id, + trainable_params, + group, + use_calc_stream=use_calc, + task_flow=task_flow, + sync_wait=sync_wait) + else: + _wait_layer(trainable_params, layer_id, task_flow, group, use_calc) + _create_params_grad(layer, trainable_params, param2buffer_size, + task_flow) + task_flow.use_calc[layer_id] = use_calc + if layer_id != order_tracer["layer"][0] and not sync_comm: + layer_next_id = order_tracer["layer"][order_tracer[layer_id] - 1] + _allgather_buffer( + layer_next_id, + trainable_params, + group, + use_calc_stream=use_calc, + task_flow=task_flow, + sync_wait=sync_wait) + + return args + + +class TaskFlow: + """ + Task flows, one way linked list for task acquisition. + """ + + def __init__(self, + full_param=dict(), + full_grad=dict(), + use_calc=dict(), + callback=None): + self.full_param = full_param + self.full_grad = full_grad + self.use_calc = use_calc + self.callback = callback + + +def _release_param(layer, trainable_params, param2buffer, rank, task_flow): + for param in trainable_params[id(layer)]: + # async communicate share weight not clear + param.use_count -= 1 + if param.use_count == 0: + param._clear() + if param.name in task_flow.full_param.keys(): + start, end = param2buffer[param.name][rank] + with paddle.amp.auto_cast(enable=False): + param.fw_storage = core.VarBase( + task_flow.full_param[param.name]._slice(start, end), + param.name + "@slice").detach().clone() + param.status = "part" + tmp_var = task_flow.full_param.pop(param.name) + tmp_var._clear() + return + + +def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream): + for param in trainable_params[layer_id]: + if param.status == "all": + param.use_count += 1 + continue + if param.name in task_flow.full_param.keys(): + full_param = task_flow.full_param[param.name] + with paddle.amp.auto_cast(enable=False): + paddle.device.cuda.synchronize() + core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( + param) + param.value().get_tensor()._set_dims(param.shape) + param.fw_storage._clear() + param.fw_storage = None + param.status = "all" + param.use_count += 1 + else: + _allgather_buffer( + layer_id, + trainable_params, + group, + use_calc_stream, + task_flow, + sync_wait=True) + break + return task_flow + + +def _allgather_buffer(layer_id, + trainable_params, + group, + use_calc_stream, + task_flow, + sync_wait=False): + for param in trainable_params[layer_id]: + if param.status == "all": + param.use_count += 1 + continue + with paddle.amp.auto_cast(enable=False): + full_param = _all_gather( + param.fw_storage, group, use_calc_stream=use_calc_stream) + if sync_wait: + with paddle.amp.auto_cast(enable=False): + dist.wait( + tensor=full_param, + group=group, + use_calc_stream=use_calc_stream) + core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( + param) + param.value().get_tensor()._set_dims(param.shape) + param.fw_storage._clear() + param.fw_storage = None + param.status = "all" + param.use_count += 1 + task_flow.full_param[param.name] = full_param + return task_flow + + +@paddle.no_grad() +def _create_params_grad(layer, trainable_params, param2buffer_size, task_flow): + for param in trainable_params[id(layer)]: + if param.name in task_flow.full_grad.keys(): + continue + assert isinstance(param2buffer_size[param.name], int) + temp_grad = paddle.zeros( + [param2buffer_size[param.name]], dtype=param.dtype) + param._copy_gradient_from( + core.VarBase(temp_grad._slice(0, param._numel()))) + task_flow.full_grad[param.name] = temp_grad + return task_flow + + +def _PartitionParam(param): + if not hasattr(param, "fw_storage"): + setattr(param, "fw_storage", None) + setattr(param, "bw_storage", None) + setattr(param, "status", "all") + setattr(param, "use_count", 0) + return param + + +def _VarBaseWrapper(param): + varbase = param.fw_storage + tmp_param = ParamBase( + shape=varbase.shape, dtype=varbase.dtype, name="slice@" + param.name) + varbase._share_buffer_to(tmp_param) + tmp_param.regularizer = param.regularizer + tmp_param.optimize_attr['learning_rate'] = param.optimize_attr[ + 'learning_rate'] + varbase._clear() + return tmp_param + + +def _OptimizerWrapper(optimizer, offload, group, update_params_slice): + if not hasattr(optimizer, "_optim"): + setattr(optimizer, "_optim", optimizer) + setattr(optimizer, "offload", offload) + setattr(optimizer, "group", group) + setattr(optimizer, "update_scaler", None) + setattr(optimizer, "update_slice", update_params_slice) + return optimizer + + +def _current_layer_params(layer): + return layer.parameters( + include_sublayers=False) + list(layer.extra_parameters) if hasattr( + layer, "extra_parameters") else layer.parameters( + include_sublayers=False) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index 272aada576be8a5182187fb6fe2e80bc6ac757bb..5f696195c1abcd4921b4358b8971fdbc982609da 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -152,6 +152,9 @@ def ShardingScaler(scaler): param_grads = [] param_grads_fp16 = [] param_grads_fp32 = [] + if hasattr(optimizer, "update_slice"): + optimizer.update_slice() + optimizer.update_scaler = True if getattr(optimizer._optim, '_param_groups', None) and isinstance( optimizer._optim._param_groups[0], dict): @@ -161,27 +164,21 @@ def ShardingScaler(scaler): if param._grad_ivar() is not None: param_grads.append(param._grad_ivar()) if param._grad_ivar( - ).dtype == core.VarDesc.VarType.FP16: + ).dtype in [core.VarDesc.VarType.FP16, paddle.float16]: param_grads_fp16.append(param._grad_ivar()) else: param_grads_fp32.append(param._grad_ivar()) else: - param_grads = [ - param._grad_ivar() for param in optimizer._optim._parameter_list - if param._grad_ivar() is not None - ] - param_grads_fp16 = [ - param._grad_ivar() for param in optimizer._optim._parameter_list - if (param._grad_ivar() is not None - ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16 - ) - ] - param_grads_fp32 = [ - param._grad_ivar() for param in optimizer._optim._parameter_list - if (param._grad_ivar() is not None - ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32 - ) - ] + for param in optimizer._optim._parameter_list: + if param.grad is not None: + param_grads.append(param.grad) + if param.grad.dtype in [ + core.VarDesc.VarType.FP16, paddle.float16 + ]: + param_grads_fp16.append(param.grad) + else: + param_grads_fp32.append(param.grad) + temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool)) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 67697fcfd839871c4795f334ee4ff1fe0e178332..c0c13866ccd55da0cef95a389cc5bf9221e68df1 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -34,6 +34,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2) list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2) +list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3) list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) @@ -250,6 +251,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2) list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2) + list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3) list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) @@ -1058,6 +1060,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120) + set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0bec9c454b0fdfaea4d96ac821bfe8f859eff5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py @@ -0,0 +1,233 @@ +# -*- coding: UTF-8 -*- + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import argparse +import ast +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Linear +from paddle.distributed import fleet +from paddle.fluid.dygraph import nn + +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3 +from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler + +epoch = 10 +batch_size = 32 +paddle.seed(2021) +np.random.seed(2021) +base_lr = 0.1 +momentum_rate = 0.9 +l2_decay = 1e-4 +fleet.init(is_collective=True) + + +class MLP(fluid.Layer): + def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): + super(MLP, self).__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + return y + + +def reader_decorator(linear_size=1000): + def __reader__(): + for _ in range(100): + img = np.random.rand(linear_size).astype('float32') + label = np.ones(1).astype('int64') + yield img, label + + return __reader__ + + +def optimizer_setting(model, use_pure_fp16, opt_group=False): + clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) + optimizer = paddle.optimizer.AdamW( + parameters=[{ + "params": model.parameters() + }] if opt_group else model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + multi_precision=use_pure_fp16) + + return optimizer + + +def train_mlp(model, + sharding_stage, + use_pure_fp16=False, + accumulate_grad=False, + opt_group=False, + recompute=False): + group = paddle.distributed.new_group([0, 1]) + if opt_group: + optimizer = optimizer_setting( + model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group) + else: + optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) + + if use_pure_fp16: + model = paddle.amp.decorate( + models=model, level='O2', save_dtype='float32') + scaler = paddle.amp.GradScaler(init_loss_scaling=32768) + scaler = ShardingScaler(scaler) + if sharding_stage == 2: + optimizer = ShardingOptimizerStage2( + params=model.parameters(), optim=optimizer, group=group) + model = ShardingStage2( + model, + optimizer, + group=group, + buffer_max_size=2**21, + accumulate_grads=accumulate_grad) + elif sharding_stage == 3: + model = ShardingStage3( + model, optimizer=optimizer, group=group, sync_comm=recompute) + + train_reader = paddle.batch( + reader_decorator(), batch_size=batch_size, drop_last=True) + + train_loader = paddle.io.DataLoader.from_generator( + capacity=32, + use_double_buffer=True, + iterable=True, + return_list=True, + use_multiprocess=True) + train_loader.set_sample_list_generator(train_reader) + + for eop in range(epoch): + model.train() + for batch_id, data in enumerate(train_loader()): + img, label = data + label.stop_gradient = True + img.stop_gradient = True + with paddle.amp.auto_cast(True, level='O2'): + out = model(img) + loss = paddle.nn.functional.cross_entropy( + input=out, label=label) + avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + if not accumulate_grad: + if not use_pure_fp16: + avg_loss.backward() + optimizer.step() + else: + scaler.scale(avg_loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + if accumulate_grad: + if not use_pure_fp16: + avg_loss.backward() + optimizer.step() + else: + scaler.scale(avg_loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.clear_grad() + if sharding_stage == 3: + model.get_all_parameters() + return model.parameters() + + +def test_stage2_stage3(): + mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP( + ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() + state_dict = mlp.state_dict() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + mlp3.set_state_dict(state_dict) + mlp4.set_state_dict(state_dict) + mlp5.set_state_dict(state_dict) + mlp6.set_state_dict(state_dict) + mlp7.set_state_dict(state_dict) + mlp8.set_state_dict(state_dict) + # fp32 + stage2_params = train_mlp( + mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True) + stage3_params = train_mlp( + mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True) + for i in range(len(stage2_params)): + for j in range(len(stage3_params)): + if stage2_params[i].name == stage3_params[j].name: + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[j].numpy(), + rtol=1e-6) + # fp32 accumulate grad + stage2_params = train_mlp( + mlp3, + sharding_stage=2, + use_pure_fp16=False, + accumulate_grad=True, + opt_group=True) + stage3_params = train_mlp( + mlp4, + sharding_stage=3, + use_pure_fp16=False, + accumulate_grad=True, + opt_group=True) + for i in range(len(stage2_params)): + for j in range(len(stage3_params)): + if stage2_params[i].name == stage3_params[j].name: + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[j].numpy(), + rtol=1e-6) + # fp16 + stage2_params = train_mlp( + mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False) + stage3_params = train_mlp( + mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False) + for i in range(len(stage2_params)): + for j in range(len(stage3_params)): + if stage2_params[i].name == stage3_params[j].name: + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[j].numpy(), + rtol=1e-6) + # fp16 recompute + stage3_params = train_mlp( + mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False) + stage3_params_re = train_mlp( + mlp8, + sharding_stage=3, + use_pure_fp16=True, + opt_group=False, + recompute=True) + for i in range(len(stage3_params)): + for j in range(len(stage3_params_re)): + if stage3_params[i].name == stage3_params_re[j].name: + np.testing.assert_allclose( + stage3_params[i].numpy(), + stage3_params_re[j].numpy(), + rtol=1e-6) + return + + +if __name__ == '__main__': + test_stage2_stage3() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flatten_contiguous_range.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flatten_contiguous_range.py new file mode 100644 index 0000000000000000000000000000000000000000..a4060349d4bed495011cfae7fa367ca23ee5d8eb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flatten_contiguous_range.py @@ -0,0 +1,115 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertFlattenContiguousRangeTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input(batch): + return np.random.random([2, batch, 4, 8, 3]).astype(np.float32) + + for batch in [1, 2, 4]: + for start_axis in range(5): + for stop_axis in range(start_axis, 5): + type = "flatten_contiguous_range" + op_outputs = { + "Out": ["output_data"], + "XShape": ["xshape_data"] + } + ops_config = [{ + "op_type": type, + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": op_outputs, + "op_attrs": { + "start_axis": start_axis, + "stop_axis": stop_axis, + } + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, batch)) + }, + outputs=["output_data"]) + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = {"input_data": [2, 1, 4, 8, 3]} + self.dynamic_shape.max_input_shape = {"input_data": [2, 4, 4, 8, 3]} + self.dynamic_shape.opt_input_shape = {"input_data": [2, 2, 4, 8, 3]} + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 >= 7000: + if dynamic_shape: + return 1, 2 + else: + if attrs[0]['start_axis'] == 0: + return 0, 3 + else: + return 1, 2 + else: + return 0, 3 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py new file mode 100644 index 0000000000000000000000000000000000000000..89d5f2e8c7b292592369651887fc72bcabcb77ea --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_sharding_stage3.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestDygraphShardingStage3(TestMultipleGpus): + + # check sharding logic as well as the accuracy with single mode + def test_dygraph_sharding_optimizer_stage3(self): + self.run_mnist_2gpu('dygraph_sharding_stage3.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_stack_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_stack_op_xpu.py index 68e5a6ccdbfb73fbc44fbfc07503d6e8752523a5..20446aee41ec7e11c8b9a6963e2a70b108086314 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_stack_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_stack_op_xpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -66,6 +66,15 @@ class TestStackOpBase(XPUOpTest): place = paddle.XPUPlace(0) self.check_output_with_place(place) + def test_check_grad(self): + if self.dtype == 'int64' or self.dtype == 'int32': + pass + else: + if paddle.is_compiled_with_xpu(): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, self.get_x_names(), 'Y') + class TestStackOp1(TestStackOpBase): def initParameters(self): @@ -81,11 +90,17 @@ class TestStackOp3(TestStackOpBase): def initParameters(self): self.axis = -1 + def test_check_grad(self): + pass + class TestStackOp4(TestStackOpBase): def initParameters(self): self.axis = -4 + def test_check_grad(self): + pass + class TestStackOp5(TestStackOpBase): def initParameters(self): @@ -113,7 +128,7 @@ class TestStackOpint(TestStackOpBase): self.num_inputs = 4 self.input_dim = (5, 6, 7) self.axis = 0 - self.dtype = 'int' + self.dtype = 'int32' def initParameters(self): self.num_inputs = 16