提交 9174652b 编写于 作者: S Superjomn

refactor TypeSystem

make typesystem simpler
上级 658c5432
...@@ -13,3 +13,5 @@ add_subdirectory(pybind) ...@@ -13,3 +13,5 @@ add_subdirectory(pybind)
add_subdirectory(train) add_subdirectory(train)
# NOTE: please add subdirectory inference at last. # NOTE: please add subdirectory inference at last.
add_subdirectory(inference) add_subdirectory(inference)
add_subdirectory(lite)
...@@ -80,6 +80,8 @@ class Tensor { ...@@ -80,6 +80,8 @@ class Tensor {
template <typename T> template <typename T>
const T* data() const; const T* data() const;
const void* raw_data() const { return holder_->ptr(); }
inline bool IsInitialized() const; inline bool IsInitialized() const;
/** /**
......
...@@ -24,6 +24,9 @@ ...@@ -24,6 +24,9 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/lite/utils/logging.h" #include "paddle/fluid/lite/utils/logging.h"
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include <glog/logging.h>
#endif
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -2,6 +2,11 @@ if (NOT WITH_LITE) ...@@ -2,6 +2,11 @@ if (NOT WITH_LITE)
return() return()
endif() endif()
message(WARNING "Enable Lite")
message(STATUS "LIGHT_FRAMEWORK: ${LITE_WITH_LIGHT_WEIGHT_FRAMEWORK}")
message(STATUS "LITE_WITH_CUDA: ${LITE_WITH_CUDA}")
message(STATUS "LITE_WITH_X86: ${LITE_WITH_X86}")
add_subdirectory(core) add_subdirectory(core)
add_subdirectory(x86) add_subdirectory(x86)
add_subdirectory(host) add_subdirectory(host)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
...@@ -45,14 +59,13 @@ void Run(const char* model_dir) { ...@@ -45,14 +59,13 @@ void Run(const char* model_dir) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
int main(int argc, char** argv ) { int main(int argc, char** argv) {
CHECK_EQ(argc, 2) << "usage: ./cmd <model_dir>"; CHECK_EQ(argc, 2) << "usage: ./cmd <model_dir>";
paddle::lite::Run(argv[1]); paddle::lite::Run(argv[1]);
return 0; return 0;
} }
USE_LITE_OP(mul); USE_LITE_OP(mul);
USE_LITE_OP(fc); USE_LITE_OP(fc);
USE_LITE_OP(scale); USE_LITE_OP(scale);
......
...@@ -13,7 +13,7 @@ else() ...@@ -13,7 +13,7 @@ else()
set(tensor_lite hvy_tensor) set(tensor_lite hvy_tensor)
endif() endif()
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto_lite SRCS framework.proto)
cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library(variable_lite SRCS variable.cc) cc_library(variable_lite SRCS variable.cc)
...@@ -22,7 +22,7 @@ cc_library(scope_lite SRCS scope.cc) ...@@ -22,7 +22,7 @@ cc_library(scope_lite SRCS scope.cc)
cc_library(context_lite SRCS context.cc DEPS any_lite) cc_library(context_lite SRCS context.cc DEPS any_lite)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite compatible_pb_lite)
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite}) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite) cc_library(kernel_executor_lite SRCS kernel_executor.cc DEPS mir_ssa_graph kernel_lite)
cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite) cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite)
...@@ -37,14 +37,14 @@ endif() ...@@ -37,14 +37,14 @@ endif()
cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph cc_library(program_fake_utils SRCS program_fake_utils.cc DEPS mir_ssa_graph
scope_lite op_registry_lite proto_desc op_lite scope_lite op_registry_lite proto_desc op_lite
ops_lite ${ops_lite}
host_kernels ${host_kernels}
) )
lite_cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) lite_cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
lite_cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrapper_x86) lite_cc_test(test_kernel_lite SRCS kernel_test.cc DEPS kernel_lite target_wrapper_x86)
lite_cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) lite_cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
lite_cc_test(test_tensor_lite SRCS lite_tensor_test.cc DEPS lite_tensor) lite_cc_test(test_tensor_lite SRCS lite_tensor_test.cc DEPS lite_tensor)
lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system) lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_lite)
lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes) lite_cc_test(test_optimizer_lite SRCS optimizer_test.cc DEPS mir_pass_manager program_fake_utils mir_passes)
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
...@@ -83,6 +83,8 @@ class TensorHvy : public TensorBase<TensorHvy> { ...@@ -83,6 +83,8 @@ class TensorHvy : public TensorBase<TensorHvy> {
return data_.data<T>(); return data_.data<T>();
} }
const void* raw_data() const { return data_.raw_data(); }
template <typename DimT> template <typename DimT>
void Resize(const DimT& dims) { void Resize(const DimT& dims) {
LOG(INFO) << "dims.size " << dims.size(); LOG(INFO) << "dims.size " << dims.size();
......
...@@ -22,15 +22,15 @@ endif() ...@@ -22,15 +22,15 @@ endif()
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_ssa_graph scope_lite op_lite mir_ssa_graph scope_lite op_lite
ops_lite ${ops_lite}
host_kernels ${host_kernels}
mir_passes mir_passes
mir_pass_manager mir_pass_manager
program_fake_utils program_fake_utils
) )
set(test_variable_place_infrence_pass_DEPS set(test_variable_place_infrence_pass_DEPS
ops_lite ${ops_lite}
host_kernels ${host_kernels}
mir_passes mir_passes
mir_pass_manager mir_pass_manager
optimizer_lite optimizer_lite
......
...@@ -60,7 +60,7 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -60,7 +60,7 @@ class RuntimeContextAssignPass : public StmtPass {
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
std::unique_ptr<KernelContext> NewCudaContext() { std::unique_ptr<KernelContext> NewCudaContext() {
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& cuda = ctx->AsCudaContext(); auto& cuda = ctx->As<CUDAContext>();
// Some initialization here. // Some initialization here.
CHECK(cublas_fp32_) << "cublas_fp32 should be set first"; CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
cuda.blas_fp32 = cublas_fp32_; cuda.blas_fp32 = cublas_fp32_;
......
...@@ -67,9 +67,8 @@ bool OpLite::Run() { ...@@ -67,9 +67,8 @@ bool OpLite::Run() {
bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) { bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) {
// valid_places_.clear(); // valid_places_.clear();
LOG(INFO) << "valid_places " << valid_places_.size();
CHECK(scope != nullptr); CHECK(scope != nullptr);
CHECK(!op_info_.get()); // CHECK(!op_info_.get());
scope_ = scope; scope_ = scope;
op_info_.reset(new OpInfo); // Force clean the out-of-date infomation. op_info_.reset(new OpInfo); // Force clean the out-of-date infomation.
op_info_->Build(opdesc.ReadonlyProto()); op_info_->Build(opdesc.ReadonlyProto());
......
...@@ -49,9 +49,7 @@ class OpInfo; ...@@ -49,9 +49,7 @@ class OpInfo;
class OpLite : public Registry { class OpLite : public Registry {
public: public:
OpLite() = default; OpLite() = default;
OpLite(const std::string &type) : op_type_(type) { OpLite(const std::string &type) : op_type_(type) {}
LOG(INFO) << "valid places " << valid_places_.size();
}
OpLite(const std::vector<Place> &valid_places) : valid_places_(valid_places) { OpLite(const std::vector<Place> &valid_places) : valid_places_(valid_places) {
LOG(INFO) << "valid places " << valid_places.size(); LOG(INFO) << "valid places " << valid_places.size();
} }
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/utils/all.h" #include "paddle/fluid/lite/utils/all.h"
using LiteType = paddle::lite::Type;
namespace paddle { namespace paddle {
namespace lite { namespace lite {
......
...@@ -106,7 +106,7 @@ struct Instruct { ...@@ -106,7 +106,7 @@ struct Instruct {
void Run() { void Run() {
CHECK(op_); CHECK(op_);
CHECK(kernel_); CHECK(kernel_);
if (UNLIKELY(first_epoch_)) { if (first_epoch_) {
first_epoch_ = false; first_epoch_ = false;
CHECK(op_->CheckShape()); CHECK(op_->CheckShape());
} }
......
...@@ -48,32 +48,27 @@ enum class DataLayoutType : int { ...@@ -48,32 +48,27 @@ enum class DataLayoutType : int {
// Some helper macro to get a specific TargetType. // Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__ #define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
// Some helper macro to get a specific PrecisionType. // Some helper macro to get a specific PrecisionType.
#define PRECISION(item__) paddle::lite::PrecisionType::item__ #define PRECISION(item__) paddle::lite::PrecisionType::item__
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__ #define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
constexpr const int kNumPrecisions = PRECISION_VAL(NUM);
constexpr const int kNumTargets = TARGET_VAL(NUM);
static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"any"};
static const std::string& TargetToStr(TargetType target) { static const std::string& TargetToStr(TargetType target) {
static const std::string target2string[] = {"unk", "host", "x86", "cuda",
"any"};
auto x = static_cast<int>(target); auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM))); CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x]; return target2string[x];
} }
static const std::string precision2string[] = {"unk", "float", "int8", "any"};
static const std::string& PrecisionToStr(PrecisionType precision) { static const std::string& PrecisionToStr(PrecisionType precision) {
static const std::string precision2string[] = {"unk", "float", "int8", "any"};
auto x = static_cast<int>(precision); auto x = static_cast<int>(precision);
CHECK_LT(x, static_cast<int>(PRECISION(NUM))); CHECK_LT(x, static_cast<int>(PRECISION(NUM)));
return precision2string[x]; return precision2string[x];
} }
static const std::string datalayout2string[] = {"unk", "NCHW", "any"};
static const std::string& DataLayoutToStr(DataLayoutType layout) { static const std::string& DataLayoutToStr(DataLayoutType layout) {
static const std::string datalayout2string[] = {"unk", "NCHW", "any"};
auto x = static_cast<int>(layout); auto x = static_cast<int>(layout);
CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM))); CHECK_LT(x, static_cast<int>(DATALAYOUT(NUM)));
return datalayout2string[x]; return datalayout2string[x];
......
...@@ -54,7 +54,7 @@ class DDimBase { ...@@ -54,7 +54,7 @@ class DDimBase {
value_type production() const { value_type production() const {
value_type res = 1; value_type res = 1;
for (int i = 0; i < const_self()->size(); i++) { for (size_t i = 0; i < const_self()->size(); i++) {
res *= (*const_self())[i]; res *= (*const_self())[i];
} }
return res; return res;
...@@ -142,6 +142,8 @@ class TensorBase { ...@@ -142,6 +142,8 @@ class TensorBase {
return const_self()->data(); return const_self()->data();
} }
const void *raw_data() const { return const_self()->data(); }
size_t data_size() const { return const_self()->dims().production(); } size_t data_size() const { return const_self()->dims().production(); }
void ShareDataWith(const TensorBase &other) { self()->ShareDataWith(other); } void ShareDataWith(const TensorBase &other) { self()->ShareDataWith(other); }
......
...@@ -13,14 +13,10 @@ ...@@ -13,14 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/type_system.h" #include "paddle/fluid/lite/core/type_system.h"
#include "type_system.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
// ------------------------- GetType specification ----------------------------
// ------------------------- end GetType specification ------------------------
size_t ParamTypeRegistry::KernelIdTy::hash() const { size_t ParamTypeRegistry::KernelIdTy::hash() const {
std::hash<std::string> h; std::hash<std::string> h;
size_t hash = h(kernel_type); size_t hash = h(kernel_type);
...@@ -31,26 +27,21 @@ size_t ParamTypeRegistry::KernelIdTy::hash() const { ...@@ -31,26 +27,21 @@ size_t ParamTypeRegistry::KernelIdTy::hash() const {
} }
std::ostream &operator<<(std::ostream &os, const Type &other) { std::ostream &operator<<(std::ostream &os, const Type &other) {
if (other.IsUnsupported()) { os << other.name();
os << "<Unsupported>";
return os;
}
if (other.IsVoid()) {
os << "<Void>";
return os;
}
if (other.IsTensor()) {
os << "<Tensor:";
} else {
os << "<";
}
os << TargetToStr(other.target()) << "/" << PrecisionToStr(other.precision())
<< "/" << DataLayoutToStr(other.layout()) << ">";
return os; return os;
} }
// An map is used to maintain a global repo for types. We don't use
// MACROs with static variables for that the TypeSystem should only used in
// compile time, that is not performance sensitive, and a map-based way is
// easier to implement and maintain.
//
// The map is declared in each Type::GetXXX method other than in the Type class
// so that it will force to construct before any usage.
const Type *Type::GetTensorTy(TargetType target, PrecisionType precision, const Type *Type::GetTensorTy(TargetType target, PrecisionType precision,
DataLayoutType layout, int device) { DataLayoutType layout, int device) {
static std::map<size_t, const Type *> type_repo;
// NOTE quite naive implementation here, but not performance sensitive. // NOTE quite naive implementation here, but not performance sensitive.
DataType::ID type_id = DataType::ID::Tensor; DataType::ID type_id = DataType::ID::Tensor;
...@@ -72,17 +63,16 @@ const Type *Type::GetTensorTy(TargetType target, PrecisionType precision, ...@@ -72,17 +63,16 @@ const Type *Type::GetTensorTy(TargetType target, PrecisionType precision,
name << device; name << device;
name << ">"; name << ">";
auto it = type_repo_.find(v); if (!type_repo[v])
if (it == type_repo_.end()) {
// The Types should alive across the process life, no need to delete. // The Types should alive across the process life, no need to delete.
type_repo_[v] = type_repo[v] =
new Type(type_id, name.str(), target, precision, layout, device); new Type(type_id, name.str(), target, precision, layout, device);
} return type_repo[v];
return type_repo_[v];
} }
const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision, const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision,
DataLayoutType layout, int device) { DataLayoutType layout, int device) {
static std::map<size_t, const Type *> type_repo;
DataType::ID type_id = DataType::ID::TensorList; DataType::ID type_id = DataType::ID::TensorList;
#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x))) #define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
...@@ -103,28 +93,50 @@ const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision, ...@@ -103,28 +93,50 @@ const Type *Type::GetTensorListTy(TargetType target, PrecisionType precision,
name << device; name << device;
name << ">"; name << ">";
if (!type_repo_[v]) if (!type_repo[v])
// The Types should alive across the process life, no need to delete. // The Types should alive across the process life, no need to delete.
type_repo_[v] = type_repo[v] =
new Type(type_id, name.str(), target, precision, layout, device); new Type(type_id, name.str(), target, precision, layout, device);
return type_repo_[v]; return type_repo[v];
} }
const Type *Type::GetUnsupportedTy() { const Type *Type::GetUnsupportedTy() {
static std::map<size_t, const Type *> type_repo;
std::hash<int> hasher; std::hash<int> hasher;
size_t v = hasher(static_cast<int>(DataType::ID::Unsupported)); size_t v = hasher(static_cast<int>(DataType::ID::Unsupported));
if (!type_repo_[v]) if (!type_repo[v])
type_repo_[v] = type_repo[v] =
new Type(DataType::ID::Unsupported, "Unsupported", TARGET(kUnk), new Type(DataType::ID::Unsupported, "Unsupported", TARGET(kUnk),
PRECISION(kUnk), DATALAYOUT(kUnk), -1); PRECISION(kUnk), DATALAYOUT(kUnk), -1);
return type_repo[v];
} }
const Type *Type::GetVoidTy() { const Type *Type::GetVoidTy() {
static std::map<size_t, const Type *> type_repo;
std::hash<int> hasher; std::hash<int> hasher;
size_t v = hasher(static_cast<int>(DataType::ID::Void)); size_t v = hasher(static_cast<int>(DataType::ID::Void));
if (!type_repo_[v]) if (!type_repo[v])
type_repo_[v] = new Type(DataType::ID::Void, "Void", TARGET(kAny), type_repo[v] = new Type(DataType::ID::Void, "Void", TARGET(kAny),
PRECISION(kAny), DATALAYOUT(kAny), -1); PRECISION(kAny), DATALAYOUT(kAny), -1);
return type_repo[v];
}
const Type *Type::Get(DataType::ID type_id, TargetType target,
PrecisionType precision, DataLayoutType layout,
int device) {
switch (type_id) {
case DataType::ID::Void:
return GetVoidTy();
case DataType::ID::Unsupported:
return GetUnsupportedTy();
case DataType::ID::Tensor:
return GetTensorTy(target, precision, layout, device);
case DataType::ID::TensorList:
return GetTensorListTy(target, precision, layout, device);
default:
LOG(FATAL) << "Unknown Type found";
return nullptr;
}
} }
} // namespace lite } // namespace lite
......
...@@ -133,6 +133,11 @@ class Type : public DataType { ...@@ -133,6 +133,11 @@ class Type : public DataType {
/// Get an Void type. /// Get an Void type.
static const Type* GetVoidTy(); static const Type* GetVoidTy();
static const Type* Get(DataType::ID type_id, TargetType target = TARGET(kUnk),
PrecisionType precision = PRECISION(kUnk),
DataLayoutType layout = DATALAYOUT(kUnk),
int device = 0);
TargetType target() const { return place_.target; } TargetType target() const { return place_.target; }
PrecisionType precision() const { return place_.precision; } PrecisionType precision() const { return place_.precision; }
DataLayoutType layout() const { return place_.layout; } DataLayoutType layout() const { return place_.layout; }
...@@ -154,12 +159,6 @@ class Type : public DataType { ...@@ -154,12 +159,6 @@ class Type : public DataType {
DataLayoutType layout = DataLayoutType::kNCHW, short device = 0) DataLayoutType layout = DataLayoutType::kNCHW, short device = 0)
: DataType(id), place_{target, precision, layout, device}, name_(name) {} : DataType(id), place_{target, precision, layout, device}, name_(name) {}
// An map is used here to maintain a global repo for types. We don't use
// MACROs with static variables for that the TypeSystem should only used in
// compile time, that is not performance sensitive, and a map-based way is
// easier to implement and maintain.
static std::map<size_t, const Type*> type_repo_;
Place place_; Place place_;
const std::string name_; const std::string name_;
}; };
...@@ -203,22 +202,6 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) { ...@@ -203,22 +202,6 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) {
PrecisionCompatibleTo(a, b) && DeviceCompatibleTo(a, b); PrecisionCompatibleTo(a, b) && DeviceCompatibleTo(a, b);
} }
// -------------------------------- predefined types ---------------------------
// TODO(Superjomn) make all the Types' constructs protected to make sure there
// is only one instance across the system.
class VoidTy : public Type {
public:
VoidTy() : Type(ID::Void, "Void") {}
};
class UnsupportedTy : public Type {
public:
UnsupportedTy() : Type(ID::Unsupported, "Unsupported", false /*is_tensor*/) {}
};
const Type* LookupType(DataType::ID type_id, bool is_unknown, bool is_tensor,
Place place);
// ------------------------- end predefined types ---------------------------
/* /*
* ParamType is used to represent a data type of a parameter for the kernel. It * ParamType is used to represent a data type of a parameter for the kernel. It
* can represent any Variable data type. * can represent any Variable data type.
...@@ -226,13 +209,12 @@ const Type* LookupType(DataType::ID type_id, bool is_unknown, bool is_tensor, ...@@ -226,13 +209,12 @@ const Type* LookupType(DataType::ID type_id, bool is_unknown, bool is_tensor,
* registered in the `TypeSystem`. * registered in the `TypeSystem`.
*/ */
struct ParamType { struct ParamType {
Place tensor_place{};
const Type* type; const Type* type;
ParamType() = default; ParamType() = default;
ParamType(const Type* type) : type(type) { tensor_place = type->place(); } ParamType(const Type* type) : type(type) {}
std::string DebugString() const { return tensor_place.DebugString(); } std::string DebugString() const { return type->name(); }
}; };
/* /*
......
...@@ -25,6 +25,10 @@ TEST(TypeSystem, CheckDuplicateGet) { ...@@ -25,6 +25,10 @@ TEST(TypeSystem, CheckDuplicateGet) {
Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); Type::GetTensorTy(TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
ASSERT_EQ(tensor_ty, tensor_ty1); ASSERT_EQ(tensor_ty, tensor_ty1);
ASSERT_EQ(tensor_ty->target(), TARGET(kHost));
ASSERT_EQ(tensor_ty->precision(), PRECISION(kFloat));
ASSERT_EQ(tensor_ty->layout(), DATALAYOUT(kNCHW));
} }
} // namespace lite } // namespace lite
......
...@@ -50,7 +50,7 @@ class IoCopyHostToCudaCompute ...@@ -50,7 +50,7 @@ class IoCopyHostToCudaCompute
param.x->target() == TARGET(kX86)); param.x->target() == TARGET(kX86));
LOG(INFO) << "copy size " << param.x->data_size(); LOG(INFO) << "copy size " << param.x->data_size();
auto* data = param.y->mutable_data<int8_t>(TARGET(kCUDA)); auto* data = param.y->mutable_data<int8_t>(TARGET(kCUDA));
CopyFromHostSync(data, param.x->data<int8_t>(), param.x->data_size()); CopyFromHostSync(data, param.x->raw_data(), param.x->data_size());
} }
std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override { std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override {
...@@ -63,8 +63,9 @@ class IoCopyHostToCudaCompute ...@@ -63,8 +63,9 @@ class IoCopyHostToCudaCompute
auto out_place = type->place(); auto out_place = type->place();
out_place.target = TARGET(kCUDA); out_place.target = TARGET(kCUDA);
auto* out_type = LookupType(type->id(), type->IsUnsupported(), auto* out_type =
type->IsUnsupported(), out_place); Type::Get(type->id(), out_place.target, out_place.precision,
out_place.layout, out_place.device);
return out_type; return out_type;
}; };
return res; return res;
...@@ -98,17 +99,13 @@ class IoCopyCudaToHostCompute ...@@ -98,17 +99,13 @@ class IoCopyCudaToHostCompute
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny,
paddle::lite::kernels::cuda::IoCopyHostToCudaCompute, paddle::lite::kernels::cuda::IoCopyHostToCudaCompute,
host_to_device) host_to_device)
.BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))})
TARGET(kHost))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kCUDA))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, REGISTER_LITE_KERNEL(io_copy, kCUDA, kAny, kAny,
paddle::lite::kernels::cuda::IoCopyCudaToHostCompute, paddle::lite::kernels::cuda::IoCopyCudaToHostCompute,
device_to_host) device_to_host)
.BindInput("Input", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("Input", {LiteType::GetTensorTy(TARGET(kCUDA))})
TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.Finalize(); .Finalize();
...@@ -25,10 +25,7 @@ namespace cuda {} // namespace cuda ...@@ -25,10 +25,7 @@ namespace cuda {} // namespace cuda
REGISTER_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, REGISTER_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW,
paddle::lite::kernels::cuda::MulCompute, def) paddle::lite::kernels::cuda::MulCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
TARGET(kCUDA))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
TARGET(kCUDA))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kCUDA))})
.Finalize(); .Finalize();
...@@ -36,7 +36,7 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { ...@@ -36,7 +36,7 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
void Run() override { void Run() override {
CHECK(context_) << "running context should be set first"; CHECK(context_) << "running context should be set first";
auto& context = context_->AsCudaContext(); auto& context = context_->As<CUDAContext>();
CHECK(context.blas_fp32) << "blas should init first"; CHECK(context.blas_fp32) << "blas should init first";
/* /*
auto& blas = *context.blas_fp32; auto& blas = *context.blas_fp32;
......
...@@ -53,13 +53,8 @@ void FcCompute::Run() { ...@@ -53,13 +53,8 @@ void FcCompute::Run() {
REGISTER_LITE_KERNEL(fc, kHost, kFloat, kNCHW, REGISTER_LITE_KERNEL(fc, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::FcCompute, def) paddle::lite::kernels::host::FcCompute, def)
.BindInput("Input", .BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))})
{paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))})
TARGET(kHost))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Bias", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
TARGET(kHost))})
.BindInput("W", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize(); .Finalize();
...@@ -43,8 +43,6 @@ class FeedCompute ...@@ -43,8 +43,6 @@ class FeedCompute
REGISTER_LITE_KERNEL(feed, kHost, kAny, kAny, REGISTER_LITE_KERNEL(feed, kHost, kAny, kAny,
paddle::lite::kernels::host::FeedCompute, def) paddle::lite::kernels::host::FeedCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
TARGET(kHost))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>(
TARGET(kHost))})
.Finalize(); .Finalize();
...@@ -44,8 +44,8 @@ class FetchCompute ...@@ -44,8 +44,8 @@ class FetchCompute
REGISTER_LITE_KERNEL(fetch, kHost, kAny, kAny, REGISTER_LITE_KERNEL(fetch, kHost, kAny, kAny,
paddle::lite::kernels::host::FetchCompute, def) paddle::lite::kernels::host::FetchCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorAnyTy>( .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
TARGET(kHost))}) DATALAYOUT(kAny), -1)})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorListAnyTy>( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kAny),
TARGET(kHost))}) DATALAYOUT(kAny), -1)})
.Finalize(); .Finalize();
...@@ -74,10 +74,7 @@ class MulCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -74,10 +74,7 @@ class MulCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(mul, kHost, kFloat, kNCHW, REGISTER_LITE_KERNEL(mul, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::MulCompute, def) paddle::lite::kernels::host::MulCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
TARGET(kHost))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Y", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize(); .Finalize();
...@@ -52,8 +52,6 @@ class ScaleCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> { ...@@ -52,8 +52,6 @@ class ScaleCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(scale, kHost, kFloat, kNCHW, REGISTER_LITE_KERNEL(scale, kHost, kFloat, kNCHW,
paddle::lite::kernels::host::ScaleCompute, def) paddle::lite::kernels::host::ScaleCompute, def)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>( .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))})
TARGET(kHost))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize(); .Finalize();
...@@ -3,7 +3,7 @@ lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_ ...@@ -3,7 +3,7 @@ lite_cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite)
else() else()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto proto_desc) cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto_lite proto_desc)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite set(model_parser_deps variable_lite scope_lite ${tensor_lite} scope_lite
......
cc_library(var_desc_lite SRCS var_desc.cc DEPS framework_proto) cc_library(var_desc_lite SRCS var_desc.cc DEPS framework_proto_lite)
cc_library(op_desc_lite SRCS op_desc.cc DEPS framework_proto) cc_library(op_desc_lite SRCS op_desc.cc DEPS framework_proto_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/utils/cp_logging.h" #include "paddle/fluid/lite/utils/cp_logging.h"
namespace paddle { namespace paddle {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/utils/logging.h" #include "paddle/fluid/lite/utils/logging.h"
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// LOG() // LOG()
#define LOG(status) LOG_##status.stream() #define LOG(status) LOG_##status.stream()
#define LOG_ERROR LOG_INFO #define LOG_ERROR LOG_INFO
...@@ -86,3 +88,5 @@ class LogMessageFatal : public LogMessage { ...@@ -86,3 +88,5 @@ class LogMessageFatal : public LogMessage {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
#endif // LITE_WITH_LIGHT_FRAMEWORK
...@@ -22,10 +22,13 @@ ...@@ -22,10 +22,13 @@
#define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented"; #define LITE_UNIMPLEMENTED CHECK(false) << "Not Implemented";
/*
#ifndef LIKELY #ifndef LIKELY
#define LIKELY(x) __builtin_expect(!!(x), 1) #define LIKELY(x) __builtin_expect(!!(x), 1)
#endif #endif
#ifndef UNLIKELY #ifndef UNLIKELY
//#define UNLIKELY(x) __built_expect(!!(x), 0) //#define UNLIKELY(x) __built_expect(!!(x), 0)
#define UNLIKELY(x) (x) #define UNLIKELY(x) (x)
#endif #endif
*/
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册