未验证 提交 5a6182b8 编写于 作者: W Wilber 提交者: GitHub

infrt run once (A trick version) (#41634)

* temporariliy run once

* update

* update

* update

* update

* fix ci problem
上级 2ab986ae
......@@ -14,7 +14,7 @@
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include <llvm/include/llvm/ADT/TypeSwitch.h>
#include <llvm/ADT/TypeSwitch.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
......
......@@ -15,7 +15,7 @@
#include <glog/logging.h>
#include <llvm/Support/ErrorHandling.h>
#include <llvm/include/mlir/IR/Attributes.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/PatternMatch.h>
......
......@@ -87,7 +87,7 @@ int main(int argc, char** argv) {
std::cout << "\npass failed!\n" << std::endl;
return 4;
}
// module->dump();
module->dump();
::infrt::host_context::TestMlir(module.get(), &registry);
return 0;
}
......@@ -16,6 +16,7 @@
#include <mlir/IR/BuiltinOps.h>
#include <string>
#include <unordered_set>
#include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/kernel_registry.h"
......@@ -71,7 +72,15 @@ OpExecutableBuilder::OpExecutableBuilder(const std::string& op_name,
// TODO(Superjomn) support other device other than CPU.
CHECK(impl_->kernel_impl) << "No CPU kernel called " << op_name;
if (op_name == "dt.get_param") {
// TODO(wilber): Maybe we can use the MLIR trait or other facilities to remove
// the run_once set.
std::unordered_set<std::string> run_once_set{
"dt.get_param",
"trt.create_engine",
"phi_dt.create_host_inited_dense_tensor.f32",
"phi_dt.create_context.cpu",
"phi_dt.create_context.gpu"};
if (run_once_set.count(op_name)) {
impl_->run_once = true;
}
}
......
......@@ -22,6 +22,7 @@
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#ifdef INFRT_WITH_GPU
#include <cuda_runtime.h>
......@@ -308,34 +309,50 @@ inline size_t SizeOfDataType(::phi::DataType data_type) {
}
return 0;
}
::phi::DenseTensor GpuMemCpy(const ::phi::DenseTensor& input,
const ::phi::GPUContext& context,
bool d2h) {
void GpuMemCpy(const ::phi::DenseTensor& input,
const ::phi::GPUContext& context,
bool d2h,
::phi::DenseTensor* output) {
if (d2h) {
::phi::DenseTensor ret(
const_cast<::phi::Allocator*>(&context.GetHostAllocator()),
input.meta());
CHECK(input.place().GetType() == ::phi::AllocationType::GPU);
// TODO(wilber): Add sync op and stream.
cudaMemcpyAsync(ret.data(),
// TODO(wilber): Just a trick to avoid malloc.
if (input.numel() > output->numel()) {
// TODO(wilber): Use pinned memory.
output->Resize(input.dims());
context.HostAlloc(
output, input.dtype(), input.numel() * SizeOfDataType(input.dtype()));
}
cudaMemcpyAsync(output->data(),
input.data(),
SizeOfDataType(input.dtype()) * input.numel(),
cudaMemcpyDeviceToHost,
nullptr);
return ret;
context.stream());
// TODO(wilber): Ir add sync op.
cudaStreamSynchronize(context.stream());
} else {
// h2d
::phi::DenseTensor ret(
const_cast<::phi::Allocator*>(&context.GetAllocator()), input.meta());
CHECK(input.place().GetType() == ::phi::AllocationType::CPU ||
input.place().GetType() == ::phi::AllocationType::GPUPINNED);
if (input.numel() > output->numel()) {
output->Resize(input.dims());
context.Alloc(output,
input.dtype(),
input.numel() * SizeOfDataType(input.dtype()),
false);
} else {
output->Resize(input.dims());
}
// TODO(wilber): Add sync op and stream.
cudaMemcpyAsync(ret.data(),
cudaMemcpyAsync(output->data(),
input.data(),
SizeOfDataType(input.dtype()) * input.numel(),
cudaMemcpyHostToDevice,
nullptr);
return ret;
context.stream());
}
}
#endif
......
......@@ -76,9 +76,10 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor);
int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map);
#ifdef INFRT_WITH_GPU
::phi::DenseTensor GpuMemCpy(const ::phi::DenseTensor& input,
const ::phi::GPUContext& context,
bool d2h);
void GpuMemCpy(const ::phi::DenseTensor& input,
const ::phi::GPUContext& context,
bool d2h,
::phi::DenseTensor* output);
#endif
} // namespace phi
......
......@@ -119,6 +119,7 @@ void NaiveMatmul(const DenseHostTensor &x,
const int N = w.shape().GetDim(1);
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
out_data[i * N + j] = 0;
for (int k = 0; k < K; k++) {
out_data[i * N + j] += x_data[i * K + k] * w_data[k * N + j];
}
......@@ -134,9 +135,11 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
{"shape"});
registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor));
registry->AddKernel("dt.fill_tensor_with_constant.f32",
INFRT_KERNEL(FillTensorWithConstant<float>));
INFRT_KERNEL(FillTensorWithConstant<float>),
{"value"});
registry->AddKernel("dt.fill_tensor_with_constant.f64",
INFRT_KERNEL(FillTensorWithConstant<double>));
INFRT_KERNEL(FillTensorWithConstant<double>),
{"value"});
// TensorMap related methods.
registry->AddKernel("dt.load_params", INFRT_KERNEL(LoadParams));
......
......@@ -57,7 +57,7 @@ namespace tensorrt {
// TODO(wilber): The build option shoule be fiiled from mlir info.
backends::tensorrt::BuildOptions options;
options.max_batch = 4;
options.workspace = 1024;
options.workspace = 128;
// Parse mlir Region which only has one block.
mlir::Operation& operation = *create_engine_op.operation;
......
......@@ -115,6 +115,27 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT
// TODO(Inference)
// CHECK(false) << "Not supported adaptive pool";
// TODO(wilber): Reformat.
// global average pooling.
auto ksize_vec = ArrayAttrToVec<int>(ksize);
if (static_cast<nvinfer1::PoolingType>(pool_type) ==
nvinfer1::PoolingType::kAVERAGE &&
ksize_vec.size() == 2 && ksize_vec[0] == 1 && ksize_vec[1] == 1) {
nvinfer1::Dims dims;
dims.nbDims = 2;
dims.d[0] = input_shape.d[1];
dims.d[1] = input_shape.d[2];
auto* layer = network->addPoolingNd(
*input_itensor, static_cast<nvinfer1::PoolingType>(pool_type), dims);
CHECK_NOTNULL(layer);
mlir::Value out_repr = op.output_tensor();
nvinfer1::ITensor* out_tensor = layer->getOutput(0);
value_to_trt_tensor_map[out_repr] = out_tensor;
return;
}
// plugin...
std::vector<int> input_shape_v;
for (int i = 0; i < input_dims; i++) {
input_shape_v.push_back(input_shape.d[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册