From 9174652bc0fc3710d2f7c097d0265db9b989943e Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 12 May 2019 14:48:25 +0800 Subject: [PATCH] refactor TypeSystem make typesystem simpler --- paddle/fluid/CMakeLists.txt | 2 + paddle/fluid/framework/tensor.h | 2 + paddle/fluid/inference/analysis/dot.h | 3 + paddle/fluid/lite/CMakeLists.txt | 5 ++ paddle/fluid/lite/api/cxx_api_bin.cc | 17 ++++- paddle/fluid/lite/core/CMakeLists.txt | 10 +-- paddle/fluid/lite/core/hvy_tensor.h | 2 + paddle/fluid/lite/core/mir/CMakeLists.txt | 8 +- .../core/mir/runtime_context_assign_pass.cc | 2 +- paddle/fluid/lite/core/op_lite.cc | 3 +- paddle/fluid/lite/core/op_lite.h | 4 +- paddle/fluid/lite/core/op_registry.h | 2 + paddle/fluid/lite/core/program.h | 2 +- paddle/fluid/lite/core/target_wrapper.h | 13 +--- paddle/fluid/lite/core/tensor.h | 4 +- paddle/fluid/lite/core/type_system.cc | 76 +++++++++++-------- paddle/fluid/lite/core/type_system.h | 32 ++------ paddle/fluid/lite/core/type_system_test.cc | 4 + .../lite/kernels/cuda/io_copy_compute.cc | 19 ++--- paddle/fluid/lite/kernels/cuda/mul_compute.cc | 9 +-- paddle/fluid/lite/kernels/cuda/mul_compute.h | 2 +- paddle/fluid/lite/kernels/host/fc_compute.cc | 13 +--- .../fluid/lite/kernels/host/feed_compute.cc | 6 +- .../fluid/lite/kernels/host/fetch_compute.cc | 8 +- paddle/fluid/lite/kernels/host/mul_compute.cc | 9 +-- .../fluid/lite/kernels/host/scale_compute.cc | 6 +- paddle/fluid/lite/model_parser/CMakeLists.txt | 2 +- .../fluid/lite/model_parser/pb/CMakeLists.txt | 4 +- paddle/fluid/lite/utils/cp_logging.cc | 14 ++++ paddle/fluid/lite/utils/cp_logging.h | 14 ++++ paddle/fluid/lite/utils/logging.h | 4 + paddle/fluid/lite/utils/macros.h | 3 + 32 files changed, 171 insertions(+), 133 deletions(-) diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index bf7bcfa5d4a..c212d579921 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -13,3 +13,5 @@ add_subdirectory(pybind) add_subdirectory(train) # NOTE: please add subdirectory inference at last. add_subdirectory(inference) + +add_subdirectory(lite) diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 0fa76f943ec..f83a1aa49d5 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -80,6 +80,8 @@ class Tensor { template const T* data() const; + const void* raw_data() const { return holder_->ptr(); } + inline bool IsInitialized() const; /** diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h index 83212869d1f..d1eef603be4 100644 --- a/paddle/fluid/inference/analysis/dot.h +++ b/paddle/fluid/inference/analysis/dot.h @@ -24,6 +24,9 @@ #include #include #include "paddle/fluid/lite/utils/logging.h" +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +#include +#endif namespace paddle { namespace inference { diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 5c09261e4dc..13d64ca0e20 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -2,6 +2,11 @@ if (NOT WITH_LITE) return() endif() +message(WARNING "Enable Lite") +message(STATUS "LIGHT_FRAMEWORK: ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}") +message(STATUS "LITE_WITH_CUDA: ${LITE_WITH_CUDA}") +message(STATUS "LITE_WITH_X86: ${LITE_WITH_X86}") + add_subdirectory(core) add_subdirectory(x86) add_subdirectory(host) diff --git a/paddle/fluid/lite/api/cxx_api_bin.cc b/paddle/fluid/lite/api/cxx_api_bin.cc index d19a95b34a9..4ff18ded849 100644 --- a/paddle/fluid/lite/api/cxx_api_bin.cc +++ b/paddle/fluid/lite/api/cxx_api_bin.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/op_registry.h" @@ -45,14 +59,13 @@ void Run(const char* model_dir) { } // namespace lite } // namespace paddle -int main(int argc, char** argv ) { +int main(int argc, char** argv) { CHECK_EQ(argc, 2) << "usage: ./cmd "; paddle::lite::Run(argv[1]); return 0; } - USE_LITE_OP(mul); USE_LITE_OP(fc); USE_LITE_OP(scale); diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 56a93cf2b1b..f039f92bf83 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -13,7 +13,7 @@ else() set(tensor_lite hvy_tensor) endif() -proto_library(framework_proto SRCS framework.proto) +proto_library(framework_proto_lite SRCS framework.proto) cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(variable_lite SRCS variable.cc) @@ -22,7 +22,7 @@ cc_library(scope_lite SRCS scope.cc) cc_library(context_lite SRCS context.cc DEPS any_lite) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite) cc_library(types_lite SRCS types.cc) -cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite}) +cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite) @@ -37,14 +37,14 @@ endif() cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph scope_lite op_registry_lite proto_desc op_lite - ops_lite - host_kernels + ${ops_lite} + ${host_kernels} ) lite_cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) lite_cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrapper_x86) lite_cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) lite_cc_test(test_tensor_lite SRCS lite_tensor_test.cc DEPS lite_tensor) -lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system) +lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_lite) lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes) lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) diff --git a/paddle/fluid/lite/core/hvy_tensor.h b/paddle/fluid/lite/core/hvy_tensor.h index 9fb2aeea3d6..5344c96e671 100644 --- a/paddle/fluid/lite/core/hvy_tensor.h +++ b/paddle/fluid/lite/core/hvy_tensor.h @@ -83,6 +83,8 @@ class TensorHvy : public TensorBase { return data_.data(); } + const void* raw_data() const { return data_.raw_data(); } + template void Resize(const DimT& dims) { LOG(INFO) << "dims.size " << dims.size(); diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 94f329aeda9..dd2428906d2 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -22,15 +22,15 @@ endif() cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS mir_ssa_graph scope_lite op_lite - ops_lite - host_kernels + ${ops_lite} + ${host_kernels} mir_passes mir_pass_manager program_fake_utils ) set(test_variable_place_infrence_pass_DEPS - ops_lite - host_kernels + ${ops_lite} + ${host_kernels} mir_passes mir_pass_manager optimizer_lite diff --git a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc index 67fa3c7ccbb..8e89d922e0b 100644 --- a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc +++ b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc @@ -60,7 +60,7 @@ class RuntimeContextAssignPass : public StmtPass { #ifdef LITE_WITH_CUDA std::unique_ptr NewCudaContext() { std::unique_ptr ctx(new KernelContext); - auto& cuda = ctx->AsCudaContext(); + auto& cuda = ctx->As(); // Some initialization here. CHECK(cublas_fp32_) << "cublas_fp32 should be set first"; cuda.blas_fp32 = cublas_fp32_; diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index a15bee1f1d7..dd159becc78 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -67,9 +67,8 @@ bool OpLite::Run() { bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) { // valid_places_.clear(); - LOG(INFO) << "valid_places " << valid_places_.size(); CHECK(scope != nullptr); - CHECK(!op_info_.get()); + // CHECK(!op_info_.get()); scope_ = scope; op_info_.reset(new OpInfo); // Force clean the out-of-date infomation. op_info_->Build(opdesc.ReadonlyProto()); diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 144198211e9..4fb3c1c7636 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -49,9 +49,7 @@ class OpInfo; class OpLite : public Registry { public: OpLite() = default; - OpLite(const std::string &type) : op_type_(type) { - LOG(INFO) << "valid places " << valid_places_.size(); - } + OpLite(const std::string &type) : op_type_(type) {} OpLite(const std::vector &valid_places) : valid_places_(valid_places) { LOG(INFO) << "valid places " << valid_places.size(); } diff --git a/paddle/fluid/lite/core/op_registry.h b/paddle/fluid/lite/core/op_registry.h index 6b8b0e03871..4de6df1c72b 100644 --- a/paddle/fluid/lite/core/op_registry.h +++ b/paddle/fluid/lite/core/op_registry.h @@ -21,6 +21,8 @@ #include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/utils/all.h" +using LiteType = paddle::lite::Type; + namespace paddle { namespace lite { diff --git a/paddle/fluid/lite/core/program.h b/paddle/fluid/lite/core/program.h index 2837e728417..585deeb33ed 100644 --- a/paddle/fluid/lite/core/program.h +++ b/paddle/fluid/lite/core/program.h @@ -106,7 +106,7 @@ struct Instruct { void Run() { CHECK(op_); CHECK(kernel_); - if (UNLIKELY(first_epoch_)) { + if (first_epoch_) { first_epoch_ = false; CHECK(op_->CheckShape()); } diff --git a/paddle/fluid/lite/core/target_wrapper.h b/paddle/fluid/lite/core/target_wrapper.h index 49b26a3f6fb..7afd6e1926d 100644 --- a/paddle/fluid/lite/core/target_wrapper.h +++ b/paddle/fluid/lite/core/target_wrapper.h @@ -48,32 +48,27 @@ enum class DataLayoutType : int { // Some helper macro to get a specific TargetType. #define TARGET(item__) paddle::lite::TargetType::item__ -#define TARGET_VAL(item__) static_cast(TARGET(item__)) // Some helper macro to get a specific PrecisionType. #define PRECISION(item__) paddle::lite::PrecisionType::item__ -#define PRECISION_VAL(item__) static_cast(PRECISION(item__)) #define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__ -constexpr const int kNumPrecisions = PRECISION_VAL(NUM); -constexpr const int kNumTargets = TARGET_VAL(NUM); - -static const std::string target2string[] = {"unk", "host", "x86", "cuda", - "any"}; static const std::string& TargetToStr(TargetType target) { + static const std::string target2string[] = {"unk", "host", "x86", "cuda", + "any"}; auto x = static_cast(target); CHECK_LT(x, static_cast(TARGET(NUM))); return target2string[x]; } -static const std::string precision2string[] = {"unk", "float", "int8", "any"}; static const std::string& PrecisionToStr(PrecisionType precision) { + static const std::string precision2string[] = {"unk", "float", "int8", "any"}; auto x = static_cast(precision); CHECK_LT(x, static_cast(PRECISION(NUM))); return precision2string[x]; } -static const std::string datalayout2string[] = {"unk", "NCHW", "any"}; static const std::string& DataLayoutToStr(DataLayoutType layout) { + static const std::string datalayout2string[] = {"unk", "NCHW", "any"}; auto x = static_cast(layout); CHECK_LT(x, static_cast(DATALAYOUT(NUM))); return datalayout2string[x]; diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index 59c35498e59..807fbfc6a62 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -54,7 +54,7 @@ class DDimBase { value_type production() const { value_type res = 1; - for (int i = 0; i < const_self()->size(); i++) { + for (size_t i = 0; i < const_self()->size(); i++) { res *= (*const_self())[i]; } return res; @@ -142,6 +142,8 @@ class TensorBase { return const_self()->data(); } + const void *raw_data() const { return const_self()->data(); } + size_t data_size() const { return const_self()->dims().production(); } void ShareDataWith(const TensorBase &other) { self()->ShareDataWith(other); } diff --git a/paddle/fluid/lite/core/type_system.cc b/paddle/fluid/lite/core/type_system.cc index 4c1ea9d729e..125d74e9f0a 100644 --- a/paddle/fluid/lite/core/type_system.cc +++ b/paddle/fluid/lite/core/type_system.cc @@ -13,14 +13,10 @@ // limitations under the License. #include "paddle/fluid/lite/core/type_system.h" -#include "type_system.h" namespace paddle { namespace lite { -// ------------------------- GetType specification ---------------------------- -// ------------------------- end GetType specification ------------------------ - size_t ParamTypeRegistry::KernelIdTy::hash() const { std::hash h; size_t hash = h(kernel_type); @@ -31,26 +27,21 @@ size_t ParamTypeRegistry::KernelIdTy::hash() const { } std::ostream &operator<<(std::ostream &os, const Type &other) { - if (other.IsUnsupported()) { - os << ""; - return os; - } - if (other.IsVoid()) { - os << ""; - return os; - } - if (other.IsTensor()) { - os << ""; + os << other.name(); return os; } +// An map is used to maintain a global repo for types. We don't use +// MACROs with static variables for that the TypeSystem should only used in +// compile time, that is not performance sensitive, and a map-based way is +// easier to implement and maintain. +// +// The map is declared in each Type::GetXXX method other than in the Type class +// so that it will force to construct before any usage. + const Type *Type::GetTensorTy(TargetType target, PrecisionType precision, DataLayoutType layout, int device) { + static std::map type_repo; // NOTE quite naive implementation here, but not performance sensitive. DataType::ID type_id = DataType::ID::Tensor; @@ -72,17 +63,16 @@ const Type *Type::GetTensorTy(TargetType target, PrecisionType precision, name << device; name << ">"; - auto it = type_repo_.find(v); - if (it == type_repo_.end()) { + if (!type_repo[v]) // The Types should alive across the process life, no need to delete. - type_repo_[v] = + type_repo[v] = new Type(type_id, name.str(), target, precision, layout, device); - } - return type_repo_[v]; + return type_repo[v]; } const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision, DataLayoutType layout, int device) { + static std::map type_repo; DataType::ID type_id = DataType::ID::TensorList; #define HASH_ONE(x) v = hash_combine(v, hasher(static_cast(x))) @@ -103,28 +93,50 @@ const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision, name << device; name << ">"; - if (!type_repo_[v]) + if (!type_repo[v]) // The Types should alive across the process life, no need to delete. - type_repo_[v] = + type_repo[v] = new Type(type_id, name.str(), target, precision, layout, device); - return type_repo_[v]; + return type_repo[v]; } const Type *Type::GetUnsupportedTy() { + static std::map type_repo; std::hash hasher; size_t v = hasher(static_cast(DataType::ID::Unsupported)); - if (!type_repo_[v]) - type_repo_[v] = + if (!type_repo[v]) + type_repo[v] = new Type(DataType::ID::Unsupported, "Unsupported", TARGET(kUnk), PRECISION(kUnk), DATALAYOUT(kUnk), -1); + return type_repo[v]; } const Type *Type::GetVoidTy() { + static std::map type_repo; std::hash hasher; size_t v = hasher(static_cast(DataType::ID::Void)); - if (!type_repo_[v]) - type_repo_[v] = new Type(DataType::ID::Void, "Void", TARGET(kAny), - PRECISION(kAny), DATALAYOUT(kAny), -1); + if (!type_repo[v]) + type_repo[v] = new Type(DataType::ID::Void, "Void", TARGET(kAny), + PRECISION(kAny), DATALAYOUT(kAny), -1); + return type_repo[v]; +} + +const Type *Type::Get(DataType::ID type_id, TargetType target, + PrecisionType precision, DataLayoutType layout, + int device) { + switch (type_id) { + case DataType::ID::Void: + return GetVoidTy(); + case DataType::ID::Unsupported: + return GetUnsupportedTy(); + case DataType::ID::Tensor: + return GetTensorTy(target, precision, layout, device); + case DataType::ID::TensorList: + return GetTensorListTy(target, precision, layout, device); + default: + LOG(FATAL) << "Unknown Type found"; + return nullptr; + } } } // namespace lite diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index 4ebcfbc2acf..9701fddccb5 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -133,6 +133,11 @@ class Type : public DataType { /// Get an Void type. static const Type* GetVoidTy(); + static const Type* Get(DataType::ID type_id, TargetType target = TARGET(kUnk), + PrecisionType precision = PRECISION(kUnk), + DataLayoutType layout = DATALAYOUT(kUnk), + int device = 0); + TargetType target() const { return place_.target; } PrecisionType precision() const { return place_.precision; } DataLayoutType layout() const { return place_.layout; } @@ -154,12 +159,6 @@ class Type : public DataType { DataLayoutType layout = DataLayoutType::kNCHW, short device = 0) : DataType(id), place_{target, precision, layout, device}, name_(name) {} - // An map is used here to maintain a global repo for types. We don't use - // MACROs with static variables for that the TypeSystem should only used in - // compile time, that is not performance sensitive, and a map-based way is - // easier to implement and maintain. - static std::map type_repo_; - Place place_; const std::string name_; }; @@ -203,22 +202,6 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) { PrecisionCompatibleTo(a, b) && DeviceCompatibleTo(a, b); } -// -------------------------------- predefined types --------------------------- -// TODO(Superjomn) make all the Types' constructs protected to make sure there -// is only one instance across the system. -class VoidTy : public Type { - public: - VoidTy() : Type(ID::Void, "Void") {} -}; -class UnsupportedTy : public Type { - public: - UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {} -}; - -const Type* LookupType(DataType::ID type_id, bool is_unknown, bool is_tensor, - Place place); -// ------------------------- end predefined types --------------------------- - /* * ParamType is used to represent a data type of a parameter for the kernel. It * can represent any Variable data type. @@ -226,13 +209,12 @@ const Type* LookupType(DataType::ID type_id, bool is_unknown, bool is_tensor, * registered in the `TypeSystem`. */ struct ParamType { - Place tensor_place{}; const Type* type; ParamType() = default; - ParamType(const Type* type) : type(type) { tensor_place = type->place(); } + ParamType(const Type* type) : type(type) {} - std::string DebugString() const { return tensor_place.DebugString(); } + std::string DebugString() const { return type->name(); } }; /* diff --git a/paddle/fluid/lite/core/type_system_test.cc b/paddle/fluid/lite/core/type_system_test.cc index b01aa0852ff..4e6814aa690 100644 --- a/paddle/fluid/lite/core/type_system_test.cc +++ b/paddle/fluid/lite/core/type_system_test.cc @@ -25,6 +25,10 @@ TEST(TypeSystem, CheckDuplicateGet) { Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); ASSERT_EQ(tensor_ty, tensor_ty1); + + ASSERT_EQ(tensor_ty->target(), TARGET(kHost)); + ASSERT_EQ(tensor_ty->precision(), PRECISION(kFloat)); + ASSERT_EQ(tensor_ty->layout(), DATALAYOUT(kNCHW)); } } // namespace lite diff --git a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc index 0a81afaf7b5..f395a042e22 100644 --- a/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc +++ b/paddle/fluid/lite/kernels/cuda/io_copy_compute.cc @@ -50,7 +50,7 @@ class IoCopyHostToCudaCompute param.x->target() == TARGET(kX86)); LOG(INFO) << "copy size " << param.x->data_size(); auto* data = param.y->mutable_data(TARGET(kCUDA)); - CopyFromHostSync(data, param.x->data(), param.x->data_size()); + CopyFromHostSync(data, param.x->raw_data(), param.x->data_size()); } std::unique_ptr GetTypeInferHandler() override { @@ -63,8 +63,9 @@ class IoCopyHostToCudaCompute auto out_place = type->place(); out_place.target = TARGET(kCUDA); - auto* out_type = LookupType(type->id(), type->IsUnsupported(), - type->IsUnsupported(), out_place); + auto* out_type = + Type::Get(type->id(), out_place.target, out_place.precision, + out_place.layout, out_place.device); return out_type; }; return res; @@ -98,17 +99,13 @@ class IoCopyCudaToHostCompute REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, paddle::lite::kernels::cuda::IoCopyHostToCudaCompute, host_to_device) - .BindInput("Input", {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindOutput("Out", {paddle::lite::Type::Get( - TARGET(kCUDA))}) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, paddle::lite::kernels::cuda::IoCopyCudaToHostCompute, device_to_host) - .BindInput("Input", {paddle::lite::Type::Get( - TARGET(kCUDA))}) - .BindOutput("Out", {paddle::lite::Type::Get( - TARGET(kHost))}) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.cc b/paddle/fluid/lite/kernels/cuda/mul_compute.cc index f5081d2baa9..7c88f1dd29e 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.cc +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.cc @@ -25,10 +25,7 @@ namespace cuda {} // namespace cuda REGISTER_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::MulCompute, def) - .BindInput("X", {paddle::lite::Type::Get( - TARGET(kCUDA))}) - .BindInput("Y", {paddle::lite::Type::Get( - TARGET(kCUDA))}) - .BindOutput("Out", {paddle::lite::Type::Get( - TARGET(kCUDA))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/cuda/mul_compute.h b/paddle/fluid/lite/kernels/cuda/mul_compute.h index c8f323e0537..597d8468326 100644 --- a/paddle/fluid/lite/kernels/cuda/mul_compute.h +++ b/paddle/fluid/lite/kernels/cuda/mul_compute.h @@ -36,7 +36,7 @@ class MulCompute : public KernelLite { void Run() override { CHECK(context_) << "running context should be set first"; - auto& context = context_->AsCudaContext(); + auto& context = context_->As(); CHECK(context.blas_fp32) << "blas should init first"; /* auto& blas = *context.blas_fp32; diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index 7b84720c803..ae5b23ce3ec 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -53,13 +53,8 @@ void FcCompute::Run() { REGISTER_LITE_KERNEL(fc, kHost, kFloat, kNCHW, paddle::lite::kernels::host::FcCompute, def) - .BindInput("Input", - {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindInput("Bias", {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindInput("W", {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindOutput("Out", {paddle::lite::Type::Get( - TARGET(kHost))}) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 02da6b2672f..ba503c577f4 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -43,8 +43,6 @@ class FeedCompute REGISTER_LITE_KERNEL(feed, kHost, kAny, kAny, paddle::lite::kernels::host::FeedCompute, def) - .BindInput("X", {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindOutput("Out", {paddle::lite::Type::Get( - TARGET(kHost))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/fetch_compute.cc b/paddle/fluid/lite/kernels/host/fetch_compute.cc index e23b540e644..8ecb38cae6c 100644 --- a/paddle/fluid/lite/kernels/host/fetch_compute.cc +++ b/paddle/fluid/lite/kernels/host/fetch_compute.cc @@ -44,8 +44,8 @@ class FetchCompute REGISTER_LITE_KERNEL(fetch, kHost, kAny, kAny, paddle::lite::kernels::host::FetchCompute, def) - .BindInput("X", {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindOutput("Out", {paddle::lite::Type::Get( - TARGET(kHost))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny), + DATALAYOUT(kAny), -1)}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc index 34ec07a1c67..2bb509c86ac 100644 --- a/paddle/fluid/lite/kernels/host/mul_compute.cc +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -74,10 +74,7 @@ class MulCompute : public KernelLite { REGISTER_LITE_KERNEL(mul, kHost, kFloat, kNCHW, paddle::lite::kernels::host::MulCompute, def) - .BindInput("X", {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindInput("Y", {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindOutput("Out", {paddle::lite::Type::Get( - TARGET(kHost))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc index 78281ba8a61..3fc542646ba 100644 --- a/paddle/fluid/lite/kernels/host/scale_compute.cc +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -52,8 +52,6 @@ class ScaleCompute : public KernelLite { REGISTER_LITE_KERNEL(scale, kHost, kFloat, kNCHW, paddle::lite::kernels::host::ScaleCompute, def) - .BindInput("X", {paddle::lite::Type::Get( - TARGET(kHost))}) - .BindOutput("Out", {paddle::lite::Type::Get( - TARGET(kHost))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) .Finalize(); diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 55fccf996dd..5a0c4e92977 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -3,7 +3,7 @@ lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite) else() - cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto proto_desc) + cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto_lite proto_desc) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite diff --git a/paddle/fluid/lite/model_parser/pb/CMakeLists.txt b/paddle/fluid/lite/model_parser/pb/CMakeLists.txt index d0f1af4cad2..22d88aeabf4 100644 --- a/paddle/fluid/lite/model_parser/pb/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/pb/CMakeLists.txt @@ -1,2 +1,2 @@ -cc_library(var_desc_lite SRCS var_desc.cc DEPS framework_proto) -cc_library(op_desc_lite SRCS op_desc.cc DEPS framework_proto) +cc_library(var_desc_lite SRCS var_desc.cc DEPS framework_proto_lite) +cc_library(op_desc_lite SRCS op_desc.cc DEPS framework_proto_lite) diff --git a/paddle/fluid/lite/utils/cp_logging.cc b/paddle/fluid/lite/utils/cp_logging.cc index d72cf0d782f..ef99cb217b8 100644 --- a/paddle/fluid/lite/utils/cp_logging.cc +++ b/paddle/fluid/lite/utils/cp_logging.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/fluid/lite/utils/cp_logging.h" namespace paddle { diff --git a/paddle/fluid/lite/utils/cp_logging.h b/paddle/fluid/lite/utils/cp_logging.h index 885670cc28b..3fac352c1b2 100644 --- a/paddle/fluid/lite/utils/cp_logging.h +++ b/paddle/fluid/lite/utils/cp_logging.h @@ -1,3 +1,17 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #include "paddle/fluid/lite/utils/logging.h" #else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK diff --git a/paddle/fluid/lite/utils/logging.h b/paddle/fluid/lite/utils/logging.h index df25fec2f64..37294969920 100644 --- a/paddle/fluid/lite/utils/logging.h +++ b/paddle/fluid/lite/utils/logging.h @@ -22,6 +22,8 @@ #include #include +#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + // LOG() #define LOG(status) LOG_##status.stream() #define LOG_ERROR LOG_INFO @@ -86,3 +88,5 @@ class LogMessageFatal : public LogMessage { } // namespace lite } // namespace paddle + +#endif // LITE_WITH_LIGHT_FRAMEWORK diff --git a/paddle/fluid/lite/utils/macros.h b/paddle/fluid/lite/utils/macros.h index 9dea37199b5..d12ad4cab13 100644 --- a/paddle/fluid/lite/utils/macros.h +++ b/paddle/fluid/lite/utils/macros.h @@ -22,10 +22,13 @@ #define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented"; +/* #ifndef LIKELY #define LIKELY(x) __builtin_expect(!!(x), 1) #endif + #ifndef UNLIKELY //#define UNLIKELY(x) __built_expect(!!(x), 0) #define UNLIKELY(x) (x) #endif + */ -- GitLab