未验证 提交 5a6182b8 编写于 作者: W Wilber 提交者: GitHub

infrt run once (A trick version) (#41634)

* temporariliy run once

* update

* update

* update

* update

* fix ci problem
上级 2ab986ae
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "paddle/infrt/dialect/phi/ir/phi_base.h" #include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include <llvm/include/llvm/ADT/TypeSwitch.h> #include <llvm/ADT/TypeSwitch.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h> #include <mlir/IR/DialectImplementation.h>
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <llvm/Support/ErrorHandling.h> #include <llvm/Support/ErrorHandling.h>
#include <llvm/include/mlir/IR/Attributes.h> #include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h> #include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/PatternMatch.h> #include <mlir/IR/PatternMatch.h>
......
...@@ -87,7 +87,7 @@ int main(int argc, char** argv) { ...@@ -87,7 +87,7 @@ int main(int argc, char** argv) {
std::cout << "\npass failed!\n" << std::endl; std::cout << "\npass failed!\n" << std::endl;
return 4; return 4;
} }
// module->dump(); module->dump();
::infrt::host_context::TestMlir(module.get(), &registry); ::infrt::host_context::TestMlir(module.get(), &registry);
return 0; return 0;
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <mlir/IR/BuiltinOps.h> #include <mlir/IR/BuiltinOps.h>
#include <string> #include <string>
#include <unordered_set>
#include "paddle/infrt/host_context/kernel_frame.h" #include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/kernel_registry.h" #include "paddle/infrt/host_context/kernel_registry.h"
...@@ -71,7 +72,15 @@ OpExecutableBuilder::OpExecutableBuilder(const std::string& op_name, ...@@ -71,7 +72,15 @@ OpExecutableBuilder::OpExecutableBuilder(const std::string& op_name,
// TODO(Superjomn) support other device other than CPU. // TODO(Superjomn) support other device other than CPU.
CHECK(impl_->kernel_impl) << "No CPU kernel called " << op_name; CHECK(impl_->kernel_impl) << "No CPU kernel called " << op_name;
if (op_name == "dt.get_param") { // TODO(wilber): Maybe we can use the MLIR trait or other facilities to remove
// the run_once set.
std::unordered_set<std::string> run_once_set{
"dt.get_param",
"trt.create_engine",
"phi_dt.create_host_inited_dense_tensor.f32",
"phi_dt.create_context.cpu",
"phi_dt.create_context.gpu"};
if (run_once_set.count(op_name)) {
impl_->run_once = true; impl_->run_once = true;
} }
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#ifdef INFRT_WITH_GPU #ifdef INFRT_WITH_GPU
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -308,34 +309,50 @@ inline size_t SizeOfDataType(::phi::DataType data_type) { ...@@ -308,34 +309,50 @@ inline size_t SizeOfDataType(::phi::DataType data_type) {
} }
return 0; return 0;
} }
::phi::DenseTensor GpuMemCpy(const ::phi::DenseTensor& input, void GpuMemCpy(const ::phi::DenseTensor& input,
const ::phi::GPUContext& context, const ::phi::GPUContext& context,
bool d2h) { bool d2h,
::phi::DenseTensor* output) {
if (d2h) { if (d2h) {
::phi::DenseTensor ret(
const_cast<::phi::Allocator*>(&context.GetHostAllocator()),
input.meta());
CHECK(input.place().GetType() == ::phi::AllocationType::GPU); CHECK(input.place().GetType() == ::phi::AllocationType::GPU);
// TODO(wilber): Add sync op and stream.
cudaMemcpyAsync(ret.data(), // TODO(wilber): Just a trick to avoid malloc.
if (input.numel() > output->numel()) {
// TODO(wilber): Use pinned memory.
output->Resize(input.dims());
context.HostAlloc(
output, input.dtype(), input.numel() * SizeOfDataType(input.dtype()));
}
cudaMemcpyAsync(output->data(),
input.data(), input.data(),
SizeOfDataType(input.dtype()) * input.numel(), SizeOfDataType(input.dtype()) * input.numel(),
cudaMemcpyDeviceToHost, cudaMemcpyDeviceToHost,
nullptr); context.stream());
return ret; // TODO(wilber): Ir add sync op.
cudaStreamSynchronize(context.stream());
} else { } else {
// h2d // h2d
::phi::DenseTensor ret(
const_cast<::phi::Allocator*>(&context.GetAllocator()), input.meta());
CHECK(input.place().GetType() == ::phi::AllocationType::CPU || CHECK(input.place().GetType() == ::phi::AllocationType::CPU ||
input.place().GetType() == ::phi::AllocationType::GPUPINNED); input.place().GetType() == ::phi::AllocationType::GPUPINNED);
if (input.numel() > output->numel()) {
output->Resize(input.dims());
context.Alloc(output,
input.dtype(),
input.numel() * SizeOfDataType(input.dtype()),
false);
} else {
output->Resize(input.dims());
}
// TODO(wilber): Add sync op and stream. // TODO(wilber): Add sync op and stream.
cudaMemcpyAsync(ret.data(), cudaMemcpyAsync(output->data(),
input.data(), input.data(),
SizeOfDataType(input.dtype()) * input.numel(), SizeOfDataType(input.dtype()) * input.numel(),
cudaMemcpyHostToDevice, cudaMemcpyHostToDevice,
nullptr); context.stream());
return ret;
} }
} }
#endif #endif
......
...@@ -76,9 +76,10 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor); ...@@ -76,9 +76,10 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor);
int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map); int32_t TensorMapGetSize(const ::infrt::phi::DenseTensorMap& map);
#ifdef INFRT_WITH_GPU #ifdef INFRT_WITH_GPU
::phi::DenseTensor GpuMemCpy(const ::phi::DenseTensor& input, void GpuMemCpy(const ::phi::DenseTensor& input,
const ::phi::GPUContext& context, const ::phi::GPUContext& context,
bool d2h); bool d2h,
::phi::DenseTensor* output);
#endif #endif
} // namespace phi } // namespace phi
......
...@@ -119,6 +119,7 @@ void NaiveMatmul(const DenseHostTensor &x, ...@@ -119,6 +119,7 @@ void NaiveMatmul(const DenseHostTensor &x,
const int N = w.shape().GetDim(1); const int N = w.shape().GetDim(1);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {
out_data[i * N + j] = 0;
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
out_data[i * N + j] += x_data[i * K + k] * w_data[k * N + j]; out_data[i * N + j] += x_data[i * K + k] * w_data[k * N + j];
} }
...@@ -134,9 +135,11 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { ...@@ -134,9 +135,11 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
{"shape"}); {"shape"});
registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor)); registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor));
registry->AddKernel("dt.fill_tensor_with_constant.f32", registry->AddKernel("dt.fill_tensor_with_constant.f32",
INFRT_KERNEL(FillTensorWithConstant<float>)); INFRT_KERNEL(FillTensorWithConstant<float>),
{"value"});
registry->AddKernel("dt.fill_tensor_with_constant.f64", registry->AddKernel("dt.fill_tensor_with_constant.f64",
INFRT_KERNEL(FillTensorWithConstant<double>)); INFRT_KERNEL(FillTensorWithConstant<double>),
{"value"});
// TensorMap related methods. // TensorMap related methods.
registry->AddKernel("dt.load_params", INFRT_KERNEL(LoadParams)); registry->AddKernel("dt.load_params", INFRT_KERNEL(LoadParams));
......
...@@ -57,7 +57,7 @@ namespace tensorrt { ...@@ -57,7 +57,7 @@ namespace tensorrt {
// TODO(wilber): The build option shoule be fiiled from mlir info. // TODO(wilber): The build option shoule be fiiled from mlir info.
backends::tensorrt::BuildOptions options; backends::tensorrt::BuildOptions options;
options.max_batch = 4; options.max_batch = 4;
options.workspace = 1024; options.workspace = 128;
// Parse mlir Region which only has one block. // Parse mlir Region which only has one block.
mlir::Operation& operation = *create_engine_op.operation; mlir::Operation& operation = *create_engine_op.operation;
......
...@@ -115,6 +115,27 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT ...@@ -115,6 +115,27 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT
// TODO(Inference) // TODO(Inference)
// CHECK(false) << "Not supported adaptive pool"; // CHECK(false) << "Not supported adaptive pool";
// TODO(wilber): Reformat.
// global average pooling.
auto ksize_vec = ArrayAttrToVec<int>(ksize);
if (static_cast<nvinfer1::PoolingType>(pool_type) ==
nvinfer1::PoolingType::kAVERAGE &&
ksize_vec.size() == 2 && ksize_vec[0] == 1 && ksize_vec[1] == 1) {
nvinfer1::Dims dims;
dims.nbDims = 2;
dims.d[0] = input_shape.d[1];
dims.d[1] = input_shape.d[2];
auto* layer = network->addPoolingNd(
*input_itensor, static_cast<nvinfer1::PoolingType>(pool_type), dims);
CHECK_NOTNULL(layer);
mlir::Value out_repr = op.output_tensor();
nvinfer1::ITensor* out_tensor = layer->getOutput(0);
value_to_trt_tensor_map[out_repr] = out_tensor;
return;
}
// plugin...
std::vector<int> input_shape_v; std::vector<int> input_shape_v;
for (int i = 0; i < input_dims; i++) { for (int i = 0; i < input_dims; i++) {
input_shape_v.push_back(input_shape.d[i]); input_shape_v.push_back(input_shape.d[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册