提交 04b9bfbd 编写于 作者: J jim19930609

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

include(FetchContent)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz)
set(LLVM_MD5 39d32b6be466781dddf5869318dcba53)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/infrt/llvm_b5149f4e66a49a98b67e8e2de4e24a4af8e2781b.tar.gz)
set(LLVM_MD5 022819bb5760817013cf4b8a37e97d5e)
set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm)
set(FETCHCONTENT_QUIET OFF)
......@@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
# To build with MLIR, the LLVM is build from source code using the following flags:
#[==[
cmake -G Ninja ../llvm \
cmake ../llvm -G "Unix Makefiles" \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="X86" \
......@@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_ZLIB=OFF \
-DLLVM_ENABLE_RTTI=ON \
-DLLVM_INSTALL_UTILS=ON \
-DCMAKE_INSTALL_PREFIX=./install
#]==]
# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit)
# The matched llvm-project version is b5149f4e66a49a98b67e8e2de4e24a4af8e2781b (currently a temporary commit)
add_definitions(${LLVM_DEFINITIONS})
......@@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS})
# The minimum needed libraries for MLIR IR parse and transform.
set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
set(MLIR_IR_LIBS MLIRAnalysis MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
# tb_base is the name of a xxx.td file (without the .td suffix)
......@@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base)
mlir_tablegen(${td_base}.cpp.inc -gen-op-defs)
if (mlir_tablegen_on_DIALECT)
mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT})
mlir_tablegen(${td_base}_dialect.cpp.inc --gen-dialect-defs -dialect=${mlir_tablegen_on_DIALECT})
endif()
add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
......
......@@ -46,8 +46,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
<< " is diabled by config in TensorRT";
return false;
}
return tensorrt::OpTeller::Global().Tell(node, no_calib_int8,
with_dynamic_shape);
bool is_ok = tensorrt::OpTeller::Global().Tell(node, no_calib_int8,
with_dynamic_shape);
if (!is_ok)
VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT";
return is_ok;
};
framework::ir::SubGraphFuser fuser(
......
......@@ -1416,6 +1416,7 @@ USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(transpose);
USE_TRT_CONVERTER(flatten);
USE_TRT_CONVERTER(flatten_contiguous_range);
USE_TRT_CONVERTER(matmul);
USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu);
......
......@@ -3,7 +3,7 @@ nv_library(tensorrt_converter
SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
anchor_generator_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* flatten_contiguous_range trt converter
*/
class FlattenContiguousRangeOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
int dims = input->getDimensions().nbDims;
int start_axis = BOOST_GET_CONST(int, op_desc.GetAttr("start_axis"));
int stop_axis = BOOST_GET_CONST(int, op_desc.GetAttr("stop_axis"));
nvinfer1::IShuffleLayer* layer = nullptr;
if (!engine_->with_dynamic_shape()) {
if (start_axis < 0) start_axis += dims + 1;
if (stop_axis < 0) stop_axis += dims + 1;
int dim_prod = 1;
nvinfer1::Dims flatten_dim;
flatten_dim.nbDims = dims - (stop_axis - start_axis);
for (int i = 0, j = 0; i < dims; ++i) {
if (start_axis <= i + 1 && i + 1 <= stop_axis) {
int dim_i = input->getDimensions().d[i];
PADDLE_ENFORCE_GT(dim_i, 0, platform::errors::InvalidArgument(
"flatten_contiguous_range input dim "
"should be > 0, but got %d.",
dim_i));
dim_prod *= dim_i;
if (i + 1 == stop_axis) {
flatten_dim.d[j++] = dim_prod;
}
} else {
flatten_dim.d[j++] = input->getDimensions().d[i];
}
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setReshapeDimensions(flatten_dim);
} else {
if (start_axis < 0) start_axis += dims;
if (stop_axis < 0) stop_axis += dims;
auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* shape_layer_itensor = shape_layer->getOutput(0);
nvinfer1::Dims start_dim, size_dim, stride_dim;
start_dim.nbDims = 1;
size_dim.nbDims = 1;
stride_dim.nbDims = 1;
start_dim.d[0] = start_axis;
size_dim.d[0] = stop_axis - start_axis + 1;
stride_dim.d[0] = 1;
auto* slice_layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor, start_dim,
size_dim, stride_dim);
uint32_t reduce_dim = 1;
auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *(slice_layer->getOutput(0)),
nvinfer1::ReduceOperation::kPROD, reduce_dim, true);
nvinfer1::ITensor* input_shape = nullptr;
if (start_axis == 0 && stop_axis == dims - 1) {
input_shape = reduce_prod_layer->getOutput(0);
} else {
std::vector<nvinfer1::ITensor*> itensors;
if (start_axis > 0) {
nvinfer1::Dims left_start_dim, left_size_dim, left_stride_dim;
left_start_dim.nbDims = 1;
left_size_dim.nbDims = 1;
left_stride_dim.nbDims = 1;
left_start_dim.d[0] = 0;
left_size_dim.d[0] = start_axis;
left_stride_dim.d[0] = 1;
auto* slice_layer_left = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *shape_layer_itensor, left_start_dim,
left_size_dim, left_stride_dim);
itensors.push_back(slice_layer_left->getOutput(0));
}
itensors.push_back(reduce_prod_layer->getOutput(0));
if (stop_axis < dims - 1) {
nvinfer1::Dims right_start_dim, right_size_dim, right_stride_dim;
right_start_dim.nbDims = 1;
right_size_dim.nbDims = 1;
right_stride_dim.nbDims = 1;
right_start_dim.d[0] = stop_axis + 1;
right_size_dim.d[0] = dims - stop_axis - 1;
right_stride_dim.d[0] = 1;
auto* slice_layer_right = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *shape_layer_itensor, right_start_dim,
right_size_dim, right_stride_dim);
itensors.push_back(slice_layer_right->getOutput(0));
}
auto* concat_layer = TRT_ENGINE_ADD_LAYER(
engine_, Concatenation, itensors.data(), itensors.size());
concat_layer->setAxis(0);
input_shape = concat_layer->getOutput(0);
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setInput(1, *input_shape);
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "flatten_contiguous_range", {output_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(flatten_contiguous_range,
FlattenContiguousRangeOpConverter);
......@@ -55,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller {
// #endif
#if IS_TRT_VERSION_GE(7000)
teller_set.insert("tile");
teller_set.insert("flatten_contiguous_range");
#endif
#if CUDA_VERSION >= 10020
teller_set.insert("reshape");
......@@ -531,6 +532,37 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (axis != 1) return false;
}
}
if (op_type == "flatten_contiguous_range") {
if (!with_dynamic_shape) {
int start_axis = BOOST_GET_CONST(int, desc.GetAttr("start_axis"));
int stop_axis = BOOST_GET_CONST(int, desc.GetAttr("stop_axis"));
auto x_var_name = desc.Input("X")[0];
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
int dims = x_shape.size();
if (start_axis < 0) start_axis += dims;
if (start_axis == 0) {
VLOG(3) << "TRT flatten_contiguous_range not support the "
"batch-dimension being changed";
return false;
}
if (stop_axis < 0) stop_axis += dims;
for (int i = start_axis; i <= stop_axis; ++i) {
if (x_shape[i] < 0) {
VLOG(3) << "On TRT static shape,flatten_contiguous_range input dim "
"should be > 0";
return false;
}
}
}
}
if (op_type == "gather") {
auto gather_inputs = desc.Inputs();
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/stack_op.h"
#include <string>
#ifdef PADDLE_WITH_XPU
#include <vector>
#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
......@@ -59,14 +62,44 @@ class StackXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class StackGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dy_dims = dy->dims();
if (axis < 0) axis += dy_dims.size() + 1;
auto dy_shape = framework::vectorize<int>(dy_dims);
std::vector<int> dx_dims_list(dx.size(), 1);
std::vector<T*> dx_lists;
for (auto out : dx) {
dx_lists.push_back(out->mutable_data<T>(ctx.GetPlace()));
}
int r = xpu::split<T>(dev_ctx.x_context(), dy->data<T>(), dx_lists,
dy_shape, dx_dims_list, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"The stack_grad XPU kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(stack,
ops::StackXPUKernel<plat::XPUDeviceContext, int64_t>,
ops::StackXPUKernel<plat::XPUDeviceContext, float>,
ops::StackXPUKernel<plat::XPUDeviceContext, int>,
ops::StackXPUKernel<plat::XPUDeviceContext, float>);
ops::StackXPUKernel<plat::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(stack_grad,
ops::StackGradXPUKernel<plat::XPUDeviceContext, float>,
ops::StackGradXPUKernel<plat::XPUDeviceContext, int>);
#endif
......@@ -300,6 +300,7 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
......@@ -333,6 +333,8 @@ XPUOpMap& get_kl2_ops() {
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
......
......@@ -77,7 +77,6 @@ add_subdirectory(paddle)
# MLIR td file generations
set(infrt_mlir_incs
ops_inc
basic_kernels_inc
test_kernels_inc
infrt_base_inc
......
......@@ -14,7 +14,7 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include <mlir/IR/MLIRContext.h>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
......
......@@ -2,7 +2,6 @@ core_gather_headers()
gather_srcs(infrt_src SRCS
dialect.cc
types.cc
basic_kernels.cc
test_kernels.cc
infrt_base.cc
......@@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS
pd_types.cc
pd_ops.cc
)
mlir_tablegen_on(ops)
mlir_tablegen_on(basic_kernels)
mlir_tablegen_on(test_kernels)
mlir_tablegen_on(infrt_base DIALECT infrt)
......@@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite)
# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code
add_executable(infrtopt opt.cc)
target_link_libraries(infrtopt infrt ${mlir_libs})
add_dependencies(infrtopt infrt)
target_link_libraries(infrtopt infrt)
add_executable(print-ir print_ir.cc)
target_link_libraries(print-ir infrt ${mlir_libs})
......
......@@ -17,17 +17,17 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include "paddle/infrt/dialect/dense_tensor.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
using namespace mlir; // NOLINT
static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT
......@@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT
static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(32, result.getContext()), parser, result);
IntegerType::get(result.getContext(), 32), parser, result);
}
static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(64, result.getContext()), parser, result);
IntegerType::get(result.getContext(), 64), parser, result);
}
static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
......@@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
}
static void print(OpAsmPrinter &p, CallOp op) { // NOLINT
p << "infrt.call " << op.getAttr("callee") << "(";
p << "infrt.call " << op->getAttr("callee") << "(";
p.printOperands(op.getOperands());
p << ")";
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
p.printOptionalAttrDict(op->getAttrs(), {"callee"});
p << " : ";
}
......@@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); }
static LogicalResult verify(ConstantI64Op op) { return success(); }
static LogicalResult verify(ReturnOp op) {
auto function = dyn_cast<FuncOp>(op.getParentOp());
auto function = dyn_cast<FuncOp>(op->getParentOp());
if (!function) return success();
......@@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) {
return success();
}
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.cpp.inc"
} // namespace infrt::dialect
......@@ -13,12 +13,9 @@
// limitations under the License.
#pragma once
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
using namespace mlir; // NOLINT
namespace infrt::dialect {
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.hpp.inc"
} // namespace infrt::dialect
......@@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> {
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
StringRef getCallee() { return callee(); }
mlir::StringRef getCallee() { return callee(); }
mlir::FunctionType getCalleeType();
}];
}
......@@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> {
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result",
[{ build(b, result, llvm::None); }]>];
let builders = [OpBuilder<(ins),
[{ build($_builder, $_state, llvm::None); }]>];
}
class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> {
......
......@@ -17,12 +17,11 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
......@@ -31,68 +30,37 @@
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt::dt {
namespace infrt {
namespace dt {
void DTDialect::initialize() {
allowUnknownTypes();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"
>();
}
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) {
if (key.equals_lower("x86"))
if (key.equals_insensitive("x86"))
return TargetType::X86;
else if (key.equals_lower("cuda"))
else if (key.equals_insensitive("cuda"))
return TargetType::CUDA;
else
return llvm::None;
}
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) {
if (key.equals_lower("nchw"))
if (key.equals_insensitive("nchw"))
return LayoutType::NCHW;
else if (key.equals_lower("nhwc"))
else if (key.equals_insensitive("nhwc"))
return LayoutType::NHWC;
else
return llvm::None;
}
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) {
if (key.equals_lower("i32"))
if (key.equals_insensitive("i32"))
return PrecisionType::I32;
else if (key.equals_lower("f32"))
else if (key.equals_insensitive("f32"))
return PrecisionType::F32;
else
return llvm::None;
......@@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; }
PrecisionType TensorType::precision() { return getImpl()->precision_; }
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType) {
os << "TensorType<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">";
return os;
......@@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
raw_ostream &operator<<(raw_ostream &os, TargetType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type) {
switch (type) {
case (TargetType::X86):
os << "X86";
......@@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) {
return os;
}
raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type) {
switch (type) {
case (LayoutType::NCHW):
os << "NCHW";
......@@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
return os;
}
raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type) {
switch (type) {
case (PrecisionType::I32):
os << "I32";
......@@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
return os;
}
static Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = Identifier::get("t", context);
return OpaqueType::get(t_dialect, "tensor", context);
static mlir::Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = mlir::Identifier::get("t", context);
return mlir::OpaqueType::get(t_dialect, "tensor");
}
static ParseResult parseCreateUninitTensorOp(
OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
static mlir::ParseResult parseCreateUninitTensorOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
auto loc = parser.getCurrentLocation();
::mlir::Type outputRawTypes[1];
::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes);
mlir::Type outputRawTypes[1];
::llvm::ArrayRef<mlir::Type> outputTypes(outputRawTypes);
mlir::ArrayAttr shapeAttr;
if (parser.parseAttribute(shapeAttr,
parser.getBuilder().getI64Type(),
"shape",
result.attributes))
return failure();
if (parser.parseOptionalAttrDict(result.attributes)) return failure();
return mlir::failure();
if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure();
if (parser.parseArrow()) return failure();
if (parser.parseType(outputRawTypes[0])) return failure();
if (parser.parseArrow()) return mlir::failure();
if (parser.parseType(outputRawTypes[0])) return mlir::failure();
if (!outputRawTypes[0].isa<TensorType>())
return parser.emitError(loc, "invalid kind of type specified");
result.addTypes(outputTypes);
return success();
return mlir::success();
}
template <typename CreateUninitTensorOp>
static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT
static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p, // NOLINT
CreateUninitTensorOp op) {
p << CreateUninitTensorOp::getOperationName();
p << " ";
p.printAttributeWithoutType(op.shapeAttr());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
p << " -> ";
p << op.getOperation()->getResultTypes();
}
// TODO(shibo): can be removed?
// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser,
// OperationState& result) {
// auto loc = parser.getCurrentLocation();
// ::mlir::OpAsmParser::OperandType inputRawOperands[1];
// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType>
// inputOperands(inputRawOperands);
// ::mlir::Type inputRawTypes[1];
// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes);
//
// if (parser.parseOperand(inputRawOperands[0])) return failure();
//
// if (parser.parseColon()) return failure();
// if (parser.parseType(inputRawTypes[0])) return failure();
// if (!inputRawTypes[0].isa<TensorType>())
// return parser.emitError(loc, "invalid kind of type specified");
//
// Attribute value_attr;
// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands))
// return failure();
// if (parser.parseAttribute(value_attr, "value", result.attributes)) return
// failure();
// return success();
//}
// TODO(shibo): can be removed?
// template <typename FillTensorOp>
// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) {
// p << FillTensorOp::getOperationName();
// p << " ";
// p.printOperand(op.getOperand());
// p << " : ";
// p << op.getOperation()->getOperandTypes();
// p << " ";
// p << op.getAttr("value");
//}
static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return failure();
static mlir::ParseResult parseSetTensorOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
llvm::SmallVector<mlir::OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return mlir::failure();
auto tensor_type = getTensorType(result.getContext());
Attribute value_attr;
return failure(
mlir::Attribute value_attr;
return mlir::failure(
parser.resolveOperand(operands[0], tensor_type, result.operands) ||
parser.parseAttribute(value_attr, "values", result.attributes));
}
template <typename SetTensorOp>
static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT
static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) { // NOLINT
p << SetTensorOp::getOperationName() << " ";
p.printOperand(op.getOperand());
p << " " << op.getAttr("values");
p << " " << op->getAttr("values");
}
} // namespace dt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT
} // namespace infrt::dt
#include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc"
......@@ -19,13 +19,8 @@
#include <string>
using namespace mlir; // NOLINT
namespace infrt::dt {
namespace detail {
struct TensorTypeStorage;
} // namespace detail
namespace infrt {
namespace dt {
enum class TargetType : uint8_t { X86, CUDA };
enum class LayoutType : uint8_t { NCHW, NHWC };
enum class PrecisionType : uint8_t { I32, F32 };
......@@ -34,9 +29,39 @@ llvm::Optional<TargetType> GetTargetType(mlir::StringRef key);
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key);
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key);
raw_ostream &operator<<(raw_ostream &os, TargetType type);
raw_ostream &operator<<(raw_ostream &os, LayoutType type);
raw_ostream &operator<<(raw_ostream &os, PrecisionType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type);
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
class TensorType : public mlir::Type::TypeBase<TensorType,
mlir::Type,
......@@ -52,7 +77,7 @@ class TensorType : public mlir::Type::TypeBase<TensorType,
PrecisionType precision();
};
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType);
class TensorMapType : public mlir::Type::TypeBase<TensorMapType,
mlir::Type,
......@@ -70,10 +95,10 @@ class StringType
static StringType get();
static StringType get(mlir::MLIRContext *context);
};
} // namespace dt
} // namespace infrt
#include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.hpp.inc"
} // namespace infrt::dt
......@@ -14,9 +14,11 @@
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include <llvm/Support/raw_ostream.h>
#include <string>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
struct MyScopedDiagnosicHandler::Impl {
Impl() : diag_stream_(diag_str_) {}
......@@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) {
return mlir::failure(true);
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -18,7 +18,8 @@
#include <memory>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
/**
* A scoped diagnostic handler to help debug MLIR process.
......@@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler {
std::unique_ptr<Impl> impl_;
};
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -13,24 +13,26 @@
// limitations under the License.
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::hlir::dialect {
namespace infrt {
namespace hlir {
namespace dialect {
class CinnDialect : public ::mlir::Dialect {
class CinnDialect : public mlir::Dialect {
public:
explicit CinnDialect(::mlir::MLIRContext* ctx);
explicit CinnDialect(mlir::MLIRContext* ctx);
//! We should register this function in dialect
static llvm::StringRef getDialectNamespace() {
return "infrt::hlir::dialect";
}
};
} // namespace infrt::hlir::dialect
} // namespace dialect
} // namespace hlir
} // namespace infrt
......@@ -18,7 +18,8 @@
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/test_kernels.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
// ----INFRTDialect definition begin----
void INFRTDialect::initialize() {
......@@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type,
// ----INFRTDialect definition end----
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -18,19 +18,17 @@
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
#include "paddle/infrt/dialect/infrt_base.hpp.inc"
namespace infrt::dialect {
class INFRTDialect : public ::mlir::Dialect {
explicit INFRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(),
context,
::mlir::TypeID::get<INFRTDialect>()) {
namespace infrt {
namespace dialect {
class INFRTDialect : public mlir::Dialect {
explicit INFRTDialect(mlir::MLIRContext *context)
: mlir::Dialect(
getDialectNamespace(), context, mlir::TypeID::get<INFRTDialect>()) {
initialize();
}
......@@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect {
mlir::DialectAsmPrinter &printer) const override;
void initialize();
friend class ::mlir::MLIRContext;
friend class mlir::MLIRContext;
public:
static ::llvm::StringRef getDialectNamespace() { return "infrt"; }
};
} // namespace infrt::dialect
namespace mlir {
} // namespace dialect
template <typename T>
static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
......@@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
return b.getIntegerAttr(b.getI32Type(), constant);
}
static mlir::SmallVector<::mlir::Value, 4> cvtValueToValueRange(
static mlir::SmallVector<mlir::Value, 4> cvtValueToValueRange(
const mlir::Value &operand) {
return mlir::SmallVector<::mlir::Value, 4>(1, operand);
return mlir::SmallVector<mlir::Value, 4>(1, operand);
}
static mlir::SmallVector<::mlir::Value, 4> concatTwoValueRange(
static mlir::SmallVector<mlir::Value, 4> concatTwoValueRange(
mlir::ValueRange operand_0, mlir::ValueRange operand_1) {
mlir::SmallVector<::mlir::Value, 4> operands;
mlir::SmallVector<mlir::Value, 4> operands;
operands.append(operand_0.begin(), operand_0.end());
operands.append(operand_1.begin(), operand_1.end());
return operands;
}
} // namespace mlir
} // namespace infrt
......@@ -28,11 +28,11 @@ def TensorMapType :
def BufferType : OpaqueType<"b", "buffer", "buffer">;
class INFRT_createI32Attr<string value> : NativeCodeCall<
"mlir::createI32Attr($_builder, $_loc, " # value # ")">;
"infrt::createI32Attr($_builder, $_loc, " # value # ")">;
def INFRT_cvtValueToValueRange : NativeCodeCall<
"mlir::cvtValueToValueRange($0)">;
"infrt::cvtValueToValueRange($0)">;
def INFRT_concatTwoValueRange : NativeCodeCall<
"mlir::concatTwoValueRange($0, $1)">;
"infrt::concatTwoValueRange($0, $1)">;
#endif // INFRT_BASE
......@@ -23,12 +23,10 @@
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT
registry.insert<ts::TensorShapeDialect>();
registry.insert<dialect::INFRTDialect>();
registry.insert<dt::DTDialect>();
registry.insert<mlir::pd::PaddleDialect>();
void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
registry.insert<ts::TensorShapeDialect,
dialect::INFRTDialect,
dt::DTDialect,
mlir::pd::PaddleDialect>();
}
} // namespace infrt
......@@ -14,10 +14,8 @@
#pragma once
#include "mlir/IR/Dialect.h"
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT
void registerCinnDialects(mlir::DialectRegistry &registry); // NOLINT
} // namespace infrt
......@@ -16,8 +16,8 @@
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
#include <unordered_map>
......@@ -30,12 +30,15 @@
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) {
// context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
mlir::DialectRegistry registry;
registerCinnDialects(registry);
context->appendDialectRegistry(registry);
// Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is
// enough。Don't need StandardOpsDialect.
// context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
......@@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context) {
// context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::DialectRegistry registry;
registerCinnDialects(registry);
context->appendDialectRegistry(registry);
mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
......@@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
return mlir::parseSourceFile(std::string(file_name), context);
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -15,16 +15,17 @@
#pragma once
#include <glog/logging.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/BuiltinOps.h>
#include <string>
#include <memory>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source);
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context);
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -17,14 +17,15 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Parser.h>
#include <string>
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
TEST(MlirLoader, basic) {
mlir::MLIRContext context;
......@@ -42,8 +43,7 @@ func @main() -> f32 {
)ROC";
auto module = LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
LOG(INFO) << "module name: " << module->getOperationName().data();
for (auto func : module->getOps<mlir::FuncOp>()) {
LOG(INFO) << "get func " << func.getName().str();
......@@ -54,4 +54,5 @@ func @main() -> f32 {
}
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -20,5 +20,5 @@ func @main() -> tensor<?xf32> {
%c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
infrt.return %e2 : tensor<?xf32>
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
}
\ No newline at end of file
......@@ -11,5 +11,5 @@ func @main() -> tensor<?xf32> {
%c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
%d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
infrt.return %d : tensor<?x3x256x256xf32>
"pd.fetch"(%d) {name="output"} :(tensor<?x3x256x256xf32>)->()
}
\ No newline at end of file
......@@ -18,5 +18,5 @@ func @main() -> tensor<?xf32> {
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%e2) :(tensor<?xf32>)->()
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
}
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<INFRT_Dialect, mnemonic, traits>;
......@@ -12,34 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <llvm/Support/CommandLine.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Dialect.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
#include "paddle/infrt/dialect/mlir_loader.h"
int main(int argc, char **argv) {
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::DialectRegistry registry;
infrt::registerCinnDialects(registry);
mlir::registerCanonicalizerPass();
return mlir::failed(
mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry));
mlir::MlirOptMain(argc, argv, "infrt mlir pass driver", registry));
}
......@@ -16,7 +16,7 @@ def PD_Dialect : Dialect {
This dialect contains the PaddlePaddle operators.
}];
let cppNamespace = "::mlir::pd";
let cppNamespace = "mlir::pd";
}
class PD_Op<string mnemonic, list<OpTrait> traits = []> :
......
......@@ -14,10 +14,15 @@
#include "paddle/infrt/dialect/pd_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include <mlir/IR/Matchers.h>
#include <mlir/IR/PatternMatch.h>
#include "paddle/infrt/dialect/infrt_base.h"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
namespace mlir {
namespace pd {
PaddleDialect::PaddleDialect(MLIRContext *context)
......@@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder,
return builder.create<ConstantOp>(loc, value);
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
void ConstantOp::build(OpBuilder &builder,
OperationState &state,
Attribute value) {
......@@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes(
inferredReturnTypes.push_back(attributes.get("value").getType());
return success();
}
::mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<::mlir::Attribute> operands) {
mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<mlir::Attribute> operands) {
return value();
}
......@@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes(
return success();
}
void ElementwiseAdd::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseMulAdd>(context);
}
::mlir::OpFoldResult ElementwiseAdd::fold(
mlir::OpFoldResult ElementwiseAdd::fold(
llvm::ArrayRef<mlir::Attribute> operands) {
if (getElementTypeOrSelf(getType()).isa<FloatType>()) {
if (!operands[0] || !operands[1]) return {};
......@@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes(
}
void ReluOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseFCRelu>(context);
}
void FusedRepeatedFCRelu::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseRepeatedFCRelu2>(context);
}
void BatchNormOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseBatchNormWithConvPattern>(context);
}
......
......@@ -14,21 +14,20 @@
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace mlir {
namespace pd {
......@@ -53,9 +52,8 @@ class PaddleDialect : public Dialect {
}
};
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace pd
} // namespace mlir
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
......@@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> {
def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let summary = "fetch Op";
let description = [{
Return the output tensor from the subgraph.
}];
let arguments = (ins PD_Tensor :$inputs, StrAttr:$name);
}
def PD_ReturnOp : PD_Op<"return", [Terminator]> {
let summary = "return Op";
let description = [{
Fetch tensor from the graph.
}];
......@@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let arguments = (ins Variadic<PD_Tensor>:$inputs);
}
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> {
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
......@@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte
let hasFolder = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">,
OpBuilder<(ins "Attribute":$value)>,
];
}
......
......@@ -18,12 +18,11 @@
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
namespace mlir {
namespace PD {
......
......@@ -11,26 +11,25 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <llvm/ADT/Optional.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/ScopedPrinter.h>
#include <mlir/IR/BuiltinOps.h>
#include <llvm/Support/raw_os_ostream.hv
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Block.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/Region.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
#include "llvm/ADT/Optional.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
......@@ -114,17 +113,15 @@ int main(int argc, char **argv) {
mlir::registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv, "mlir demo");
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
// context->allowUnregisteredDialects();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::DialectRegistry registry;
infrt::registerCinnDialects(registry);
mlir::MLIRContext context(registry);
// mlir will verify module automatically after parsing.
// https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051
// mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source,
// context);
mlir::OwningModuleRef module_ref =
mlir::parseSourceFile(inputFilename, context);
mlir::parseSourceFile(inputFilename, &context);
std::cout << "----------print IR Structure begin----------" << std::endl;
printOperation(module_ref->getOperation(), 0);
std::cout << "----------print IR Structure end----------" << std::endl;
......
......@@ -17,16 +17,16 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::ts {
namespace infrt {
namespace ts {
using namespace mlir; // NOLINT
void TensorShapeDialect::initialize() {
......@@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const {
return Type();
}
void TensorShapeDialect::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &os) const {
void TensorShapeDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &os) const {
if (type.isa<ShapeType>()) {
os << "shape";
return;
......@@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type,
}
llvm_unreachable("unexpected 'shape' type kind");
}
} // namespace ts
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT
} // namespace infrt::ts
#include "paddle/infrt/dialect/tensor_shape_dialect.cpp.inc"
......@@ -17,7 +17,8 @@
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::ts {
namespace infrt {
namespace ts {
class ShapeType
: public mlir::Type::TypeBase<ShapeType, mlir::Type, mlir::TypeStorage> {
......@@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase<PartialShapeType,
public:
using Base::Base;
};
} // namespace ts
} // namespace infrt
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.hpp.inc"
#include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc"
} // namespace infrt::ts
......@@ -19,7 +19,7 @@ def TensorShapeDialect : Dialect {
def TS_Shape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::ShapeType>()">, "!ts.shape type">,
BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
let typeDescription = [{
let description = [{
`!ts.shape type` represents a static tensor shape.
}];
}
......@@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
def TS_PartialShape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::PartialShapeType>()">, "!ts.partial_shape type">,
BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> {
let typeDescription = [{
let description = [{
`!ts.partial_shape type` represents either a static tensor shape, unranked
tensor shape or a ranked tensor shape with unknown dimension sizes.
}];
......
......@@ -11,10 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <llvm/Support/CommandLine.h>
#include <mlir/Pass/PassManager.h>
#include <iostream>
#include <string>
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
......
......@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
#include <paddle/infrt/dialect/pd_ops.h>
#include <list>
#include <unordered_set>
#include <vector>
#include "llvm/ADT/SetVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -32,9 +31,9 @@ namespace {
// Reference the function nameed "FlexibleDFS" but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc.
bool reverseDfs(std::vector<::mlir::Operation *> source,
const std::function<bool(const ::mlir::Operation *)> &func) {
std::unordered_set<const ::mlir::Operation *> visited;
bool reverseDfs(std::vector<mlir::Operation *> source,
const std::function<bool(const mlir::Operation *)> &func) {
std::unordered_set<const mlir::Operation *> visited;
while (!source.empty()) {
auto node = source.back();
source.pop_back();
......@@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
auto values = node->getOperands();
for (auto value : values) {
// if the value is a block argument, the node is nullptr.
::mlir::Operation *node = value.getDefiningOp();
mlir::Operation *node = value.getDefiningOp();
if (node != nullptr && !visited.count(node)) {
source.emplace_back(node);
}
......@@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
}
// merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
::mlir::pd::GraphOp first,
::mlir::pd::GraphOp second) {
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
mlir::pd::GraphOp first,
mlir::pd::GraphOp second) {
// comput inputs and outputs
::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs;
for (::mlir::Value input : second.getOperands()) {
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) {
if (input.getDefiningOp() != first) {
inputs.push_back(input);
}
}
::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping;
for (::mlir::Value output : first.getResults()) {
for (::mlir::Operation *user : output.getUsers()) {
::llvm::DenseMap<mlir::Value, unsigned int> op_output_mapping;
for (mlir::Value output : first.getResults()) {
for (mlir::Operation *user : output.getUsers()) {
if (user != second && user->getParentOp() != second) {
op_output_mapping[output] = outputs.size();
outputs.push_back(output);
......@@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
}
}
}
auto fetch_op = second.getBody()->getTerminator();
outputs.append(fetch_op->getOperands().begin(),
fetch_op->getOperands().end());
::llvm::SmallVector<::mlir::Type, 4> fetch_types;
auto return_op = second.getBody()->getTerminator();
outputs.append(return_op->getOperands().begin(),
return_op->getOperands().end());
::llvm::SmallVector<mlir::Type, 4> return_types;
for (auto value : outputs) {
fetch_types.push_back(value.getType());
return_types.push_back(value.getType());
}
// create the new graph op
builder.setInsertionPoint(first);
auto loc = first.getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs);
::mlir::Block *block = new ::mlir::Block;
auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(),
second.getBody()->getOperations(),
......@@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
copy_range.begin(),
copy_range.end());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, outputs);
builder.create<mlir::pd::ReturnOp>(loc, outputs);
graph_op.body().push_back(block);
// mapping the output
unsigned int num_result = first.getNumResults();
fetch_op = first.getBody()->getTerminator();
return_op = first.getBody()->getTerminator();
for (unsigned int index = 0; index < num_result; ++index) {
auto origin_value = first.getResult(index);
if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
origin_value.replaceAllUsesWith(fetch_op->getOperand(index));
origin_value.replaceAllUsesWith(return_op->getOperand(index));
} else {
auto inner_value = fetch_op->getOperand(index);
auto inner_value = return_op->getOperand(index);
auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
while (!origin_value.use_empty()) {
auto replace_value =
......@@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) { // NOLINT
llvm::SetVector<Operation *> toSort;
llvm::SetVector<mlir::Operation *> toSort;
if (body.empty()) return;
for (auto it = body.rbegin(); it != body.rend(); ++it) {
toSort.insert(&*it);
}
llvm::SetVector<Operation *> result =
::mlir::topologicalSort(std::move(toSort));
llvm::SetVector<mlir::Operation *> result =
mlir::topologicalSort(std::move(toSort));
for (auto *op : result) {
op->moveBefore(body.getTerminator());
}
......@@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT
// Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() {
mlir::Block &body = getFunction().front();
::mlir::OpBuilder builder(&body, body.begin());
mlir::OpBuilder builder(&body, body.begin());
bool changed = false;
do {
changed = false;
for (auto &op : body) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) {
::mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op);
mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue;
// get all dst input nodes except src.
std::vector<::mlir::Operation *> source_nodes;
std::vector<mlir::Operation *> source_nodes;
for (auto operand : user_op->getOperands()) {
auto input = operand.getDefiningOp();
if (input != &op && input != nullptr) {
......@@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() {
}
}
// Reverse DFS from the source_nodes.
if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) {
return n == &op;
})) {
if (!reverseDfs(source_nodes,
[&op](const mlir::Operation *n) { return n == &op; })) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true;
break;
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -28,15 +28,15 @@ namespace trt {
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* "pd.fetch" %d, %f
*
......@@ -47,13 +47,13 @@ namespace trt {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* "pd.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*/
class trtGraphFusePass
: public ::mlir::PassWrapper<trtGraphFusePass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtGraphFusePass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override;
......
......@@ -14,7 +14,7 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "mlir/IR/Builders.h"
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
......@@ -22,24 +22,24 @@ namespace infrt {
namespace trt {
// Implementation of the trtGraphSplitPass。
void trtGraphSplitPass::runOnFunction() {
std::vector<::mlir::pd::GraphOp> worklist;
::mlir::Block& block = getFunction().front();
std::vector<mlir::pd::GraphOp> worklist;
mlir::Block& block = getFunction().front();
for (auto& op : block) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op);
}
}
while (!worklist.empty()) {
::mlir::pd::GraphOp graph_op = worklist.back();
mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back();
::mlir::Block* body = graph_op.getBody();
auto fetch_op = body->getTerminator();
graph_op.replaceAllUsesWith(fetch_op->getOperands());
mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator();
graph_op.replaceAllUsesWith(return_op->getOperands());
auto copy_range = body->without_terminator();
block.getOperations().splice(::mlir::Block::iterator(graph_op),
block.getOperations().splice(mlir::Block::iterator(graph_op),
body->getOperations(),
copy_range.begin(),
copy_range.end());
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -31,9 +31,9 @@ namespace trt {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* "pd.return" (%n, %s)
* } ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*
* destination func:
......@@ -42,11 +42,11 @@ namespace trt {
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*/
class trtGraphSplitPass
: public ::mlir::PassWrapper<trtGraphSplitPass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtGraphSplitPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void runOnFunction() override;
......
......@@ -14,49 +14,48 @@
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "mlir/IR/Builders.h"
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
// Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() {
::mlir::Block &body = getFunction().front();
std::vector<::mlir::Operation *> worklist;
mlir::Block &body = getFunction().front();
std::vector<mlir::Operation *> worklist;
worklist.reserve(body.getOperations().size());
for (auto &op : body) {
worklist.push_back(&op);
}
// Build GraphOp.
::mlir::OpBuilder builder(&body, body.begin());
mlir::OpBuilder builder(&body, body.begin());
while (!worklist.empty()) {
auto *op = worklist.back();
worklist.pop_back();
if (op == nullptr) continue;
auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op);
auto op1 = ::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op);
if (op1) continue;
auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op);
auto op2 = ::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op);
if (op2) continue;
auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op);
auto op3 = ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op);
if (op3) continue;
builder.setInsertionPoint(op);
auto loc = getFunction().getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(
auto graph_op = builder.create<mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) {
::llvm::SmallVector<mlir::Value, 4>{graph_op.getODSResults(0)}) {
tblgen_repl_values.push_back(v);
}
op->replaceAllUsesWith(tblgen_repl_values);
// Build graph op.
::mlir::Block *block = new ::mlir::Block;
mlir::Block *block = new mlir::Block;
graph_op.body().push_back(block);
op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, op->getResults());
builder.create<mlir::pd::ReturnOp>(loc, op->getResults());
}
}
} // namespace trt
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -29,7 +29,7 @@ namespace trt {
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*
* destination func:
......@@ -37,23 +37,23 @@ namespace trt {
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
* TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt.
*/
class trtOpTellerPass
: public ::mlir::PassWrapper<trtOpTellerPass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtOpTellerPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override;
......
......@@ -13,27 +13,25 @@
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt {
namespace trt {
TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect("trt", context, ::mlir::TypeID::get<TensorRTDialect>()) {
TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>();
#undef GET_OP_LIST
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
......@@ -14,37 +14,32 @@
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt {
namespace trt {
class TensorRTDialect : public ::mlir::Dialect {
class TensorRTDialect : public mlir::Dialect {
public:
explicit TensorRTDialect(::mlir::MLIRContext* context);
explicit TensorRTDialect(mlir::MLIRContext* context);
static llvm::StringRef getDialectNamespace() { return "trt"; }
};
// mlir bug。 can be removed safety when update mlir to llvm11.
using namespace mlir; // NOLINT
} // namespace trt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
......@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/test_kernels.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace infrt::dialect {
#include <mlir/IR/Builders.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
namespace infrt {
namespace dialect {
//===----------------------------------------------------------------------===//
// BenchmarkOp
//===----------------------------------------------------------------------===//
......@@ -32,65 +31,67 @@ namespace infrt::dialect {
// ...
// }
static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
StringAttr nameAttr;
static mlir::ParseResult parseBenchmarkOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
mlir::StringAttr nameAttr;
if (parser.parseAttribute(nameAttr, "name", result.attributes))
return failure();
return mlir::failure();
// Parse the operands, e.g. (%c : i32, %d : f32)
if (parser.parseLParen()) return failure();
if (parser.parseLParen()) return mlir::failure();
SmallVector<OpAsmParser::OperandType, 4> operands;
SmallVector<Type, 4> types;
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
llvm::SmallVector<mlir::Type, 4> types;
llvm::SMLoc type_loc = parser.getCurrentLocation();
if (parser.parseOptionalRParen()) {
// Parse non-empty operands
do {
// Parse %c : i32,
OpAsmParser::OperandType operand;
Type type;
mlir::OpAsmParser::OperandType operand;
mlir::Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
return mlir::failure();
operands.push_back(operand);
types.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen()) return failure();
if (parser.parseRParen()) return mlir::failure();
}
if (parser.resolveOperands(operands, types, type_loc, result.operands))
return failure();
return mlir::failure();
// Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1
do {
StringRef attr;
Attribute resultAttr;
mlir::StringRef attr;
mlir::Attribute resultAttr;
if (parser.parseKeyword(&attr) || parser.parseEqual() ||
parser.parseAttribute(resultAttr,
parser.getBuilder().getIntegerType(32),
attr,
result.attributes))
return failure();
} while (succeeded(parser.parseOptionalComma()));
return mlir::failure();
} while (mlir::succeeded(parser.parseOptionalComma()));
// Set the default attribute num_warmup_runs to 1 if unset
auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) {
bool found = llvm::any_of(result.attributes,
[attr_name](const NamedAttribute &attr) {
return attr.first == attr_name;
[attr_name](const mlir::NamedAttribute &attr) {
return attr.getName() == attr_name;
});
if (!found) {
IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value);
mlir::IntegerAttr default_val =
parser.getBuilder().getI32IntegerAttr(value);
result.addAttribute(attr_name, default_val);
}
};
setDefaultAttrIfUnset("num_warmup_runs", 1);
Region *target = result.addRegion();
mlir::Region *target = result.addRegion();
return parser.parseRegion(*target,
operands,
types,
......@@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
// max_count = 100, duration_secs = 1 {
// ...
// }
static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
static void print(mlir::OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p << "infrt.benchmark ";
// Print the name attribute, e.g "add.i32"
auto name_attr = op.getAttr("name");
auto name_attr = op->getAttr("name");
p << name_attr;
// Print the operands and types, e.g. (%c : i32, %d : f32)
......@@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
bool need_comma = false;
// Print the attributes, e.g. max_count = 100, duration_secs = 1
for (auto &name_attr : op.getAttrs()) {
auto id = name_attr.first;
for (auto &name_attr : op->getAttrs()) {
auto id = name_attr.getName();
if (id == "name") continue;
if (need_comma) p << ", ";
auto attr = name_attr.second;
auto attr = name_attr.getValue();
p << id << " = ";
if (auto int_attr = attr.dyn_cast<IntegerAttr>()) {
if (auto int_attr = attr.dyn_cast<mlir::IntegerAttr>()) {
int_attr.getValue().print(p.getStream(), /*isSigned=*/false);
} else {
op.emitOpError("Unexpected attribute");
......@@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
static LogicalResult verify(BenchmarkOp op) {
static mlir::LogicalResult verify(BenchmarkOp op) {
// Verify that the target benchmark region has exactly one return value.
auto &region = op.region();
auto &last_op = region.front().back();
......@@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) {
"incorrect number of return values. One return value is expected");
}
return success();
return mlir::success();
}
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.cpp.inc"
} // namespace infrt::dialect
......@@ -13,11 +13,8 @@
// limitations under the License.
#pragma once
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::dialect {
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.hpp.inc"
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/types.h"
namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/StandardTypes.h>
......@@ -23,7 +23,8 @@
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct CoreRuntime::Impl {
KernelRegistry* kernel_registry{};
......@@ -90,4 +91,5 @@ llvm::SmallVector<ValueRef, 4> CoreRuntime::GetResults(
CoreRuntime::~CoreRuntime() {}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -22,7 +22,8 @@
#include "paddle/infrt/host_context/value.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class KernelRegistry;
class OpExecutable;
......@@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime {
OpExecutableBuilder* NewOpExecutable(const std::string& op_name);
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -21,7 +21,8 @@
#include "llvm/ADT/SmallVector.h"
#include "paddle/infrt/host_context/value.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
/**
* KernelFrame captures the states(input arguments, attributes, results)
......@@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame {
}
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -18,7 +18,8 @@
#include "paddle/infrt/host_context/kernel_utils.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; }
......@@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) {
ASSERT_EQ(results[0]->get<int>(), 3);
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -16,7 +16,8 @@
#include <gtest/gtest.h>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; }
float add_f32(float a, float b) { return a + b; }
......@@ -66,4 +67,5 @@ TEST(KernelImpl, pair) {
ASSERT_EQ(results[1]->get<float>(), 3.f);
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -15,6 +15,7 @@
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include <glog/logging.h>
#include <mlir/IR/BuiltinOps.h>
#include <string> // NOLINT
......
......@@ -13,7 +13,8 @@
// limitations under the License.
#pragma once
#include <mlir/IR/Function.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <string>
#include <unordered_map>
......
......@@ -15,9 +15,9 @@
#pragma once
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OperationSupport.h>
#include <unordered_map>
......
......@@ -16,8 +16,9 @@
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
......@@ -40,7 +41,8 @@
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
template <typename T>
std::string DumpToString(T& op) { // NOLINT
......@@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) {
template <>
boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(32)) {
return val.getInt();
}
......@@ -125,10 +127,10 @@ boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
}
template <>
boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(64)) {
return val.getInt();
}
......@@ -139,10 +141,10 @@ boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
// TODO(Superjomn) Make double and float parsing share some thing.
template <>
boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr.isa<mlir::FloatAttr>()) {
auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF32()) return val.getValueAsDouble();
}
return boost::none;
......@@ -150,10 +152,10 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr.isa<mlir::FloatAttr>()) {
auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF64()) return val.getValueAsDouble();
}
return boost::none;
......@@ -161,17 +163,17 @@ boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::StringAttr>()) return boost::none;
return attr->cast<mlir::StringAttr>().getValue().str();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::StringAttr>()) return boost::none;
return attr.cast<mlir::StringAttr>().getValue().str();
}
#define PROCESS_ARRAY_INT(type__, bits__) \
template <> \
boost::optional<std::vector<type__>> MlirToRuntimeTranslator::EmitAttribute( \
const mlir::Attribute* attr) { \
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr->cast<mlir::ArrayAttr>(); \
const mlir::Attribute& attr) { \
if (!attr.isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr.cast<mlir::ArrayAttr>(); \
CHECK(!array.empty()); \
\
if (!array[0].getType().isInteger(bits__)) { \
......@@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64);
template <>
boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty());
if (!array[0].getType().isF32()) return boost::none;
......@@ -207,9 +209,9 @@ boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty());
if (!array[0].getType().isF64()) return boost::none;
......@@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (int i = 0, e = op->getNumOperands(); i < e; i++) {
// function argument as value
auto operand = op->getOperand(i);
if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
/// if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg);
impl_->cur_op->AppendArgument(arg_value);
......@@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (size_t i = 0; i < attrs.size(); i++) {
auto& attr = attrs[i];
if (auto v = EmitAttribute<int32_t>(&attr.second)) {
if (auto v = EmitAttribute<int32_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<int64_t>(&attr.second)) {
} else if (auto v = EmitAttribute<int64_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<float>(&attr.second)) {
} else if (auto v = EmitAttribute<float>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<double>(&attr.second)) {
} else if (auto v = EmitAttribute<double>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::string>(&attr.second)) {
} else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int16_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int64_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<float>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<double>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else {
LOG(FATAL) << "Not supported attribute type";
......@@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
llvm::SmallVector<mlir::Type, 0> results;
auto func_type =
mlir::FunctionType::get(inputs, results, region.getContext());
mlir::FunctionType::get(region.getContext(), inputs, results);
auto* function = impl_->cur_op->CreateFunctionExecutable(
&region, func_type, &impl_->func_defs);
impl_->cur_op->AppendAttribute(new Value(function));
......@@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) {
execute.Run();
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -29,7 +29,8 @@ class Attribute;
class Value;
} // namespace mlir
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class CoreRuntimeBuilder;
class Value;
......@@ -73,7 +74,7 @@ class MlirToRuntimeTranslator {
bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table);
template <typename T>
boost::optional<T> EmitAttribute(const mlir::Attribute* attr);
boost::optional<T> EmitAttribute(const mlir::Attribute& attr);
Value* GetOpResult(mlir::Operation* op);
......@@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime);
*/
void TestMlir(mlir::ModuleOp module, KernelRegistry* registry);
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -29,7 +29,8 @@
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
TEST(MlirToRuntimeTranslate, basic) {
mlir::MLIRContext context;
......@@ -48,7 +49,7 @@ func @main() -> () {
)ROC";
auto module = dialect::LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry);
......@@ -74,7 +75,7 @@ func @main() -> () {
)ROC";
auto module = dialect::LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry);
......@@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
// LOG(INFO) << "content: " << content << std::endl;
auto module = dialect::LoadMlirSource(context, content);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
host_context::KernelRegistry registry;
......@@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
}
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -14,6 +14,7 @@
#include "paddle/infrt/host_context/op_executable.h"
#include <mlir/IR/BuiltinOps.h>
#include <string>
#include "paddle/infrt/host_context/kernel_frame.h"
......@@ -21,7 +22,8 @@
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct OpExecutable::Impl {
Impl(const std::string& op_name,
......@@ -148,4 +150,5 @@ void OpExecutable::Execute() {
OpExecutable::~OpExecutable() {}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -14,19 +14,18 @@
#pragma once
#include <llvm/ADT/ArrayRef.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <memory>
#include <string>
#include <unordered_map>
#include "mlir/IR/Function.h"
#include "mlir/IR/Region.h"
namespace mlir {
class FuncOp;
} // namespace mlir
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class SymbolTable;
class KernelRegistry;
......@@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable {
function_defs_t* function_defs);
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -23,7 +23,8 @@
using infrt::host_context::Attribute;
namespace infrt::kernel {
namespace infrt {
namespace kernel {
template <typename T>
T add(T a, T b) {
......@@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print<float>));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -15,13 +15,16 @@
#pragma once
#include <string>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
/**
* Register all the basic kernels to \p registry.
......@@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry);
void RegisterIntBasicKernels(host_context::KernelRegistry* registry);
void RegisterFloatBasicKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -25,7 +25,8 @@
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel {
namespace infrt {
namespace kernel {
using namespace host_context; // NOLINT
using namespace tensor; // NOLINT
......@@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShallowCopyTensor));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -14,12 +14,16 @@
#pragma once
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void RegisterTensorKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -24,7 +24,8 @@
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void PrintShape(const tensor::TensorShape& shape) {
llvm::raw_os_ostream oos(std::cout);
......@@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -14,14 +14,18 @@
#pragma once
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void RegisterTensorShapeKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -33,7 +33,8 @@ using infrt::host_context::Attribute;
using infrt::host_context::MlirFunctionExecutable;
using infrt::host_context::RemainingArguments;
namespace infrt::kernel {
namespace infrt {
namespace kernel {
namespace {
class BenchmarkStats {
public:
......@@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShadowCopyTensor));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -15,17 +15,21 @@
#pragma once
#include <string>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
/**
* Register all the test kernels to registry.
*/
void RegisterTestKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -18,7 +18,9 @@
#include <string>
#include <vector>
namespace infrt::paddle::cpp {
namespace infrt {
namespace paddle {
namespace cpp {
/*
* Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc
......@@ -226,4 +228,6 @@ class ProgramDescAPI {
virtual void SetVersion(int64_t version) = 0;
};
} // namespace infrt::paddle::cpp
} // namespace cpp
} // namespace paddle
} // namespace infrt
......@@ -22,7 +22,8 @@
#include "paddle/infrt/common/target.h"
#include "paddle/infrt/common/type.h"
namespace infrt::paddle {
namespace infrt {
namespace paddle {
int SizeOfType(framework_proto::VarType::Type type) {
using Type = framework_proto::VarType::Type;
......@@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) {
LoadLoDTensor(fin, out, target);
}
} // namespace infrt::paddle
} // namespace paddle
} // namespace infrt
......@@ -25,7 +25,8 @@
#include "paddle/infrt/paddle/scope.h"
#include "paddle/infrt/paddle/tensor.h"
namespace infrt::paddle {
namespace infrt {
namespace paddle {
namespace framework_proto = ::paddle::framework::proto;
// Read a __model__ file.
......@@ -52,4 +53,5 @@ void TensorFromStream(
const common::Target& target = common::DefaultHostTarget());
void ReadBinaryFile(const std::string& filename, std::string* contents);
} // namespace infrt::paddle
} // namespace paddle
} // namespace infrt
......@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/block_desc.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
template <>
framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>(
......@@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp<framework_proto::OpDesc>() {
return desc_->add_ops();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -18,7 +18,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
......@@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI {
framework_proto::BlockDesc* desc_; // not_own
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/op_desc.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
google::protobuf::internal::RepeatedPtrIterator<framework_proto::OpDesc_Attr>
FindAttr(framework_proto::OpDesc *desc, const std::string &name) {
......@@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector<std::string>, strings);
GET_ATTR_IMPL(std::string, s);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/framework.pb.h"
#include "paddle/infrt/support/variant.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
......@@ -195,4 +197,6 @@ template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v);
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -17,7 +17,9 @@
#include <algorithm>
#include <limits>
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
template <>
framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>(
......@@ -32,4 +34,6 @@ ProgramDesc::AddBlock<framework_proto::BlockDesc>() {
return desc_->add_blocks();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -21,7 +21,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
class ProgramDesc : public cpp::ProgramDescAPI {
......@@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI {
framework_proto::ProgramDesc *desc_; // not_own
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
cpp::VarDescAPI::Type VarDesc::GetType() const {
auto type = desc_->type().type();
......@@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() {
return std::vector<framework_proto::VarType::TensorDesc *>();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -23,7 +23,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
// convert between std::vector and protobuf repeated.
......@@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI {
framework_proto::VarDesc *desc_;
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -435,6 +435,10 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place,
}
void DenseTensor::ShareBufferWith(const DenseTensor& tensor) {
if (storage_ == nullptr) {
storage_ = make_intrusive<paddle::experimental::SharedStorage>(
paddle::platform::CPUPlace());
}
if (storage_ != nullptr && tensor.storage_ != nullptr) {
storage_->set_data_shared(tensor.storage_->data_shared());
}
......
......@@ -152,6 +152,9 @@ def ShardingScaler(scaler):
param_grads = []
param_grads_fp16 = []
param_grads_fp32 = []
if hasattr(optimizer, "update_slice"):
optimizer.update_slice()
optimizer.update_scaler = True
if getattr(optimizer._optim, '_param_groups', None) and isinstance(
optimizer._optim._param_groups[0], dict):
......@@ -161,27 +164,21 @@ def ShardingScaler(scaler):
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16:
).dtype in [core.VarDesc.VarType.FP16, paddle.float16]:
param_grads_fp16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._optim._parameter_list
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._optim._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16
)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._optim._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32
)
]
for param in optimizer._optim._parameter_list:
if param.grad is not None:
param_grads.append(param.grad)
if param.grad.dtype in [
core.VarDesc.VarType.FP16, paddle.float16
]:
param_grads_fp16.append(param.grad)
else:
param_grads_fp32.append(param.grad)
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
......
......@@ -34,6 +34,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3)
list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
......@@ -250,6 +251,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
......@@ -1058,6 +1060,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120)
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2021)
np.random.seed(2021)
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
fleet.init(is_collective=True)
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model,
sharding_stage,
use_pure_fp16=False,
accumulate_grad=False,
opt_group=False,
recompute=False):
group = paddle.distributed.new_group([0, 1])
if opt_group:
optimizer = optimizer_setting(
model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group)
else:
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = ShardingScaler(scaler)
if sharding_stage == 2:
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=accumulate_grad)
elif sharding_stage == 3:
model = ShardingStage3(
model, optimizer=optimizer, group=group, sync_comm=recompute)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if sharding_stage == 3:
model.get_all_parameters()
return model.parameters()
def test_stage2_stage3():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict)
# fp32
stage2_params = train_mlp(
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True)
stage3_params = train_mlp(
mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp32 accumulate grad
stage2_params = train_mlp(
mlp3,
sharding_stage=2,
use_pure_fp16=False,
accumulate_grad=True,
opt_group=True)
stage3_params = train_mlp(
mlp4,
sharding_stage=3,
use_pure_fp16=False,
accumulate_grad=True,
opt_group=True)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp16
stage2_params = train_mlp(
mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False)
stage3_params = train_mlp(
mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp16 recompute
stage3_params = train_mlp(
mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False)
stage3_params_re = train_mlp(
mlp8,
sharding_stage=3,
use_pure_fp16=True,
opt_group=False,
recompute=True)
for i in range(len(stage3_params)):
for j in range(len(stage3_params_re)):
if stage3_params[i].name == stage3_params_re[j].name:
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_re[j].numpy(),
rtol=1e-6)
return
if __name__ == '__main__':
test_stage2_stage3()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import unittest
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertFlattenContiguousRangeTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(batch):
return np.random.random([2, batch, 4, 8, 3]).astype(np.float32)
for batch in [1, 2, 4]:
for start_axis in range(5):
for stop_axis in range(start_axis, 5):
type = "flatten_contiguous_range"
op_outputs = {
"Out": ["output_data"],
"XShape": ["xshape_data"]
}
ops_config = [{
"op_type": type,
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": op_outputs,
"op_attrs": {
"start_axis": start_axis,
"stop_axis": stop_axis,
}
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, batch))
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [2, 1, 4, 8, 3]}
self.dynamic_shape.max_input_shape = {"input_data": [2, 4, 4, 8, 3]}
self.dynamic_shape.opt_input_shape = {"input_data": [2, 2, 4, 8, 3]}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 >= 7000:
if dynamic_shape:
return 1, 2
else:
if attrs[0]['start_axis'] == 0:
return 0, 3
else:
return 1, 2
else:
return 0, 3
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage3(self):
self.run_mnist_2gpu('dygraph_sharding_stage3.py')
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -66,6 +66,15 @@ class TestStackOpBase(XPUOpTest):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if self.dtype == 'int64' or self.dtype == 'int32':
pass
else:
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, self.get_x_names(), 'Y')
class TestStackOp1(TestStackOpBase):
def initParameters(self):
......@@ -81,11 +90,17 @@ class TestStackOp3(TestStackOpBase):
def initParameters(self):
self.axis = -1
def test_check_grad(self):
pass
class TestStackOp4(TestStackOpBase):
def initParameters(self):
self.axis = -4
def test_check_grad(self):
pass
class TestStackOp5(TestStackOpBase):
def initParameters(self):
......@@ -113,7 +128,7 @@ class TestStackOpint(TestStackOpBase):
self.num_inputs = 4
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = 'int'
self.dtype = 'int32'
def initParameters(self):
self.num_inputs = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册