diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 804a0bcda3562d89ee95ca96b55fbd0cb98f6976..98357922eb8d537b67fb6bec627d172cdf9903ea 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(core) add_subdirectory(x86) +add_subdirectory(host) add_subdirectory(cuda) add_subdirectory(operators) add_subdirectory(kernels) diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index c2c40c36bcce21075611a4c7dfc67f6efe0b4189..73989dda91b796612d5c87bb5630365185420a2b 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -1,3 +1,3 @@ -cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite) +cc_library(cxx_api_lite SRCS cxx_api.cc DEPS scope_lite op_executor_lite host_kernels ops_lite optimizer_lite target_wrapper_host) cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index 218f24c7a3cdd4f5a43e3dcecccc1a45e2fb6d7d..d6b5f09dcbc08d24507be9a269d700bb221fa8a9 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -58,7 +58,7 @@ class Predictor { const Tensor* GetOutput(size_t offset) { auto* _fetch_list = program_->exec_scope()->FindVar("fetch"); CHECK(_fetch_list) << "no fatch variable in exec_scope"; - auto fetch_list = _fetch_list->Get>(); + auto& fetch_list = *_fetch_list->GetMutable>(); CHECK_LT(offset, fetch_list.size()) << "offset " << offset << " overflow"; return &fetch_list.at(offset); } diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index 515ef9c682e2024f3b04f331f48ce30ab1f2c686..43bbc69a92036217d65998ace45cf708a587ba8b 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -28,8 +28,22 @@ TEST(CXXApi, test) { auto* input_tensor = predictor.GetInput(0); input_tensor->Resize({100, 100}); - input_tensor->mutable_data(); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + LOG(INFO) << "input " << input_tensor; + LOG(INFO) << "input " << *input_tensor; + predictor.Run(); + + auto* out = predictor.GetOutput(0); + LOG(INFO) << out << " memory size " << out->memory_size(); + LOG(INFO) << "out " << out->data()[0]; + LOG(INFO) << "out " << out->data()[1]; + LOG(INFO) << "dims " << out->dims(); + LOG(INFO) << "out " << *out; } } // namespace lite diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 9fb7b60b34e7eaa2f3d42bc7d7bf67917bcd3365..9db4874d1bd3405ef28b911a4a31f6c20089a50d 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -65,6 +65,13 @@ class KernelBase { virtual ~KernelBase() = default; + std::string DebugString() const { + std::stringstream ss; + ss << op_type() << ":" << TargetToStr(target()) << "/" + << PrecisionToStr(precision()) << "/" << DataLayoutToStr(layout()); + return ss.str(); + } + protected: std::unique_ptr context_; mutable operators::param_t param_; diff --git a/paddle/fluid/lite/core/memory.h b/paddle/fluid/lite/core/memory.h index ccfe412edad40a3961b83907b90642fa54e9d5af..6cb321c637bf687d0f45608f37e5e51ba014f365 100644 --- a/paddle/fluid/lite/core/memory.h +++ b/paddle/fluid/lite/core/memory.h @@ -21,18 +21,16 @@ namespace lite { static void* TargetMalloc(TargetType target, size_t size) { void* data{nullptr}; - switch (static_cast(target)) { - case static_cast(TargetType::kX86): - data = TargetWrapper::Malloc(size); + switch (target) { + case TargetType::kHost: + case TargetType::kX86: + data = TargetWrapper::Malloc(size); break; - case static_cast(TargetType::kCUDA): + case TargetType::kCUDA: data = TargetWrapper::Malloc(size); break; - case static_cast(TargetType::kHost): - data = TargetWrapper::Malloc(size); - break; default: - LOG(FATAL) << "Unknown type"; + LOG(FATAL) << "Unknown supported target " << TargetToStr(target); } return data; } @@ -52,17 +50,19 @@ static void TargetFree(TargetType target, void* data) { static void TargetCopy(TargetType target, void* dst, const void* src, size_t size) { - switch (static_cast(target)) { - case static_cast(TargetType::kX86): - case static_cast(TargetType::kHost): + switch (target) { + case TargetType::kX86: + case TargetType::kHost: TargetWrapper::MemcpySync(dst, src, size, IoDirection::DtoD); break; - case static_cast(TargetType::kCUDA): + case TargetType::kCUDA: TargetWrapper::MemcpySync(dst, src, size, IoDirection::DtoD); break; + default: + LOG(FATAL) << "unsupported type"; } } @@ -79,12 +79,10 @@ class Buffer { void ResetLazy(TargetType target, size_t size) { if (target != target_ || space_ < size) { Free(); + data_ = TargetMalloc(target, size); + target_ = target; + space_ = size; } - - if (size < space_) return; - target_ = target; - data_ = TargetMalloc(target, size); - space_ = size; } void ResizeLazy(size_t size) { ResetLazy(target_, size); } diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 81b1d8565efc7b26f53258d0102a1417d5bd5382..070747f4e5797bad9ce75a037baf30212c005583 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -60,9 +60,6 @@ class SSAGraph : GraphBase { op->SetValidPlaces(valid_places); auto &new_node = node_storage_.back(); auto kernels = op->CreateKernels(valid_places); - for (auto &kernel : kernels) { - op->AttachKernel(kernel.get()); - } node_storage_.back().AsInstruct(op->op_type_, std::move(kernels), op, op->op_info()); diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index 1493b577c5d6eabd9200d6f31d619172a513ff0e..86557ce0abc6f8bd40aa5258e38e6224fb709c3b 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -29,6 +29,7 @@ std::vector> OpLite::CreateKernels( (kernel_type.empty() ? op_type_ : kernel_type), place.target, place.precision); for (auto &&it : ks) { + AttachKernel(it.get()); kernels.emplace_back(std::move(it)); } } diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 8046a3473fc210843146ec4ae3d5f6a37636ac06..a31a0eb0d5ff886abd5a1d6bced63fdf4914b8a8 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -105,6 +105,11 @@ struct Instruction { void Run() { CHECK(op_); CHECK(kernel_); + LOG(INFO) << "running kernel> " << kernel_->DebugString(); + if (UNLIKELY(first_epoch_)) { + first_epoch_ = false; + op_->CheckShape(); + } op_->InferShape(); kernel_->Run(); } @@ -112,6 +117,7 @@ struct Instruction { private: std::shared_ptr op_; std::unique_ptr kernel_; + bool first_epoch_{true}; }; /* diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index c8b7aa18cb8b0b73e8b1ddf0a17818faa90ca7d7..2852f5825726538b09f5f4d93c29bc47fe850f96 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include @@ -138,11 +139,48 @@ class TargetWrapper { static void StreamSync(const stream_t& stream) {} - static void* Malloc(size_t size) { return new char[size]; } - static void Free(void* ptr) { delete[] static_cast(ptr); } + static void* Malloc(size_t size) { + LOG(FATAL) << "Unimplemented malloc for " << TargetToStr(Target); + return nullptr; + } + static void Free(void* ptr) { LOG(FATAL) << "Unimplemented"; } + + static void MemcpySync(void* dst, const void* src, size_t size, + IoDirection dir) { + LOG(FATAL) << "Unimplemented"; + } + static void MemcpyAsync(void* dst, const void* src, size_t size, + IoDirection dir, const stream_t& stream) { + MemcpySync(dst, src, size, dir); + } +}; + +// This interface should be specified by each kind of target. +template <> +class TargetWrapper { + public: + using stream_t = int; + using event_t = int; + + static size_t num_devices() { return 0; } + static size_t maximum_stream() { return 0; } + + static void CreateStream(stream_t* stream) {} + static void DestroyStream(const stream_t& stream) {} + + static void CreateEvent(event_t* event) {} + static void DestroyEvent(const event_t& event) {} + + static void RecordEvent(const event_t& event) {} + static void SyncEvent(const event_t& event) {} + + static void StreamSync(const stream_t& stream) {} + + static void* Malloc(size_t size); + static void Free(void* ptr); static void MemcpySync(void* dst, const void* src, size_t size, - IoDirection dir) {} + IoDirection dir); static void MemcpyAsync(void* dst, const void* src, size_t size, IoDirection dir, const stream_t& stream) { MemcpySync(dst, src, size, dir); diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index 864e5db9ecd65c9ee79a8a58754459181761d50d..f78b90801396848e3f8cadb0695de27530677dd4 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -95,13 +95,15 @@ class Tensor { dims_ = other.dims_; target_ = other.target_; lod_ = other.lod_; + memory_size_ = other.memory_size_; } void CopyDataFrom(const Tensor& other) { dims_ = other.dims_; target_ = other.target_; lod_ = other.lod_; - *buffer_ = *other.buffer_; + memory_size_ = other.memory_size_; + buffer_->CopyDataFrom(*other.buffer_, memory_size_); } TargetType target() const { return target_; } diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index e9aabee51f2708b663a3ee5571327d01944febe8..c8c0c0f5c12191491634dd7907a33ecba6647014 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -38,9 +38,39 @@ namespace lite { // The DNN system is simple, and the architecture can not process that many data // types as a compiler, or that will turn out to a chaos. // -// We should make sure that supported data types should be registered here, and -// keep the quantity small. And avoid using some special data types as op's IO, -// such as some runtime cache, that need to be avoided. +// We should make sure that the supported data types be registered here, and +// keep the quantity small and avoid using some special data types as op's +// inputs or outputs, such as some runtime cache, those types can't be processed +// by the MIR. +// +// A tensor with different places(target, precision, data layout or device) +// should be treated as different types. Different types might be compatible +// with each other, for example, the `VoidTy` means any type, so any other types +// can be treated as a `VoidTy`. +// +// The Different Types can transform to others by adding some special +// transforming operators, for example, a DataLayoutTransformOp can convert a +// `TensorFp32NCHWTy` to a `TensorFp32NHWCTy`; a IoCopyOp can convert a +// `TensorFp32NCHWTy(kHost)` to `TensorFp32NCHWTy(kCUDA)`. There are many other +// convertions between different Types, but there are some unsupportted type +// convertions, for example, there is noway to convert a `UnsupportedTy` to a +// `TensorAnyTy`. +// +// We use Types to declare the definition of a kernel, each inputs' and outputs' +// arguments have a specific Types. +// +// REGISTER_LITE_KERNEL(mul, kHost, kFloat, +// paddle::lite::kernels::host::MulCompute, def) +// .BindInput("X", {paddle::lite::Type::Get( +// TARGET(kHost))}) +// .BindInput("Y", {paddle::lite::Type::Get( +// TARGET(kHost))}) +// .BindOutput("Out", +// {paddle::lite::Type::Get(TARGET(kHost))}) +// .Finalize(); +// +// The above definition will be used in MIR by Type inference and uncompatible +// types check. // // TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported // type mixed in the system. diff --git a/paddle/fluid/lite/cuda/target_wrapper.cc b/paddle/fluid/lite/cuda/target_wrapper.cc index 21df004aa91bf1e8fc0bc71f5cfe6c151c2288d0..cca9a95d8e077a265988d4f598601b52c292aef7 100644 --- a/paddle/fluid/lite/cuda/target_wrapper.cc +++ b/paddle/fluid/lite/cuda/target_wrapper.cc @@ -26,12 +26,14 @@ using TargetW = TargetWrapper; template <> void* TargetW::Malloc(size_t size) { - return new char[size]; + void* ptr{}; + CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size)); + return ptr; } template <> void TargetW::Free(void* ptr) { - delete[] static_cast(ptr); + CHECK_EQ(cudaSuccess, cudaFree(ptr)); } template <> diff --git a/paddle/fluid/lite/host/CMakeLists.txt b/paddle/fluid/lite/host/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..90812f3f3cd712571eb7f11261e23c8dcb78b0fe --- /dev/null +++ b/paddle/fluid/lite/host/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(target_wrapper_host SRCS target_wrapper.cc) diff --git a/paddle/fluid/lite/host/target_wrapper.cc b/paddle/fluid/lite/host/target_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e8393ef46689ef3b77fbb7945885809d7b70cea --- /dev/null +++ b/paddle/fluid/lite/host/target_wrapper.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/target_wrapper.h" +#include + +namespace paddle { +namespace lite { + +void* TargetWrapper::Malloc(size_t size) { + return new char[size]; +} +void TargetWrapper::Free(void* ptr) { + delete[] static_cast(ptr); +} +void TargetWrapper::MemcpySync(void* dst, const void* src, + size_t size, IoDirection dir) { + memcpy(dst, src, size); +} + +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 3f8015614bf76c8647712d6107050e760c2c7c30..d727f0c22ac3a8a5ba85517dceadc5ba192dfa18 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -28,6 +28,8 @@ class FeedCompute : public OpKernel { auto ¶m = Param(); const Tensor &feed_item = param.feed_list->at(param.col); param.out->CopyDataFrom(feed_item); + LOG(INFO) << "FEED input " << feed_item << " col " << param.col; + LOG(INFO) << "FEED output " << *param.out; } }; @@ -40,6 +42,6 @@ REGISTER_LITE_KERNEL(feed, kHost, kFloat, paddle::lite::kernels::host::FeedCompute, def) .BindInput("X", {paddle::lite::Type::Get( TARGET(kHost))}) - .BindOutput("Out", {paddle::lite::Type::Get( + .BindOutput("Out", {paddle::lite::Type::Get( TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/fetch_compute.cc b/paddle/fluid/lite/kernels/host/fetch_compute.cc index 4bc71266ed2bdcb9cba79ddb313f2139b6d95b78..b3193a01942c033d075683a74dde816d980657bf 100644 --- a/paddle/fluid/lite/kernels/host/fetch_compute.cc +++ b/paddle/fluid/lite/kernels/host/fetch_compute.cc @@ -32,7 +32,7 @@ class FetchCompute : public OpKernel { } auto& dst = fetch_list->at(param.col); - dst.CopyDataFrom(*param.input); + dst.ShareDataWith(*param.input); } }; diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index b2062fa9308a1db241ea1dfefbacaa78fcdb7d30..ee7b168503aff12e49b0d92f8f45dbe57d402388 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -40,22 +40,23 @@ class MulCompute : public OpKernel { using param_t = operators::MulParam; void Run() override { - auto& theparam = Param(); - core::dim2 x_shape( - {product(theparam.x->dims().begin(), - theparam.x->dims().begin() + theparam.x_num_col_dims), - product(theparam.x->dims().begin() + theparam.x_num_col_dims, - theparam.x->dims().end())}); + auto& param = Param(); + core::dim2 x_shape({product(param.x->dims().begin(), + param.x->dims().begin() + param.x_num_col_dims), + product(param.x->dims().begin() + param.x_num_col_dims, + param.x->dims().end())}); - core::dim2 y_shape( - {product(theparam.y->dims().begin(), - theparam.y->dims().begin() + theparam.x_num_col_dims), - product(theparam.y->dims().begin() + theparam.x_num_col_dims, - theparam.y->dims().end())}); + core::dim2 y_shape({product(param.y->dims().begin(), + param.y->dims().begin() + param.x_num_col_dims), + product(param.y->dims().begin() + param.x_num_col_dims, + param.y->dims().end())}); - mul_compute_eigen(theparam.x->data(), x_shape.x, x_shape.y, // - theparam.y->data(), y_shape.x, y_shape.y, // - theparam.output->mutable_data()); + mul_compute_eigen(param.x->data(), x_shape.x, x_shape.y, // + param.y->data(), y_shape.x, y_shape.y, // + param.output->mutable_data()); + LOG(INFO) << "MUL x " << *param.x; + LOG(INFO) << "MUL W " << *param.y; + LOG(INFO) << "MUL out " << *param.output; } virtual ~MulCompute() = default; diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc index ebf6f2ff4b3cbeb7d20fb48383b5db27142a72db..7ad6ab288183ece3043ba03f26f94da53cee7389 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -36,10 +36,10 @@ class ScaleCompute : public OpKernel { using param_t = operators::MulParam; void Run() override { - auto& theparam = Param(); - scale_compute(theparam.x->data(), theparam.x->mutable_data(), - product(theparam.x->dims()), theparam.scale, theparam.bias, - theparam.bias_after_scale); + auto& param = Param(); + scale_compute(param.x->data(), param.output->mutable_data(), + product(param.x->dims()), param.scale, param.bias, + param.bias_after_scale); } virtual ~ScaleCompute() = default; diff --git a/paddle/fluid/lite/model_parser/model_parser.cc b/paddle/fluid/lite/model_parser/model_parser.cc index 63316d34cd0b9e5023eebfeae5c2295e9692a6d1..feaff82d31354818559661d08b8371c603430e58 100644 --- a/paddle/fluid/lite/model_parser/model_parser.cc +++ b/paddle/fluid/lite/model_parser/model_parser.cc @@ -77,7 +77,7 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) { DO(INT64, int64_t); #undef DO default: - LOG(FATAL) << "unknown type"; + LOG(FATAL) << "unknown type " << desc.data_type(); } is.read(static_cast(buf), size); diff --git a/paddle/fluid/lite/operators/feed_op.cc b/paddle/fluid/lite/operators/feed_op.cc index 041fe07f6340d0701782a980799ba9408990ed92..47876b55e76f3b552eefecd6c0b000b03bc18919 100644 --- a/paddle/fluid/lite/operators/feed_op.cc +++ b/paddle/fluid/lite/operators/feed_op.cc @@ -51,7 +51,6 @@ class FeedOp : public OpLite { // NOTE need boost here // TODO(Superjomn) drop the need of framework::op_desc param_.col = boost::get(opdesc.GetAttr("col")); - kernel_->SetParam(param_); return true; } diff --git a/paddle/fluid/lite/utils/macros.h b/paddle/fluid/lite/utils/macros.h index 1115c71cd24620b624d93820c892436110ba8880..1861f20f839b822dbce68161552a7d2f05191d0d 100644 --- a/paddle/fluid/lite/utils/macros.h +++ b/paddle/fluid/lite/utils/macros.h @@ -21,3 +21,10 @@ #endif #define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented"; + +#ifndef LIKELY +#define LIKELY(x) __builtin_expect(!!(x), 1) +#endif +#ifndef UNLIKELY +#define UNLIKELY(x) __built_expect(!!(x), 0) +#endif