提交 04b9bfbd 编写于 作者: J jim19930609

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

include(FetchContent)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz)
set(LLVM_MD5 39d32b6be466781dddf5869318dcba53)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/infrt/llvm_b5149f4e66a49a98b67e8e2de4e24a4af8e2781b.tar.gz)
set(LLVM_MD5 022819bb5760817013cf4b8a37e97d5e)
set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm)
set(FETCHCONTENT_QUIET OFF)
......@@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
# To build with MLIR, the LLVM is build from source code using the following flags:
#[==[
cmake -G Ninja ../llvm \
cmake ../llvm -G "Unix Makefiles" \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="X86" \
......@@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_ZLIB=OFF \
-DLLVM_ENABLE_RTTI=ON \
-DLLVM_INSTALL_UTILS=ON \
-DCMAKE_INSTALL_PREFIX=./install
#]==]
# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit)
# The matched llvm-project version is b5149f4e66a49a98b67e8e2de4e24a4af8e2781b (currently a temporary commit)
add_definitions(${LLVM_DEFINITIONS})
......@@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS})
# The minimum needed libraries for MLIR IR parse and transform.
set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
set(MLIR_IR_LIBS MLIRAnalysis MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
# tb_base is the name of a xxx.td file (without the .td suffix)
......@@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base)
mlir_tablegen(${td_base}.cpp.inc -gen-op-defs)
if (mlir_tablegen_on_DIALECT)
mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT})
mlir_tablegen(${td_base}_dialect.cpp.inc --gen-dialect-defs -dialect=${mlir_tablegen_on_DIALECT})
endif()
add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
......
......@@ -46,8 +46,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
<< " is diabled by config in TensorRT";
return false;
}
return tensorrt::OpTeller::Global().Tell(node, no_calib_int8,
with_dynamic_shape);
bool is_ok = tensorrt::OpTeller::Global().Tell(node, no_calib_int8,
with_dynamic_shape);
if (!is_ok)
VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT";
return is_ok;
};
framework::ir::SubGraphFuser fuser(
......
......@@ -1416,6 +1416,7 @@ USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(transpose);
USE_TRT_CONVERTER(flatten);
USE_TRT_CONVERTER(flatten_contiguous_range);
USE_TRT_CONVERTER(matmul);
USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu);
......
......@@ -3,7 +3,7 @@ nv_library(tensorrt_converter
SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
anchor_generator_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* flatten_contiguous_range trt converter
*/
class FlattenContiguousRangeOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
int dims = input->getDimensions().nbDims;
int start_axis = BOOST_GET_CONST(int, op_desc.GetAttr("start_axis"));
int stop_axis = BOOST_GET_CONST(int, op_desc.GetAttr("stop_axis"));
nvinfer1::IShuffleLayer* layer = nullptr;
if (!engine_->with_dynamic_shape()) {
if (start_axis < 0) start_axis += dims + 1;
if (stop_axis < 0) stop_axis += dims + 1;
int dim_prod = 1;
nvinfer1::Dims flatten_dim;
flatten_dim.nbDims = dims - (stop_axis - start_axis);
for (int i = 0, j = 0; i < dims; ++i) {
if (start_axis <= i + 1 && i + 1 <= stop_axis) {
int dim_i = input->getDimensions().d[i];
PADDLE_ENFORCE_GT(dim_i, 0, platform::errors::InvalidArgument(
"flatten_contiguous_range input dim "
"should be > 0, but got %d.",
dim_i));
dim_prod *= dim_i;
if (i + 1 == stop_axis) {
flatten_dim.d[j++] = dim_prod;
}
} else {
flatten_dim.d[j++] = input->getDimensions().d[i];
}
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setReshapeDimensions(flatten_dim);
} else {
if (start_axis < 0) start_axis += dims;
if (stop_axis < 0) stop_axis += dims;
auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* shape_layer_itensor = shape_layer->getOutput(0);
nvinfer1::Dims start_dim, size_dim, stride_dim;
start_dim.nbDims = 1;
size_dim.nbDims = 1;
stride_dim.nbDims = 1;
start_dim.d[0] = start_axis;
size_dim.d[0] = stop_axis - start_axis + 1;
stride_dim.d[0] = 1;
auto* slice_layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor, start_dim,
size_dim, stride_dim);
uint32_t reduce_dim = 1;
auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *(slice_layer->getOutput(0)),
nvinfer1::ReduceOperation::kPROD, reduce_dim, true);
nvinfer1::ITensor* input_shape = nullptr;
if (start_axis == 0 && stop_axis == dims - 1) {
input_shape = reduce_prod_layer->getOutput(0);
} else {
std::vector<nvinfer1::ITensor*> itensors;
if (start_axis > 0) {
nvinfer1::Dims left_start_dim, left_size_dim, left_stride_dim;
left_start_dim.nbDims = 1;
left_size_dim.nbDims = 1;
left_stride_dim.nbDims = 1;
left_start_dim.d[0] = 0;
left_size_dim.d[0] = start_axis;
left_stride_dim.d[0] = 1;
auto* slice_layer_left = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *shape_layer_itensor, left_start_dim,
left_size_dim, left_stride_dim);
itensors.push_back(slice_layer_left->getOutput(0));
}
itensors.push_back(reduce_prod_layer->getOutput(0));
if (stop_axis < dims - 1) {
nvinfer1::Dims right_start_dim, right_size_dim, right_stride_dim;
right_start_dim.nbDims = 1;
right_size_dim.nbDims = 1;
right_stride_dim.nbDims = 1;
right_start_dim.d[0] = stop_axis + 1;
right_size_dim.d[0] = dims - stop_axis - 1;
right_stride_dim.d[0] = 1;
auto* slice_layer_right = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *shape_layer_itensor, right_start_dim,
right_size_dim, right_stride_dim);
itensors.push_back(slice_layer_right->getOutput(0));
}
auto* concat_layer = TRT_ENGINE_ADD_LAYER(
engine_, Concatenation, itensors.data(), itensors.size());
concat_layer->setAxis(0);
input_shape = concat_layer->getOutput(0);
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setInput(1, *input_shape);
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "flatten_contiguous_range", {output_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(flatten_contiguous_range,
FlattenContiguousRangeOpConverter);
......@@ -55,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller {
// #endif
#if IS_TRT_VERSION_GE(7000)
teller_set.insert("tile");
teller_set.insert("flatten_contiguous_range");
#endif
#if CUDA_VERSION >= 10020
teller_set.insert("reshape");
......@@ -531,6 +532,37 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (axis != 1) return false;
}
}
if (op_type == "flatten_contiguous_range") {
if (!with_dynamic_shape) {
int start_axis = BOOST_GET_CONST(int, desc.GetAttr("start_axis"));
int stop_axis = BOOST_GET_CONST(int, desc.GetAttr("stop_axis"));
auto x_var_name = desc.Input("X")[0];
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
int dims = x_shape.size();
if (start_axis < 0) start_axis += dims;
if (start_axis == 0) {
VLOG(3) << "TRT flatten_contiguous_range not support the "
"batch-dimension being changed";
return false;
}
if (stop_axis < 0) stop_axis += dims;
for (int i = start_axis; i <= stop_axis; ++i) {
if (x_shape[i] < 0) {
VLOG(3) << "On TRT static shape,flatten_contiguous_range input dim "
"should be > 0";
return false;
}
}
}
}
if (op_type == "gather") {
auto gather_inputs = desc.Inputs();
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/stack_op.h"
#include <string>
#ifdef PADDLE_WITH_XPU
#include <vector>
#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
......@@ -59,14 +62,44 @@ class StackXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class StackGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto dx = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dy_dims = dy->dims();
if (axis < 0) axis += dy_dims.size() + 1;
auto dy_shape = framework::vectorize<int>(dy_dims);
std::vector<int> dx_dims_list(dx.size(), 1);
std::vector<T*> dx_lists;
for (auto out : dx) {
dx_lists.push_back(out->mutable_data<T>(ctx.GetPlace()));
}
int r = xpu::split<T>(dev_ctx.x_context(), dy->data<T>(), dx_lists,
dy_shape, dx_dims_list, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"The stack_grad XPU kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(stack,
ops::StackXPUKernel<plat::XPUDeviceContext, int64_t>,
ops::StackXPUKernel<plat::XPUDeviceContext, float>,
ops::StackXPUKernel<plat::XPUDeviceContext, int>,
ops::StackXPUKernel<plat::XPUDeviceContext, float>);
ops::StackXPUKernel<plat::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(stack_grad,
ops::StackGradXPUKernel<plat::XPUDeviceContext, float>,
ops::StackGradXPUKernel<plat::XPUDeviceContext, int>);
#endif
......@@ -300,6 +300,7 @@ XPUOpMap& get_kl1_ops() {
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
......@@ -333,6 +333,8 @@ XPUOpMap& get_kl2_ops() {
{"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
......
......@@ -77,7 +77,6 @@ add_subdirectory(paddle)
# MLIR td file generations
set(infrt_mlir_incs
ops_inc
basic_kernels_inc
test_kernels_inc
infrt_base_inc
......
......@@ -14,7 +14,7 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include <mlir/IR/MLIRContext.h>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
......
......@@ -2,7 +2,6 @@ core_gather_headers()
gather_srcs(infrt_src SRCS
dialect.cc
types.cc
basic_kernels.cc
test_kernels.cc
infrt_base.cc
......@@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS
pd_types.cc
pd_ops.cc
)
mlir_tablegen_on(ops)
mlir_tablegen_on(basic_kernels)
mlir_tablegen_on(test_kernels)
mlir_tablegen_on(infrt_base DIALECT infrt)
......@@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite)
# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code
add_executable(infrtopt opt.cc)
target_link_libraries(infrtopt infrt ${mlir_libs})
add_dependencies(infrtopt infrt)
target_link_libraries(infrtopt infrt)
add_executable(print-ir print_ir.cc)
target_link_libraries(print-ir infrt ${mlir_libs})
......
......@@ -17,17 +17,17 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include "paddle/infrt/dialect/dense_tensor.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
using namespace mlir; // NOLINT
static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT
......@@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT
static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(32, result.getContext()), parser, result);
IntegerType::get(result.getContext(), 32), parser, result);
}
static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(64, result.getContext()), parser, result);
IntegerType::get(result.getContext(), 64), parser, result);
}
static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
......@@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
}
static void print(OpAsmPrinter &p, CallOp op) { // NOLINT
p << "infrt.call " << op.getAttr("callee") << "(";
p << "infrt.call " << op->getAttr("callee") << "(";
p.printOperands(op.getOperands());
p << ")";
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
p.printOptionalAttrDict(op->getAttrs(), {"callee"});
p << " : ";
}
......@@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); }
static LogicalResult verify(ConstantI64Op op) { return success(); }
static LogicalResult verify(ReturnOp op) {
auto function = dyn_cast<FuncOp>(op.getParentOp());
auto function = dyn_cast<FuncOp>(op->getParentOp());
if (!function) return success();
......@@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) {
return success();
}
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.cpp.inc"
} // namespace infrt::dialect
......@@ -13,12 +13,9 @@
// limitations under the License.
#pragma once
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
using namespace mlir; // NOLINT
namespace infrt::dialect {
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.hpp.inc"
} // namespace infrt::dialect
......@@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> {
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
StringRef getCallee() { return callee(); }
mlir::StringRef getCallee() { return callee(); }
mlir::FunctionType getCalleeType();
}];
}
......@@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> {
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result",
[{ build(b, result, llvm::None); }]>];
let builders = [OpBuilder<(ins),
[{ build($_builder, $_state, llvm::None); }]>];
}
class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> {
......
......@@ -17,12 +17,11 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
......@@ -31,68 +30,37 @@
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt::dt {
namespace infrt {
namespace dt {
void DTDialect::initialize() {
allowUnknownTypes();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"
>();
}
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) {
if (key.equals_lower("x86"))
if (key.equals_insensitive("x86"))
return TargetType::X86;
else if (key.equals_lower("cuda"))
else if (key.equals_insensitive("cuda"))
return TargetType::CUDA;
else
return llvm::None;
}
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) {
if (key.equals_lower("nchw"))
if (key.equals_insensitive("nchw"))
return LayoutType::NCHW;
else if (key.equals_lower("nhwc"))
else if (key.equals_insensitive("nhwc"))
return LayoutType::NHWC;
else
return llvm::None;
}
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) {
if (key.equals_lower("i32"))
if (key.equals_insensitive("i32"))
return PrecisionType::I32;
else if (key.equals_lower("f32"))
else if (key.equals_insensitive("f32"))
return PrecisionType::F32;
else
return llvm::None;
......@@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; }
PrecisionType TensorType::precision() { return getImpl()->precision_; }
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType) {
os << "TensorType<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">";
return os;
......@@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
raw_ostream &operator<<(raw_ostream &os, TargetType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type) {
switch (type) {
case (TargetType::X86):
os << "X86";
......@@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) {
return os;
}
raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type) {
switch (type) {
case (LayoutType::NCHW):
os << "NCHW";
......@@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
return os;
}
raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type) {
switch (type) {
case (PrecisionType::I32):
os << "I32";
......@@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
return os;
}
static Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = Identifier::get("t", context);
return OpaqueType::get(t_dialect, "tensor", context);
static mlir::Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = mlir::Identifier::get("t", context);
return mlir::OpaqueType::get(t_dialect, "tensor");
}
static ParseResult parseCreateUninitTensorOp(
OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
static mlir::ParseResult parseCreateUninitTensorOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
auto loc = parser.getCurrentLocation();
::mlir::Type outputRawTypes[1];
::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes);
mlir::Type outputRawTypes[1];
::llvm::ArrayRef<mlir::Type> outputTypes(outputRawTypes);
mlir::ArrayAttr shapeAttr;
if (parser.parseAttribute(shapeAttr,
parser.getBuilder().getI64Type(),
"shape",
result.attributes))
return failure();
if (parser.parseOptionalAttrDict(result.attributes)) return failure();
return mlir::failure();
if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure();
if (parser.parseArrow()) return failure();
if (parser.parseType(outputRawTypes[0])) return failure();
if (parser.parseArrow()) return mlir::failure();
if (parser.parseType(outputRawTypes[0])) return mlir::failure();
if (!outputRawTypes[0].isa<TensorType>())
return parser.emitError(loc, "invalid kind of type specified");
result.addTypes(outputTypes);
return success();
return mlir::success();
}
template <typename CreateUninitTensorOp>
static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT
static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p, // NOLINT
CreateUninitTensorOp op) {
p << CreateUninitTensorOp::getOperationName();
p << " ";
p.printAttributeWithoutType(op.shapeAttr());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
p << " -> ";
p << op.getOperation()->getResultTypes();
}
// TODO(shibo): can be removed?
// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser,
// OperationState& result) {
// auto loc = parser.getCurrentLocation();
// ::mlir::OpAsmParser::OperandType inputRawOperands[1];
// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType>
// inputOperands(inputRawOperands);
// ::mlir::Type inputRawTypes[1];
// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes);
//
// if (parser.parseOperand(inputRawOperands[0])) return failure();
//
// if (parser.parseColon()) return failure();
// if (parser.parseType(inputRawTypes[0])) return failure();
// if (!inputRawTypes[0].isa<TensorType>())
// return parser.emitError(loc, "invalid kind of type specified");
//
// Attribute value_attr;
// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands))
// return failure();
// if (parser.parseAttribute(value_attr, "value", result.attributes)) return
// failure();
// return success();
//}
// TODO(shibo): can be removed?
// template <typename FillTensorOp>
// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) {
// p << FillTensorOp::getOperationName();
// p << " ";
// p.printOperand(op.getOperand());
// p << " : ";
// p << op.getOperation()->getOperandTypes();
// p << " ";
// p << op.getAttr("value");
//}
static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return failure();
static mlir::ParseResult parseSetTensorOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
llvm::SmallVector<mlir::OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return mlir::failure();
auto tensor_type = getTensorType(result.getContext());
Attribute value_attr;
return failure(
mlir::Attribute value_attr;
return mlir::failure(
parser.resolveOperand(operands[0], tensor_type, result.operands) ||
parser.parseAttribute(value_attr, "values", result.attributes));
}
template <typename SetTensorOp>
static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT
static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) { // NOLINT
p << SetTensorOp::getOperationName() << " ";
p.printOperand(op.getOperand());
p << " " << op.getAttr("values");
p << " " << op->getAttr("values");
}
} // namespace dt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT
} // namespace infrt::dt
#include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc"
......@@ -19,13 +19,8 @@
#include <string>
using namespace mlir; // NOLINT
namespace infrt::dt {
namespace detail {
struct TensorTypeStorage;
} // namespace detail
namespace infrt {
namespace dt {
enum class TargetType : uint8_t { X86, CUDA };
enum class LayoutType : uint8_t { NCHW, NHWC };
enum class PrecisionType : uint8_t { I32, F32 };
......@@ -34,9 +29,39 @@ llvm::Optional<TargetType> GetTargetType(mlir::StringRef key);
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key);
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key);
raw_ostream &operator<<(raw_ostream &os, TargetType type);
raw_ostream &operator<<(raw_ostream &os, LayoutType type);
raw_ostream &operator<<(raw_ostream &os, PrecisionType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type);
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
class TensorType : public mlir::Type::TypeBase<TensorType,
mlir::Type,
......@@ -52,7 +77,7 @@ class TensorType : public mlir::Type::TypeBase<TensorType,
PrecisionType precision();
};
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType);
class TensorMapType : public mlir::Type::TypeBase<TensorMapType,
mlir::Type,
......@@ -70,10 +95,10 @@ class StringType
static StringType get();
static StringType get(mlir::MLIRContext *context);
};
} // namespace dt
} // namespace infrt
#include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.hpp.inc"
} // namespace infrt::dt
......@@ -14,9 +14,11 @@
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include <llvm/Support/raw_ostream.h>
#include <string>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
struct MyScopedDiagnosicHandler::Impl {
Impl() : diag_stream_(diag_str_) {}
......@@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) {
return mlir::failure(true);
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -18,7 +18,8 @@
#include <memory>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
/**
* A scoped diagnostic handler to help debug MLIR process.
......@@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler {
std::unique_ptr<Impl> impl_;
};
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -13,24 +13,26 @@
// limitations under the License.
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::hlir::dialect {
namespace infrt {
namespace hlir {
namespace dialect {
class CinnDialect : public ::mlir::Dialect {
class CinnDialect : public mlir::Dialect {
public:
explicit CinnDialect(::mlir::MLIRContext* ctx);
explicit CinnDialect(mlir::MLIRContext* ctx);
//! We should register this function in dialect
static llvm::StringRef getDialectNamespace() {
return "infrt::hlir::dialect";
}
};
} // namespace infrt::hlir::dialect
} // namespace dialect
} // namespace hlir
} // namespace infrt
......@@ -18,7 +18,8 @@
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/test_kernels.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
// ----INFRTDialect definition begin----
void INFRTDialect::initialize() {
......@@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type,
// ----INFRTDialect definition end----
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -18,19 +18,17 @@
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
#include "paddle/infrt/dialect/infrt_base.hpp.inc"
namespace infrt::dialect {
class INFRTDialect : public ::mlir::Dialect {
explicit INFRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(),
context,
::mlir::TypeID::get<INFRTDialect>()) {
namespace infrt {
namespace dialect {
class INFRTDialect : public mlir::Dialect {
explicit INFRTDialect(mlir::MLIRContext *context)
: mlir::Dialect(
getDialectNamespace(), context, mlir::TypeID::get<INFRTDialect>()) {
initialize();
}
......@@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect {
mlir::DialectAsmPrinter &printer) const override;
void initialize();
friend class ::mlir::MLIRContext;
friend class mlir::MLIRContext;
public:
static ::llvm::StringRef getDialectNamespace() { return "infrt"; }
};
} // namespace infrt::dialect
namespace mlir {
} // namespace dialect
template <typename T>
static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
......@@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
return b.getIntegerAttr(b.getI32Type(), constant);
}
static mlir::SmallVector<::mlir::Value, 4> cvtValueToValueRange(
static mlir::SmallVector<mlir::Value, 4> cvtValueToValueRange(
const mlir::Value &operand) {
return mlir::SmallVector<::mlir::Value, 4>(1, operand);
return mlir::SmallVector<mlir::Value, 4>(1, operand);
}
static mlir::SmallVector<::mlir::Value, 4> concatTwoValueRange(
static mlir::SmallVector<mlir::Value, 4> concatTwoValueRange(
mlir::ValueRange operand_0, mlir::ValueRange operand_1) {
mlir::SmallVector<::mlir::Value, 4> operands;
mlir::SmallVector<mlir::Value, 4> operands;
operands.append(operand_0.begin(), operand_0.end());
operands.append(operand_1.begin(), operand_1.end());
return operands;
}
} // namespace mlir
} // namespace infrt
......@@ -28,11 +28,11 @@ def TensorMapType :
def BufferType : OpaqueType<"b", "buffer", "buffer">;
class INFRT_createI32Attr<string value> : NativeCodeCall<
"mlir::createI32Attr($_builder, $_loc, " # value # ")">;
"infrt::createI32Attr($_builder, $_loc, " # value # ")">;
def INFRT_cvtValueToValueRange : NativeCodeCall<
"mlir::cvtValueToValueRange($0)">;
"infrt::cvtValueToValueRange($0)">;
def INFRT_concatTwoValueRange : NativeCodeCall<
"mlir::concatTwoValueRange($0, $1)">;
"infrt::concatTwoValueRange($0, $1)">;
#endif // INFRT_BASE
......@@ -23,12 +23,10 @@
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT
registry.insert<ts::TensorShapeDialect>();
registry.insert<dialect::INFRTDialect>();
registry.insert<dt::DTDialect>();
registry.insert<mlir::pd::PaddleDialect>();
void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
registry.insert<ts::TensorShapeDialect,
dialect::INFRTDialect,
dt::DTDialect,
mlir::pd::PaddleDialect>();
}
} // namespace infrt
......@@ -14,10 +14,8 @@
#pragma once
#include "mlir/IR/Dialect.h"
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT
void registerCinnDialects(mlir::DialectRegistry &registry); // NOLINT
} // namespace infrt
......@@ -16,8 +16,8 @@
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
#include <unordered_map>
......@@ -30,12 +30,15 @@
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) {
// context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
mlir::DialectRegistry registry;
registerCinnDialects(registry);
context->appendDialectRegistry(registry);
// Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is
// enough。Don't need StandardOpsDialect.
// context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
......@@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context) {
// context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::DialectRegistry registry;
registerCinnDialects(registry);
context->appendDialectRegistry(registry);
mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
......@@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
return mlir::parseSourceFile(std::string(file_name), context);
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -15,16 +15,17 @@
#pragma once
#include <glog/logging.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/BuiltinOps.h>
#include <string>
#include <memory>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source);
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context);
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -17,14 +17,15 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Parser.h>
#include <string>
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
TEST(MlirLoader, basic) {
mlir::MLIRContext context;
......@@ -42,8 +43,7 @@ func @main() -> f32 {
)ROC";
auto module = LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
LOG(INFO) << "module name: " << module->getOperationName().data();
for (auto func : module->getOps<mlir::FuncOp>()) {
LOG(INFO) << "get func " << func.getName().str();
......@@ -54,4 +54,5 @@ func @main() -> f32 {
}
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -20,5 +20,5 @@ func @main() -> tensor<?xf32> {
%c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
infrt.return %e2 : tensor<?xf32>
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
}
\ No newline at end of file
......@@ -11,5 +11,5 @@ func @main() -> tensor<?xf32> {
%c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
%d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
infrt.return %d : tensor<?x3x256x256xf32>
"pd.fetch"(%d) {name="output"} :(tensor<?x3x256x256xf32>)->()
}
\ No newline at end of file
......@@ -18,5 +18,5 @@ func @main() -> tensor<?xf32> {
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%e2) :(tensor<?xf32>)->()
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
}
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<INFRT_Dialect, mnemonic, traits>;
......@@ -12,34 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <llvm/Support/CommandLine.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Dialect.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
#include "paddle/infrt/dialect/mlir_loader.h"
int main(int argc, char **argv) {
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::DialectRegistry registry;
infrt::registerCinnDialects(registry);
mlir::registerCanonicalizerPass();
return mlir::failed(
mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry));
mlir::MlirOptMain(argc, argv, "infrt mlir pass driver", registry));
}
......@@ -16,7 +16,7 @@ def PD_Dialect : Dialect {
This dialect contains the PaddlePaddle operators.
}];
let cppNamespace = "::mlir::pd";
let cppNamespace = "mlir::pd";
}
class PD_Op<string mnemonic, list<OpTrait> traits = []> :
......
......@@ -14,10 +14,15 @@
#include "paddle/infrt/dialect/pd_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include <mlir/IR/Matchers.h>
#include <mlir/IR/PatternMatch.h>
#include "paddle/infrt/dialect/infrt_base.h"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
namespace mlir {
namespace pd {
PaddleDialect::PaddleDialect(MLIRContext *context)
......@@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder,
return builder.create<ConstantOp>(loc, value);
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
void ConstantOp::build(OpBuilder &builder,
OperationState &state,
Attribute value) {
......@@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes(
inferredReturnTypes.push_back(attributes.get("value").getType());
return success();
}
::mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<::mlir::Attribute> operands) {
mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<mlir::Attribute> operands) {
return value();
}
......@@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes(
return success();
}
void ElementwiseAdd::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseMulAdd>(context);
}
::mlir::OpFoldResult ElementwiseAdd::fold(
mlir::OpFoldResult ElementwiseAdd::fold(
llvm::ArrayRef<mlir::Attribute> operands) {
if (getElementTypeOrSelf(getType()).isa<FloatType>()) {
if (!operands[0] || !operands[1]) return {};
......@@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes(
}
void ReluOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseFCRelu>(context);
}
void FusedRepeatedFCRelu::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseRepeatedFCRelu2>(context);
}
void BatchNormOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseBatchNormWithConvPattern>(context);
}
......
......@@ -14,21 +14,20 @@
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace mlir {
namespace pd {
......@@ -53,9 +52,8 @@ class PaddleDialect : public Dialect {
}
};
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace pd
} // namespace mlir
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
......@@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> {
def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let summary = "fetch Op";
let description = [{
Return the output tensor from the subgraph.
}];
let arguments = (ins PD_Tensor :$inputs, StrAttr:$name);
}
def PD_ReturnOp : PD_Op<"return", [Terminator]> {
let summary = "return Op";
let description = [{
Fetch tensor from the graph.
}];
......@@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let arguments = (ins Variadic<PD_Tensor>:$inputs);
}
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> {
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
......@@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte
let hasFolder = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">,
OpBuilder<(ins "Attribute":$value)>,
];
}
......
......@@ -18,12 +18,11 @@
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
namespace mlir {
namespace PD {
......
......@@ -11,26 +11,25 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <llvm/ADT/Optional.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/ScopedPrinter.h>
#include <mlir/IR/BuiltinOps.h>
#include <llvm/Support/raw_os_ostream.hv
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Block.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/Region.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
#include "llvm/ADT/Optional.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
......@@ -114,17 +113,15 @@ int main(int argc, char **argv) {
mlir::registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv, "mlir demo");
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
// context->allowUnregisteredDialects();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::DialectRegistry registry;
infrt::registerCinnDialects(registry);
mlir::MLIRContext context(registry);
// mlir will verify module automatically after parsing.
// https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051
// mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source,
// context);
mlir::OwningModuleRef module_ref =
mlir::parseSourceFile(inputFilename, context);
mlir::parseSourceFile(inputFilename, &context);
std::cout << "----------print IR Structure begin----------" << std::endl;
printOperation(module_ref->getOperation(), 0);
std::cout << "----------print IR Structure end----------" << std::endl;
......
......@@ -17,16 +17,16 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::ts {
namespace infrt {
namespace ts {
using namespace mlir; // NOLINT
void TensorShapeDialect::initialize() {
......@@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const {
return Type();
}
void TensorShapeDialect::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &os) const {
void TensorShapeDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &os) const {
if (type.isa<ShapeType>()) {
os << "shape";
return;
......@@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type,
}
llvm_unreachable("unexpected 'shape' type kind");
}
} // namespace ts
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT
} // namespace infrt::ts
#include "paddle/infrt/dialect/tensor_shape_dialect.cpp.inc"
......@@ -17,7 +17,8 @@
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::ts {
namespace infrt {
namespace ts {
class ShapeType
: public mlir::Type::TypeBase<ShapeType, mlir::Type, mlir::TypeStorage> {
......@@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase<PartialShapeType,
public:
using Base::Base;
};
} // namespace ts
} // namespace infrt
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.hpp.inc"
#include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc"
} // namespace infrt::ts
......@@ -19,7 +19,7 @@ def TensorShapeDialect : Dialect {
def TS_Shape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::ShapeType>()">, "!ts.shape type">,
BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
let typeDescription = [{
let description = [{
`!ts.shape type` represents a static tensor shape.
}];
}
......@@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
def TS_PartialShape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::PartialShapeType>()">, "!ts.partial_shape type">,
BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> {
let typeDescription = [{
let description = [{
`!ts.partial_shape type` represents either a static tensor shape, unranked
tensor shape or a ranked tensor shape with unknown dimension sizes.
}];
......
......@@ -11,10 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <llvm/Support/CommandLine.h>
#include <mlir/Pass/PassManager.h>
#include <iostream>
#include <string>
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
......
......@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
#include <paddle/infrt/dialect/pd_ops.h>
#include <list>
#include <unordered_set>
#include <vector>
#include "llvm/ADT/SetVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -32,9 +31,9 @@ namespace {
// Reference the function nameed "FlexibleDFS" but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc.
bool reverseDfs(std::vector<::mlir::Operation *> source,
const std::function<bool(const ::mlir::Operation *)> &func) {
std::unordered_set<const ::mlir::Operation *> visited;
bool reverseDfs(std::vector<mlir::Operation *> source,
const std::function<bool(const mlir::Operation *)> &func) {
std::unordered_set<const mlir::Operation *> visited;
while (!source.empty()) {
auto node = source.back();
source.pop_back();
......@@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
auto values = node->getOperands();
for (auto value : values) {
// if the value is a block argument, the node is nullptr.
::mlir::Operation *node = value.getDefiningOp();
mlir::Operation *node = value.getDefiningOp();
if (node != nullptr && !visited.count(node)) {
source.emplace_back(node);
}
......@@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
}
// merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
::mlir::pd::GraphOp first,
::mlir::pd::GraphOp second) {
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
mlir::pd::GraphOp first,
mlir::pd::GraphOp second) {
// comput inputs and outputs
::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs;
for (::mlir::Value input : second.getOperands()) {
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) {
if (input.getDefiningOp() != first) {
inputs.push_back(input);
}
}
::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping;
for (::mlir::Value output : first.getResults()) {
for (::mlir::Operation *user : output.getUsers()) {
::llvm::DenseMap<mlir::Value, unsigned int> op_output_mapping;
for (mlir::Value output : first.getResults()) {
for (mlir::Operation *user : output.getUsers()) {
if (user != second && user->getParentOp() != second) {
op_output_mapping[output] = outputs.size();
outputs.push_back(output);
......@@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
}
}
}
auto fetch_op = second.getBody()->getTerminator();
outputs.append(fetch_op->getOperands().begin(),
fetch_op->getOperands().end());
::llvm::SmallVector<::mlir::Type, 4> fetch_types;
auto return_op = second.getBody()->getTerminator();
outputs.append(return_op->getOperands().begin(),
return_op->getOperands().end());
::llvm::SmallVector<mlir::Type, 4> return_types;
for (auto value : outputs) {
fetch_types.push_back(value.getType());
return_types.push_back(value.getType());
}
// create the new graph op
builder.setInsertionPoint(first);
auto loc = first.getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs);
::mlir::Block *block = new ::mlir::Block;
auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(),
second.getBody()->getOperations(),
......@@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
copy_range.begin(),
copy_range.end());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, outputs);
builder.create<mlir::pd::ReturnOp>(loc, outputs);
graph_op.body().push_back(block);
// mapping the output
unsigned int num_result = first.getNumResults();
fetch_op = first.getBody()->getTerminator();
return_op = first.getBody()->getTerminator();
for (unsigned int index = 0; index < num_result; ++index) {
auto origin_value = first.getResult(index);
if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
origin_value.replaceAllUsesWith(fetch_op->getOperand(index));
origin_value.replaceAllUsesWith(return_op->getOperand(index));
} else {
auto inner_value = fetch_op->getOperand(index);
auto inner_value = return_op->getOperand(index);
auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
while (!origin_value.use_empty()) {
auto replace_value =
......@@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) { // NOLINT
llvm::SetVector<Operation *> toSort;
llvm::SetVector<mlir::Operation *> toSort;
if (body.empty()) return;
for (auto it = body.rbegin(); it != body.rend(); ++it) {
toSort.insert(&*it);
}
llvm::SetVector<Operation *> result =
::mlir::topologicalSort(std::move(toSort));
llvm::SetVector<mlir::Operation *> result =
mlir::topologicalSort(std::move(toSort));
for (auto *op : result) {
op->moveBefore(body.getTerminator());
}
......@@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT
// Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() {
mlir::Block &body = getFunction().front();
::mlir::OpBuilder builder(&body, body.begin());
mlir::OpBuilder builder(&body, body.begin());
bool changed = false;
do {
changed = false;
for (auto &op : body) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) {
::mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op);
mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue;
// get all dst input nodes except src.
std::vector<::mlir::Operation *> source_nodes;
std::vector<mlir::Operation *> source_nodes;
for (auto operand : user_op->getOperands()) {
auto input = operand.getDefiningOp();
if (input != &op && input != nullptr) {
......@@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() {
}
}
// Reverse DFS from the source_nodes.
if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) {
return n == &op;
})) {
if (!reverseDfs(source_nodes,
[&op](const mlir::Operation *n) { return n == &op; })) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true;
break;
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -28,15 +28,15 @@ namespace trt {
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* "pd.fetch" %d, %f
*
......@@ -47,13 +47,13 @@ namespace trt {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* "pd.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*/
class trtGraphFusePass
: public ::mlir::PassWrapper<trtGraphFusePass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtGraphFusePass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override;
......
......@@ -14,7 +14,7 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "mlir/IR/Builders.h"
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
......@@ -22,24 +22,24 @@ namespace infrt {
namespace trt {
// Implementation of the trtGraphSplitPass。
void trtGraphSplitPass::runOnFunction() {
std::vector<::mlir::pd::GraphOp> worklist;
::mlir::Block& block = getFunction().front();
std::vector<mlir::pd::GraphOp> worklist;
mlir::Block& block = getFunction().front();
for (auto& op : block) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op);
}
}
while (!worklist.empty()) {
::mlir::pd::GraphOp graph_op = worklist.back();
mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back();
::mlir::Block* body = graph_op.getBody();
auto fetch_op = body->getTerminator();
graph_op.replaceAllUsesWith(fetch_op->getOperands());
mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator();
graph_op.replaceAllUsesWith(return_op->getOperands());
auto copy_range = body->without_terminator();
block.getOperations().splice(::mlir::Block::iterator(graph_op),
block.getOperations().splice(mlir::Block::iterator(graph_op),
body->getOperations(),
copy_range.begin(),
copy_range.end());
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -31,9 +31,9 @@ namespace trt {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* "pd.return" (%n, %s)
* } ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*
* destination func:
......@@ -42,11 +42,11 @@ namespace trt {
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*/
class trtGraphSplitPass
: public ::mlir::PassWrapper<trtGraphSplitPass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtGraphSplitPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void runOnFunction() override;
......
......@@ -14,49 +14,48 @@
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "mlir/IR/Builders.h"
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
// Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() {
::mlir::Block &body = getFunction().front();
std::vector<::mlir::Operation *> worklist;
mlir::Block &body = getFunction().front();
std::vector<mlir::Operation *> worklist;
worklist.reserve(body.getOperations().size());
for (auto &op : body) {
worklist.push_back(&op);
}
// Build GraphOp.
::mlir::OpBuilder builder(&body, body.begin());
mlir::OpBuilder builder(&body, body.begin());
while (!worklist.empty()) {
auto *op = worklist.back();
worklist.pop_back();
if (op == nullptr) continue;
auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op);
auto op1 = ::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op);
if (op1) continue;
auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op);
auto op2 = ::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op);
if (op2) continue;
auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op);
auto op3 = ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op);
if (op3) continue;
builder.setInsertionPoint(op);
auto loc = getFunction().getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(
auto graph_op = builder.create<mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) {
::llvm::SmallVector<mlir::Value, 4>{graph_op.getODSResults(0)}) {
tblgen_repl_values.push_back(v);
}
op->replaceAllUsesWith(tblgen_repl_values);
// Build graph op.
::mlir::Block *block = new ::mlir::Block;
mlir::Block *block = new mlir::Block;
graph_op.body().push_back(block);
op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, op->getResults());
builder.create<mlir::pd::ReturnOp>(loc, op->getResults());
}
}
} // namespace trt
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -29,7 +29,7 @@ namespace trt {
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*
* destination func:
......@@ -37,23 +37,23 @@ namespace trt {
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
* TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt.
*/
class trtOpTellerPass
: public ::mlir::PassWrapper<trtOpTellerPass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtOpTellerPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override;
......
......@@ -13,27 +13,25 @@
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt {
namespace trt {
TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect("trt", context, ::mlir::TypeID::get<TensorRTDialect>()) {
TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>();
#undef GET_OP_LIST
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
......@@ -14,37 +14,32 @@
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt {
namespace trt {
class TensorRTDialect : public ::mlir::Dialect {
class TensorRTDialect : public mlir::Dialect {
public:
explicit TensorRTDialect(::mlir::MLIRContext* context);
explicit TensorRTDialect(mlir::MLIRContext* context);
static llvm::StringRef getDialectNamespace() { return "trt"; }
};
// mlir bug。 can be removed safety when update mlir to llvm11.
using namespace mlir; // NOLINT
} // namespace trt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
......@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/test_kernels.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace infrt::dialect {
#include <mlir/IR/Builders.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
namespace infrt {
namespace dialect {
//===----------------------------------------------------------------------===//
// BenchmarkOp
//===----------------------------------------------------------------------===//
......@@ -32,65 +31,67 @@ namespace infrt::dialect {
// ...
// }
static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
StringAttr nameAttr;
static mlir::ParseResult parseBenchmarkOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
mlir::StringAttr nameAttr;
if (parser.parseAttribute(nameAttr, "name", result.attributes))
return failure();
return mlir::failure();
// Parse the operands, e.g. (%c : i32, %d : f32)
if (parser.parseLParen()) return failure();
if (parser.parseLParen()) return mlir::failure();
SmallVector<OpAsmParser::OperandType, 4> operands;
SmallVector<Type, 4> types;
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
llvm::SmallVector<mlir::Type, 4> types;
llvm::SMLoc type_loc = parser.getCurrentLocation();
if (parser.parseOptionalRParen()) {
// Parse non-empty operands
do {
// Parse %c : i32,
OpAsmParser::OperandType operand;
Type type;
mlir::OpAsmParser::OperandType operand;
mlir::Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
return mlir::failure();
operands.push_back(operand);
types.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen()) return failure();
if (parser.parseRParen()) return mlir::failure();
}
if (parser.resolveOperands(operands, types, type_loc, result.operands))
return failure();
return mlir::failure();
// Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1
do {
StringRef attr;
Attribute resultAttr;
mlir::StringRef attr;
mlir::Attribute resultAttr;
if (parser.parseKeyword(&attr) || parser.parseEqual() ||
parser.parseAttribute(resultAttr,
parser.getBuilder().getIntegerType(32),
attr,
result.attributes))
return failure();
} while (succeeded(parser.parseOptionalComma()));
return mlir::failure();
} while (mlir::succeeded(parser.parseOptionalComma()));
// Set the default attribute num_warmup_runs to 1 if unset
auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) {
bool found = llvm::any_of(result.attributes,
[attr_name](const NamedAttribute &attr) {
return attr.first == attr_name;
[attr_name](const mlir::NamedAttribute &attr) {
return attr.getName() == attr_name;
});
if (!found) {
IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value);
mlir::IntegerAttr default_val =
parser.getBuilder().getI32IntegerAttr(value);
result.addAttribute(attr_name, default_val);
}
};
setDefaultAttrIfUnset("num_warmup_runs", 1);
Region *target = result.addRegion();
mlir::Region *target = result.addRegion();
return parser.parseRegion(*target,
operands,
types,
......@@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
// max_count = 100, duration_secs = 1 {
// ...
// }
static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
static void print(mlir::OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p << "infrt.benchmark ";
// Print the name attribute, e.g "add.i32"
auto name_attr = op.getAttr("name");
auto name_attr = op->getAttr("name");
p << name_attr;
// Print the operands and types, e.g. (%c : i32, %d : f32)
......@@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
bool need_comma = false;
// Print the attributes, e.g. max_count = 100, duration_secs = 1
for (auto &name_attr : op.getAttrs()) {
auto id = name_attr.first;
for (auto &name_attr : op->getAttrs()) {
auto id = name_attr.getName();
if (id == "name") continue;
if (need_comma) p << ", ";
auto attr = name_attr.second;
auto attr = name_attr.getValue();
p << id << " = ";
if (auto int_attr = attr.dyn_cast<IntegerAttr>()) {
if (auto int_attr = attr.dyn_cast<mlir::IntegerAttr>()) {
int_attr.getValue().print(p.getStream(), /*isSigned=*/false);
} else {
op.emitOpError("Unexpected attribute");
......@@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
static LogicalResult verify(BenchmarkOp op) {
static mlir::LogicalResult verify(BenchmarkOp op) {
// Verify that the target benchmark region has exactly one return value.
auto &region = op.region();
auto &last_op = region.front().back();
......@@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) {
"incorrect number of return values. One return value is expected");
}
return success();
return mlir::success();
}
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.cpp.inc"
} // namespace infrt::dialect
......@@ -13,11 +13,8 @@
// limitations under the License.
#pragma once
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::dialect {
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.hpp.inc"
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/types.h"
namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/StandardTypes.h>
......@@ -23,7 +23,8 @@
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct CoreRuntime::Impl {
KernelRegistry* kernel_registry{};
......@@ -90,4 +91,5 @@ llvm::SmallVector<ValueRef, 4> CoreRuntime::GetResults(
CoreRuntime::~CoreRuntime() {}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -22,7 +22,8 @@
#include "paddle/infrt/host_context/value.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class KernelRegistry;
class OpExecutable;
......@@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime {
OpExecutableBuilder* NewOpExecutable(const std::string& op_name);
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -21,7 +21,8 @@
#include "llvm/ADT/SmallVector.h"
#include "paddle/infrt/host_context/value.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
/**
* KernelFrame captures the states(input arguments, attributes, results)
......@@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame {
}
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -18,7 +18,8 @@
#include "paddle/infrt/host_context/kernel_utils.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; }
......@@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) {
ASSERT_EQ(results[0]->get<int>(), 3);
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -16,7 +16,8 @@
#include <gtest/gtest.h>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; }
float add_f32(float a, float b) { return a + b; }
......@@ -66,4 +67,5 @@ TEST(KernelImpl, pair) {
ASSERT_EQ(results[1]->get<float>(), 3.f);
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -15,6 +15,7 @@
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include <glog/logging.h>
#include <mlir/IR/BuiltinOps.h>
#include <string> // NOLINT
......
......@@ -13,7 +13,8 @@
// limitations under the License.
#pragma once
#include <mlir/IR/Function.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <string>
#include <unordered_map>
......
......@@ -15,9 +15,9 @@
#pragma once
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OperationSupport.h>
#include <unordered_map>
......
......@@ -16,8 +16,9 @@
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
......@@ -40,7 +41,8 @@
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
template <typename T>
std::string DumpToString(T& op) { // NOLINT
......@@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) {
template <>
boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(32)) {
return val.getInt();
}
......@@ -125,10 +127,10 @@ boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
}
template <>
boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(64)) {
return val.getInt();
}
......@@ -139,10 +141,10 @@ boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
// TODO(Superjomn) Make double and float parsing share some thing.
template <>
boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr.isa<mlir::FloatAttr>()) {
auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF32()) return val.getValueAsDouble();
}
return boost::none;
......@@ -150,10 +152,10 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr.isa<mlir::FloatAttr>()) {
auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF64()) return val.getValueAsDouble();
}
return boost::none;
......@@ -161,17 +163,17 @@ boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::StringAttr>()) return boost::none;
return attr->cast<mlir::StringAttr>().getValue().str();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::StringAttr>()) return boost::none;
return attr.cast<mlir::StringAttr>().getValue().str();
}
#define PROCESS_ARRAY_INT(type__, bits__) \
template <> \
boost::optional<std::vector<type__>> MlirToRuntimeTranslator::EmitAttribute( \
const mlir::Attribute* attr) { \
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr->cast<mlir::ArrayAttr>(); \
const mlir::Attribute& attr) { \
if (!attr.isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr.cast<mlir::ArrayAttr>(); \
CHECK(!array.empty()); \
\
if (!array[0].getType().isInteger(bits__)) { \
......@@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64);
template <>
boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty());
if (!array[0].getType().isF32()) return boost::none;
......@@ -207,9 +209,9 @@ boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty());
if (!array[0].getType().isF64()) return boost::none;
......@@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (int i = 0, e = op->getNumOperands(); i < e; i++) {
// function argument as value
auto operand = op->getOperand(i);
if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
/// if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg);
impl_->cur_op->AppendArgument(arg_value);
......@@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (size_t i = 0; i < attrs.size(); i++) {
auto& attr = attrs[i];
if (auto v = EmitAttribute<int32_t>(&attr.second)) {
if (auto v = EmitAttribute<int32_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<int64_t>(&attr.second)) {
} else if (auto v = EmitAttribute<int64_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<float>(&attr.second)) {
} else if (auto v = EmitAttribute<float>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<double>(&attr.second)) {
} else if (auto v = EmitAttribute<double>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::string>(&attr.second)) {
} else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int16_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int64_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<float>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<double>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else {
LOG(FATAL) << "Not supported attribute type";
......@@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
llvm::SmallVector<mlir::Type, 0> results;
auto func_type =
mlir::FunctionType::get(inputs, results, region.getContext());
mlir::FunctionType::get(region.getContext(), inputs, results);
auto* function = impl_->cur_op->CreateFunctionExecutable(
&region, func_type, &impl_->func_defs);
impl_->cur_op->AppendAttribute(new Value(function));
......@@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) {
execute.Run();
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -29,7 +29,8 @@ class Attribute;
class Value;
} // namespace mlir
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class CoreRuntimeBuilder;
class Value;
......@@ -73,7 +74,7 @@ class MlirToRuntimeTranslator {
bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table);
template <typename T>
boost::optional<T> EmitAttribute(const mlir::Attribute* attr);
boost::optional<T> EmitAttribute(const mlir::Attribute& attr);
Value* GetOpResult(mlir::Operation* op);
......@@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime);
*/
void TestMlir(mlir::ModuleOp module, KernelRegistry* registry);
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -29,7 +29,8 @@
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
TEST(MlirToRuntimeTranslate, basic) {
mlir::MLIRContext context;
......@@ -48,7 +49,7 @@ func @main() -> () {
)ROC";
auto module = dialect::LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry);
......@@ -74,7 +75,7 @@ func @main() -> () {
)ROC";
auto module = dialect::LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry);
......@@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
// LOG(INFO) << "content: " << content << std::endl;
auto module = dialect::LoadMlirSource(context, content);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
host_context::KernelRegistry registry;
......@@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
}
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -14,6 +14,7 @@
#include "paddle/infrt/host_context/op_executable.h"
#include <mlir/IR/BuiltinOps.h>
#include <string>
#include "paddle/infrt/host_context/kernel_frame.h"
......@@ -21,7 +22,8 @@
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct OpExecutable::Impl {
Impl(const std::string& op_name,
......@@ -148,4 +150,5 @@ void OpExecutable::Execute() {
OpExecutable::~OpExecutable() {}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -14,19 +14,18 @@
#pragma once
#include <llvm/ADT/ArrayRef.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <memory>
#include <string>
#include <unordered_map>
#include "mlir/IR/Function.h"
#include "mlir/IR/Region.h"
namespace mlir {
class FuncOp;
} // namespace mlir
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class SymbolTable;
class KernelRegistry;
......@@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable {
function_defs_t* function_defs);
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -23,7 +23,8 @@
using infrt::host_context::Attribute;
namespace infrt::kernel {
namespace infrt {
namespace kernel {
template <typename T>
T add(T a, T b) {
......@@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print<float>));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -15,13 +15,16 @@
#pragma once
#include <string>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
/**
* Register all the basic kernels to \p registry.
......@@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry);
void RegisterIntBasicKernels(host_context::KernelRegistry* registry);
void RegisterFloatBasicKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -25,7 +25,8 @@
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel {
namespace infrt {
namespace kernel {
using namespace host_context; // NOLINT
using namespace tensor; // NOLINT
......@@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShallowCopyTensor));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -14,12 +14,16 @@
#pragma once
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void RegisterTensorKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -24,7 +24,8 @@
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void PrintShape(const tensor::TensorShape& shape) {
llvm::raw_os_ostream oos(std::cout);
......@@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -14,14 +14,18 @@
#pragma once
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void RegisterTensorShapeKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -33,7 +33,8 @@ using infrt::host_context::Attribute;
using infrt::host_context::MlirFunctionExecutable;
using infrt::host_context::RemainingArguments;
namespace infrt::kernel {
namespace infrt {
namespace kernel {
namespace {
class BenchmarkStats {
public:
......@@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShadowCopyTensor));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -15,17 +15,21 @@
#pragma once
#include <string>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
/**
* Register all the test kernels to registry.
*/
void RegisterTestKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -18,7 +18,9 @@
#include <string>
#include <vector>
namespace infrt::paddle::cpp {
namespace infrt {
namespace paddle {
namespace cpp {
/*
* Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc
......@@ -226,4 +228,6 @@ class ProgramDescAPI {
virtual void SetVersion(int64_t version) = 0;
};
} // namespace infrt::paddle::cpp
} // namespace cpp
} // namespace paddle
} // namespace infrt
......@@ -22,7 +22,8 @@
#include "paddle/infrt/common/target.h"
#include "paddle/infrt/common/type.h"
namespace infrt::paddle {
namespace infrt {
namespace paddle {
int SizeOfType(framework_proto::VarType::Type type) {
using Type = framework_proto::VarType::Type;
......@@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) {
LoadLoDTensor(fin, out, target);
}
} // namespace infrt::paddle
} // namespace paddle
} // namespace infrt
......@@ -25,7 +25,8 @@
#include "paddle/infrt/paddle/scope.h"
#include "paddle/infrt/paddle/tensor.h"
namespace infrt::paddle {
namespace infrt {
namespace paddle {
namespace framework_proto = ::paddle::framework::proto;
// Read a __model__ file.
......@@ -52,4 +53,5 @@ void TensorFromStream(
const common::Target& target = common::DefaultHostTarget());
void ReadBinaryFile(const std::string& filename, std::string* contents);
} // namespace infrt::paddle
} // namespace paddle
} // namespace infrt
......@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/block_desc.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
template <>
framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>(
......@@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp<framework_proto::OpDesc>() {
return desc_->add_ops();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -18,7 +18,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
......@@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI {
framework_proto::BlockDesc* desc_; // not_own
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/op_desc.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
google::protobuf::internal::RepeatedPtrIterator<framework_proto::OpDesc_Attr>
FindAttr(framework_proto::OpDesc *desc, const std::string &name) {
......@@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector<std::string>, strings);
GET_ATTR_IMPL(std::string, s);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/framework.pb.h"
#include "paddle/infrt/support/variant.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
......@@ -195,4 +197,6 @@ template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v);
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -17,7 +17,9 @@
#include <algorithm>
#include <limits>
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
template <>
framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>(
......@@ -32,4 +34,6 @@ ProgramDesc::AddBlock<framework_proto::BlockDesc>() {
return desc_->add_blocks();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -21,7 +21,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
class ProgramDesc : public cpp::ProgramDescAPI {
......@@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI {
framework_proto::ProgramDesc *desc_; // not_own
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
cpp::VarDescAPI::Type VarDesc::GetType() const {
auto type = desc_->type().type();
......@@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() {
return std::vector<framework_proto::VarType::TensorDesc *>();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -23,7 +23,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
// convert between std::vector and protobuf repeated.
......@@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI {
framework_proto::VarDesc *desc_;
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -435,6 +435,10 @@ inline T* DenseTensor::mutable_data(const paddle::platform::Place& place,
}
void DenseTensor::ShareBufferWith(const DenseTensor& tensor) {
if (storage_ == nullptr) {
storage_ = make_intrusive<paddle::experimental::SharedStorage>(
paddle::platform::CPUPlace());
}
if (storage_ != nullptr && tensor.storage_ != nullptr) {
storage_->set_data_shared(tensor.storage_->data_shared());
}
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import time
import contextlib
import logging
import functools
import numpy as np
from itertools import chain
from functools import reduce
from types import MethodType
from collections import deque, OrderedDict
import paddle
from paddle import nn
from paddle.autograd import PyLayer
import paddle.fluid.core as core
import paddle.distributed as dist
from paddle.fluid.framework import ParamBase
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group
from .sharding_utils import Type, ShardingClipGrad
from ..pp_utils.utils import _all_gather
# CUDA alignment 256 bytes
alignment = {"gpu": 256, }
align = {
Type.fp16.value: 2,
Type.fp32.value: 4,
}
global CHECK_LAYER
CHECK_LAYER = dict() # Help to check layer's id -> layer's name
class ShardingStage3(nn.Layer):
"""
A wrapper for Sharding Stage3 Layer in Dygraph.
.. warning: ShardingStage3 encapsulates the layer strategy and integrates it into the nn.Layer.
.. ZeRO: https://arxiv.org/pdf/1910.02054.pdf.
"""
def __init__(self,
layer,
optimizer,
group=None,
sync_buffers=False,
device="gpu",
pertrain_sync_models=True,
accumulate_grads=False,
offload=False,
sync_comm=False):
super().__init__()
# Default configs
assert core.is_compiled_with_cuda(), "Only support CUDA."
self._layer = layer
self._default_device = device
self.__sync_buffers = sync_buffers
self._accumulate_grads = accumulate_grads
self._offload = offload
self._sync_comm = sync_comm
# Communication group establishment
self._group = dist.new_group(_get_global_group()
.ranks) if group is None else group
self._world_size_scaling = 1.0 / self._group.nranks
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1."
self._rank = self._group.rank
self._global_root_rank = 0 # picking rank 0 as the reference
self._global_ranks = self._group.ranks
self._param2buffer_size = dict() # {param.name: size}
self._param2buffer = dict(
) # {param.name: [(start0, end0),(start1, end1), ...]}
self._trainable_params = dict() # {layer.name: [trainable_params]}
assert not isinstance(
optimizer, list), "Multiple optimizers are not supported now."
self._optim = _OptimizerWrapper(optimizer, self._offload, self._group,
self._update_params_slice)
self._ori_parameter_list = self._optim._parameter_list
self._ori_param_groups = self._optim._param_groups
# Replace optimizer's _grad_clip
if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
logging.warning(
"While using ClipGradByGlobalNorm in ShardingStage3, the grad clip of original optimizer will be changed."
)
self._optim._grad_clip = ShardingClipGrad(self._optim._grad_clip,
paddle.get_device(),
self._group)
# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
self._segment_rank_params(self._layer)
# In the first step, record the execution order of the layer
self._order_tracer = OrderedDict()
self._order_tracer["order"] = 0
self._order_tracer["layer"] = []
# Register task flow
self._task_flow = TaskFlow()
# Register forward hooks
self._register_forward_hooks(self._layer)
# Register backward parameter hooks
self._register_backward_hooks()
# Redefine optimizer step and clear function
self._redefine_opt_step()
self._redefine_opt_clear()
@paddle.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
"""
for p in self._layer.parameters():
dist.broadcast(
p,
src=self._global_root_rank,
group=self._group,
use_calc_stream=True)
# Multi stream operation will be supported later
dist.wait(tensor=p, group=self._group, use_calc_stream=True)
def _clear_gradients(self):
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
for param in trainable_params:
assert hasattr(
param, "fw_storage"
), "Find {} don't have fw_storage attribute.".format(param.name)
# param.bw_storage.zero_()
param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear()
# Update param memery slice
def _update_params_slice(self):
update_list = self._update_params()
if not isinstance(self._optim._param_groups[0], dict):
slice_params = [param.fw_storage for param in update_list]
self._optim._parameter_list = slice_params
self._optim._param_groups = slice_params
else:
params_name_list = list(map(lambda p: p.name, update_list))
for param_group in self._optim._param_groups:
slice_p = []
for p in param_group['params']:
if p.name in params_name_list:
assert hasattr(
p, "fw_storage"
), "Find {} don't have fw_storage attribute.".format(
p.name)
slice_p.append(p.fw_storage)
param_group['params'] = slice_p
def forward(self, *inputs, **kwargs):
"""
A wrapper for Sharding Stage3 layer.
"""
# 1.Sync layer's buffers state
if self.__sync_buffers:
self._sync_buffers()
# 2.Normal FW on the base model
fw = self._layer(*inputs, **kwargs)
return fw
def _segment_rank_params(self, layer, name="last_layer"):
current_layer_params = _current_layer_params(layer)
if current_layer_params:
CHECK_LAYER[id(layer)] = name
self._flatten_layer_params(layer, current_layer_params)
for name, sub_layer in layer.named_children():
self._segment_rank_params(sub_layer, name)
def _flatten_layer_params(self, layer, current_layer_params):
def _add_manage_info(trainable_param):
return _PartitionParam(trainable_param)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
assert id(layer) not in self._trainable_params.keys()
self._trainable_params[id(layer)] = list(
map(_add_manage_info, trainable_params))
for param in self._trainable_params[id(layer)]:
if param.name in self._param2buffer.keys():
continue
self._param2buffer[param.name] = []
# 1.Params alignment
offset = 0
# CUDA alignment 256 bytes
size = param._numel() * align[param.dtype]
remaining = size % alignment[self._default_device]
ali = 0 if remaining == 0 else alignment[
self._default_device] - remaining
align_ = ali // align[param.dtype]
offset = align_ + param._numel()
buffer_size = offset if offset % self._group.nranks == 0 else offset + self._group.nranks - (
offset % self._group.nranks)
self._param2buffer_size[param.name] = buffer_size
# 2.Combination param buffer
assert buffer_size % self._group.nranks == 0
pre_buffer = buffer_size // self._group.nranks
for rank_ in range(self._group.nranks):
self._param2buffer[param.name].append(
(rank_ * pre_buffer, (rank_ + 1) * pre_buffer))
# 3.Flatten layer params and release other rank buffer
self._param_storage(param, buffer_size)
def _param_storage(self, param, buffer_size):
assert isinstance(buffer_size, int)
value = np.zeros(
buffer_size,
dtype=np.float16) if Type.fp16.value == param.dtype else np.zeros(
buffer_size, dtype=np.float32)
buffer = core.VarBase(value=value, place=core.CPUPlace())
param_shape = param.shape
origin_state = param.stop_gradient
param.stop_gradient = True
param.flatten_()
param.stop_gradient = origin_state
start, end = self._param2buffer[param.name][self._rank]
# Copy the current param value
tmp_var = core.VarBase(
tensor=buffer._slice(0, param._numel()), place=core.CPUPlace())
param_cpu = param.cpu()
tmp_var.value().get_tensor().set(param_cpu.value().get_tensor(),
core.CPUPlace())
param.value().get_tensor()._set_dims(param_shape)
param._clear()
# Current rank param_storage
param.fw_storage = core.VarBase(
buffer._slice(start, end), "slice@" + param.name)
param.status = "part"
# Updata optimizer master weights
if param.dtype == Type.fp16.value:
self._optim._master_weights[param.fw_storage.name] = paddle.cast(
param.fw_storage, Type.fp32.value)
def _register_forward_hooks(self, layer):
current_layer_params = _current_layer_params(layer)
if current_layer_params:
self._register_forward_all_hooks(layer, self._task_flow)
for _, sub_layer in layer.named_children():
self._register_forward_hooks(sub_layer)
def _register_forward_all_hooks(self, sub_layer, task_flow):
def _forward_pre_hook(layer, inputs):
return ForwardPreHooks(layer, self._order_tracer,
self._trainable_params, self._param2buffer,
self._rank, self._group, self._sync_comm,
task_flow)
def _forward_post_hook(layer, inputs, outputs):
return ForwardPostHooks.apply(
outputs, layer, self._order_tracer, self._trainable_params,
self._param2buffer, self._param2buffer_size, self._rank,
self._group, self._sync_comm, task_flow)
# register previous forward hooks
sub_layer.register_forward_pre_hook(_forward_pre_hook)
# register post forward hooks
sub_layer.register_forward_post_hook(_forward_post_hook)
@paddle.no_grad()
def _sync_buffers(self):
for buffer in self._layer.buffers(include_sublayers=True):
dist.broadcast(
buffer,
self._global_root_rank,
self._group,
use_calc_stream=True)
# Multi stream operation will be supported later
dist.wait(tensor=buffer, group=self._group, use_calc_stream=True)
def __getattr__(self, name):
"""Forward missing attributes to wrapped layer."""
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self._layer, name)
def _update_params(self):
update_list = []
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
for param in trainable_params:
assert hasattr(
param,
"fw_storage"), "Find {} don't have fw_storage attribute".format(
param.name)
if self._accumulate_grads:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param)
param.fw_storage._copy_gradient_from(param.bw_storage)
update_list.append(param)
return update_list
def get_all_parameters(self):
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
for param in trainable_params:
if param.use_count > 0:
continue
assert hasattr(
param,
"fw_storage"), "Find {} don't have fw_storage attribute".format(
param.name)
full_param = _all_gather(
param.fw_storage, self._group, use_calc_stream=True)
dist.wait(
tensor=full_param, group=self._group, use_calc_stream=True)
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups
def _register_backward_hooks(self):
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
for param in trainable_params:
allreduce_function = self._get_allreduce_fn(param)
param._register_backward_hook(allreduce_function)
def _get_allreduce_fn(self, param):
@paddle.no_grad()
def reduce(*_):
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
with paddle.amp.auto_cast(enable=False):
if not self._accumulate_grads:
full_grad.scale_(scale=self._world_size_scaling)
# Only support sync allreduce current rank's layer now
dist.all_reduce(
tensor=full_grad,
group=self._group,
use_calc_stream=True)
dist.wait(
tensor=full_grad,
group=self._group,
use_calc_stream=True)
start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None:
param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone()
else:
param.bw_storage.add_(
core.VarBase(full_grad._slice(start, end)).detach()
.clone())
param.clear_gradient(False)
param._gradient_set_empty(False)
tmp_var = self._task_flow.full_grad.pop(param.name)
tmp_var._clear()
if param.name in self._task_flow.full_param.keys():
if param.status == "all":
param.use_count = 0
param._clear()
start, end = self._param2buffer[param.name][self._rank]
with paddle.amp.auto_cast(enable=False):
param.fw_storage = core.VarBase(
self._task_flow.full_param[param.name]._slice(start,
end),
param.name + "@slice").detach().clone()
param.status = "part"
tmp_var = self._task_flow.full_param.pop(param.name)
tmp_var._clear()
return reduce
def _redefine_opt_step(self):
params_slice_func = self._update_params_slice
opt_step = self._optim.step
update_scaler = self._optim.update_scaler
def _opt_step(self):
if not update_scaler:
params_slice_func()
opt_step()
self._optim.step = MethodType(_opt_step, self._optim)
def _redefine_opt_clear(self):
clear_func = self._clear_gradients
def _opt_clear(self):
clear_func()
self._optim.clear_grad = MethodType(_opt_clear, self._optim)
def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank,
group, sync_comm, task_flow):
# Record layer's id
layer_id = id(layer)
use_calc, sync_wait = False, False
if layer_id not in order_tracer.keys() or sync_comm:
use_calc, sync_wait = True, True
task_flow.use_calc[layer_id] = use_calc
else:
task_flow.use_calc[layer_id] = use_calc
_wait_layer(trainable_params, layer_id, task_flow, group, use_calc)
if layer_id == order_tracer["layer"][-1]: return
order_ = order_tracer[layer_id]
layer_id = order_tracer["layer"][order_ + 1]
_allgather_buffer(
layer_id,
trainable_params,
group,
use_calc_stream=use_calc,
task_flow=task_flow,
sync_wait=sync_wait)
return
class ForwardPostHooks(PyLayer):
@staticmethod
def forward(ctx, inputs, layer, order_tracer, trainable_params,
param2buffer, param2buffer_size, rank, group, sync_comm,
task_flow):
_release_param(layer, trainable_params, param2buffer, rank, task_flow)
layer_id = id(layer)
if layer_id not in order_tracer.keys():
order_ = order_tracer["order"]
order_tracer[layer_id] = order_
order_tracer["order"] += 1
order_tracer["layer"].append(layer_id)
ctx.order_tracer = order_tracer
ctx.task_flow = task_flow
ctx.group = group
ctx.layer = layer
ctx.sync_comm = sync_comm
ctx.trainable_params = trainable_params
ctx.param2buffer_size = param2buffer_size
return inputs
@staticmethod
def backward(ctx, *args):
# Load context value
order_tracer = ctx.order_tracer
task_flow = ctx.task_flow
group = ctx.group
layer = ctx.layer
trainable_params = ctx.trainable_params
param2buffer_size = ctx.param2buffer_size
sync_comm = ctx.sync_comm
layer_id = id(layer)
use_calc, sync_wait = False, False
if sync_comm:
use_calc, sync_wait = True, True
_allgather_buffer(
layer_id,
trainable_params,
group,
use_calc_stream=use_calc,
task_flow=task_flow,
sync_wait=sync_wait)
else:
_wait_layer(trainable_params, layer_id, task_flow, group, use_calc)
_create_params_grad(layer, trainable_params, param2buffer_size,
task_flow)
task_flow.use_calc[layer_id] = use_calc
if layer_id != order_tracer["layer"][0] and not sync_comm:
layer_next_id = order_tracer["layer"][order_tracer[layer_id] - 1]
_allgather_buffer(
layer_next_id,
trainable_params,
group,
use_calc_stream=use_calc,
task_flow=task_flow,
sync_wait=sync_wait)
return args
class TaskFlow:
"""
Task flows, one way linked list for task acquisition.
"""
def __init__(self,
full_param=dict(),
full_grad=dict(),
use_calc=dict(),
callback=None):
self.full_param = full_param
self.full_grad = full_grad
self.use_calc = use_calc
self.callback = callback
def _release_param(layer, trainable_params, param2buffer, rank, task_flow):
for param in trainable_params[id(layer)]:
# async communicate share weight not clear
param.use_count -= 1
if param.use_count == 0:
param._clear()
if param.name in task_flow.full_param.keys():
start, end = param2buffer[param.name][rank]
with paddle.amp.auto_cast(enable=False):
param.fw_storage = core.VarBase(
task_flow.full_param[param.name]._slice(start, end),
param.name + "@slice").detach().clone()
param.status = "part"
tmp_var = task_flow.full_param.pop(param.name)
tmp_var._clear()
return
def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream):
for param in trainable_params[layer_id]:
if param.status == "all":
param.use_count += 1
continue
if param.name in task_flow.full_param.keys():
full_param = task_flow.full_param[param.name]
with paddle.amp.auto_cast(enable=False):
paddle.device.cuda.synchronize()
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
else:
_allgather_buffer(
layer_id,
trainable_params,
group,
use_calc_stream,
task_flow,
sync_wait=True)
break
return task_flow
def _allgather_buffer(layer_id,
trainable_params,
group,
use_calc_stream,
task_flow,
sync_wait=False):
for param in trainable_params[layer_id]:
if param.status == "all":
param.use_count += 1
continue
with paddle.amp.auto_cast(enable=False):
full_param = _all_gather(
param.fw_storage, group, use_calc_stream=use_calc_stream)
if sync_wait:
with paddle.amp.auto_cast(enable=False):
dist.wait(
tensor=full_param,
group=group,
use_calc_stream=use_calc_stream)
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
task_flow.full_param[param.name] = full_param
return task_flow
@paddle.no_grad()
def _create_params_grad(layer, trainable_params, param2buffer_size, task_flow):
for param in trainable_params[id(layer)]:
if param.name in task_flow.full_grad.keys():
continue
assert isinstance(param2buffer_size[param.name], int)
temp_grad = paddle.zeros(
[param2buffer_size[param.name]], dtype=param.dtype)
param._copy_gradient_from(
core.VarBase(temp_grad._slice(0, param._numel())))
task_flow.full_grad[param.name] = temp_grad
return task_flow
def _PartitionParam(param):
if not hasattr(param, "fw_storage"):
setattr(param, "fw_storage", None)
setattr(param, "bw_storage", None)
setattr(param, "status", "all")
setattr(param, "use_count", 0)
return param
def _VarBaseWrapper(param):
varbase = param.fw_storage
tmp_param = ParamBase(
shape=varbase.shape, dtype=varbase.dtype, name="slice@" + param.name)
varbase._share_buffer_to(tmp_param)
tmp_param.regularizer = param.regularizer
tmp_param.optimize_attr['learning_rate'] = param.optimize_attr[
'learning_rate']
varbase._clear()
return tmp_param
def _OptimizerWrapper(optimizer, offload, group, update_params_slice):
if not hasattr(optimizer, "_optim"):
setattr(optimizer, "_optim", optimizer)
setattr(optimizer, "offload", offload)
setattr(optimizer, "group", group)
setattr(optimizer, "update_scaler", None)
setattr(optimizer, "update_slice", update_params_slice)
return optimizer
def _current_layer_params(layer):
return layer.parameters(
include_sublayers=False) + list(layer.extra_parameters) if hasattr(
layer, "extra_parameters") else layer.parameters(
include_sublayers=False)
......@@ -152,6 +152,9 @@ def ShardingScaler(scaler):
param_grads = []
param_grads_fp16 = []
param_grads_fp32 = []
if hasattr(optimizer, "update_slice"):
optimizer.update_slice()
optimizer.update_scaler = True
if getattr(optimizer._optim, '_param_groups', None) and isinstance(
optimizer._optim._param_groups[0], dict):
......@@ -161,27 +164,21 @@ def ShardingScaler(scaler):
if param._grad_ivar() is not None:
param_grads.append(param._grad_ivar())
if param._grad_ivar(
).dtype == core.VarDesc.VarType.FP16:
).dtype in [core.VarDesc.VarType.FP16, paddle.float16]:
param_grads_fp16.append(param._grad_ivar())
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar() for param in optimizer._optim._parameter_list
if param._grad_ivar() is not None
]
param_grads_fp16 = [
param._grad_ivar() for param in optimizer._optim._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16
)
]
param_grads_fp32 = [
param._grad_ivar() for param in optimizer._optim._parameter_list
if (param._grad_ivar() is not None
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32
)
]
for param in optimizer._optim._parameter_list:
if param.grad is not None:
param_grads.append(param.grad)
if param.grad.dtype in [
core.VarDesc.VarType.FP16, paddle.float16
]:
param_grads_fp16.append(param.grad)
else:
param_grads_fp32.append(param.grad)
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool))
......
......@@ -34,6 +34,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage2)
list(APPEND DIST_TEST_OPS test_dygraph_sharding_stage3)
list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
......@@ -250,6 +251,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_optimizer_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage2)
list(REMOVE_ITEM TEST_OPS test_dygraph_sharding_stage3)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
......@@ -1058,6 +1060,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_optimizer_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_dygraph_sharding_stage3 PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120)
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2021)
np.random.seed(2021)
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
fleet.init(is_collective=True)
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model,
sharding_stage,
use_pure_fp16=False,
accumulate_grad=False,
opt_group=False,
recompute=False):
group = paddle.distributed.new_group([0, 1])
if opt_group:
optimizer = optimizer_setting(
model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group)
else:
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = ShardingScaler(scaler)
if sharding_stage == 2:
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=accumulate_grad)
elif sharding_stage == 3:
model = ShardingStage3(
model, optimizer=optimizer, group=group, sync_comm=recompute)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if accumulate_grad:
if not use_pure_fp16:
avg_loss.backward()
optimizer.step()
else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if sharding_stage == 3:
model.get_all_parameters()
return model.parameters()
def test_stage2_stage3():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = MLP(), MLP(), MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict)
# fp32
stage2_params = train_mlp(
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True)
stage3_params = train_mlp(
mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp32 accumulate grad
stage2_params = train_mlp(
mlp3,
sharding_stage=2,
use_pure_fp16=False,
accumulate_grad=True,
opt_group=True)
stage3_params = train_mlp(
mlp4,
sharding_stage=3,
use_pure_fp16=False,
accumulate_grad=True,
opt_group=True)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp16
stage2_params = train_mlp(
mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False)
stage3_params = train_mlp(
mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False)
for i in range(len(stage2_params)):
for j in range(len(stage3_params)):
if stage2_params[i].name == stage3_params[j].name:
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[j].numpy(),
rtol=1e-6)
# fp16 recompute
stage3_params = train_mlp(
mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False)
stage3_params_re = train_mlp(
mlp8,
sharding_stage=3,
use_pure_fp16=True,
opt_group=False,
recompute=True)
for i in range(len(stage3_params)):
for j in range(len(stage3_params_re)):
if stage3_params[i].name == stage3_params_re[j].name:
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_re[j].numpy(),
rtol=1e-6)
return
if __name__ == '__main__':
test_stage2_stage3()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import unittest
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertFlattenContiguousRangeTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input(batch):
return np.random.random([2, batch, 4, 8, 3]).astype(np.float32)
for batch in [1, 2, 4]:
for start_axis in range(5):
for stop_axis in range(start_axis, 5):
type = "flatten_contiguous_range"
op_outputs = {
"Out": ["output_data"],
"XShape": ["xshape_data"]
}
ops_config = [{
"op_type": type,
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": op_outputs,
"op_attrs": {
"start_axis": start_axis,
"stop_axis": stop_axis,
}
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, batch))
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [2, 1, 4, 8, 3]}
self.dynamic_shape.max_input_shape = {"input_data": [2, 4, 4, 8, 3]}
self.dynamic_shape.opt_input_shape = {"input_data": [2, 2, 4, 8, 3]}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 >= 7000:
if dynamic_shape:
return 1, 2
else:
if attrs[0]['start_axis'] == 0:
return 0, 3
else:
return 1, 2
else:
return 0, 3
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage3(self):
self.run_mnist_2gpu('dygraph_sharding_stage3.py')
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -66,6 +66,15 @@ class TestStackOpBase(XPUOpTest):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if self.dtype == 'int64' or self.dtype == 'int32':
pass
else:
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, self.get_x_names(), 'Y')
class TestStackOp1(TestStackOpBase):
def initParameters(self):
......@@ -81,11 +90,17 @@ class TestStackOp3(TestStackOpBase):
def initParameters(self):
self.axis = -1
def test_check_grad(self):
pass
class TestStackOp4(TestStackOpBase):
def initParameters(self):
self.axis = -4
def test_check_grad(self):
pass
class TestStackOp5(TestStackOpBase):
def initParameters(self):
......@@ -113,7 +128,7 @@ class TestStackOpint(TestStackOpBase):
self.num_inputs = 4
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = 'int'
self.dtype = 'int32'
def initParameters(self):
self.num_inputs = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册