提交 a1b1feb4 编写于 作者: S superjomn

init type system

上级 bb185125
cc_library(memory_lite SRCS memory.cc) cc_library(memory_lite SRCS memory.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite) cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(kernel_lite SRCS kernel.cc)
cc_library(variable_lite SRCS variable.cc) cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc) cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc) cc_library(scope_lite SRCS scope.cc)
...@@ -7,9 +8,11 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite) ...@@ -7,9 +8,11 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite)
cc_library(executor_lite SRCS executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite cc_library(executor_lite SRCS executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite
#TODO(Superjomn) remove these dependencies from original framework #TODO(Superjomn) remove these dependencies from original framework
proto_desc) proto_desc)
cc_library(type_system SRCS type_system.cc DEPS tensor_lite)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite) cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86) cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite) cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test(test_tensor_lite SRCS tensor_test.cc) cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_executor_lite SRCS executor_test.cc DEPS executor_lite ops_lite host_kernels) cc_test(test_executor_lite SRCS executor_test.cc DEPS executor_lite ops_lite host_kernels)
cc_test(test_type_system SRCS type_system_test.cc DEPS type_system)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/kernel.h"
namespace paddle {
namespace lite {
bool operator==(const Place &a, const Place &b) {
return a.target == b.target && a.precision == b.precision &&
a.layout == b.layout;
}
bool operator<(const Place &a, const Place &b) {
if (a.target != b.target)
return a.target < b.target;
else if (a.precision != b.precision)
return a.precision < b.precision;
else if (a.layout != b.layout)
return a.layout < b.layout;
return true;
}
bool ParamTypeRegistry::KeyCmp::operator()(
const ParamTypeRegistry::key_t &a,
const ParamTypeRegistry::key_t &b) const {
if (a.kernel_type != b.kernel_type)
return a.kernel_type < b.kernel_type;
else if (a.io != b.io)
return a.io < b.io;
else if (a.offset != b.offset)
return a.offset < b.offset;
else if (!(a.place == b.place)) {
return a.place < b.place;
}
return true;
}
} // namespace lite
} // namespace paddle
\ No newline at end of file
...@@ -59,10 +59,125 @@ class KernelBase { ...@@ -59,10 +59,125 @@ class KernelBase {
mutable operators::param_t param_; mutable operators::param_t param_;
}; };
/*
* ParamType is used to represent a data type of a parameter for the kernel. It
* can represent any Variable data type.
* The element_type_hash is the hash code of the element, it should be
* registered in the `TypeSystem`.
*/
struct ParamType {
size_t element_type_hash{};
Place tensor_place{};
ParamType() = default;
ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
};
/*
* The data types of kernel parameters.
*/
struct ParamTypes {
std::vector<std::vector<ParamType>> inputs;
std::vector<std::vector<ParamType>> outputs;
void RegisterInputType(int offset, const ParamType& type) {
Register(&inputs, offset, type);
}
void RegisterOutputType(int offset, const ParamType& type) {
Register(&outputs, offset, type);
}
private:
void Register(std::vector<std::vector<ParamType>>* ts, int offset,
ParamType type) {
CHECK_GE(offset, 0) << "invalid offset";
CHECK_GE(offset, 50) << "invalid offset";
for (size_t i = 0; i < offset - inputs.size() + 1; i++) {
ts->emplace_back();
}
ts->at(offset).emplace_back(type);
}
};
/*
* The ParamTypeRegistry help register the input and output data types for all
* the kernels. It is made singleton so that all the objects of the same kernel
* can share the same information.
*
* Usage:
* for register a kernel for FC operator.
* ParamTypeRegistry::Global().Register(
* "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0,
* {typeid(Tensor), {TARGET(kCUDA)}});
*/
class ParamTypeRegistry {
public:
template <TargetType target, PrecisionType precision,
DataLayoutType layout = DataLayoutType::kNCHW>
/*
* Helper class for registering a ParamType for a Kernel.
* Usage:
*
* NewInstance<TARGET(kHost), PRECISION(kFloat)>("fc")
* .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)})
* .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost),
* PRECISION(kFloat)});
*/
struct NewInstance {
NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {}
NewInstance& BindInput(int offset, const ParamType& ptype) {
ParamTypeRegistry::Global().Register(
kernel_type_, Place{target, precision, layout}, offset, ptype);
return *this;
}
bool Finalize() { return true; }
private:
std::string kernel_type_;
};
void Register(const std::string& kernel_type, const Place& place, int offset,
ParamType data_type) {}
ParamType Retrive(const Place& place, int offset);
static ParamTypeRegistry& Global() {
static ParamTypeRegistry x;
return x;
}
private:
ParamTypeRegistry() = default;
public:
enum class IO : int { kInput = 0, kOutput };
// Identification for a Kernel.
struct KernelIdT {
std::string kernel_type;
Place place;
IO io;
int offset;
};
using key_t = KernelIdT;
struct KeyCmp {
bool operator()(const key_t& a, const key_t& b) const;
};
private:
std::map<key_t, ParamType, ParamTypeRegistry::KeyCmp> types_;
};
// Light-weight kernel implementation. // Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target // The OpKernel is designed to implement the specific algorithm on a target
// device. // device.
template <TargetType Target, PrecisionType Precision> template <TargetType Target, PrecisionType Precision,
DataLayoutType DataLayout = DataLayoutType::kNCHW>
class OpKernel : public KernelBase { class OpKernel : public KernelBase {
public: public:
// Set runtime context. // Set runtime context.
...@@ -74,6 +189,8 @@ class OpKernel : public KernelBase { ...@@ -74,6 +189,8 @@ class OpKernel : public KernelBase {
TargetType target() const override { return Target; } TargetType target() const override { return Target; }
PrecisionType precision() const override { return Precision; } PrecisionType precision() const override { return Precision; }
void Touch() {}
OpKernel() = default; OpKernel() = default;
virtual ~OpKernel() = default; virtual ~OpKernel() = default;
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace lite { namespace lite {
std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
const std::vector<OpLite::Place> &places, const std::string &kernel_type) { const std::vector<Place> &places, const std::string &kernel_type) {
std::vector<std::unique_ptr<KernelBase>> kernels; std::vector<std::unique_ptr<KernelBase>> kernels;
CHECK(!op_type_.empty()) << "op_type_ should be set first"; CHECK(!op_type_.empty()) << "op_type_ should be set first";
...@@ -33,7 +33,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( ...@@ -33,7 +33,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
return kernels; return kernels;
} }
void OpLite::PickKernel(const std::vector<OpLite::Place> &valid_places, void OpLite::PickKernel(const std::vector<Place> &valid_places,
OpLite::KernelStrategy kernel_strategy) { OpLite::KernelStrategy kernel_strategy) {
switch (kernel_strategy) { switch (kernel_strategy) {
case KernelStrategy::kStatic: case KernelStrategy::kStatic:
......
...@@ -57,14 +57,6 @@ class OpLite : public Registry { ...@@ -57,14 +57,6 @@ class OpLite : public Registry {
kRuntime, kRuntime,
}; };
struct Place {
TargetType target{TARGET(kHost)};
PrecisionType precision{PRECISION(kFloat)};
Place(TargetType target, PrecisionType precision)
: target(target), precision(precision) {}
};
OpLite() = default; OpLite() = default;
OpLite(const std::string &type) : op_type_(type) {} OpLite(const std::string &type) : op_type_(type) {}
OpLite(std::unique_ptr<OpContext> &&x, const std::vector<Place> &valid_places) OpLite(std::unique_ptr<OpContext> &&x, const std::vector<Place> &valid_places)
...@@ -119,8 +111,7 @@ class OpLite : public Registry { ...@@ -119,8 +111,7 @@ class OpLite : public Registry {
// Create all the kernels for the valid targets. // Create all the kernels for the valid targets.
std::vector<std::unique_ptr<KernelBase>> CreateKernels( std::vector<std::unique_ptr<KernelBase>> CreateKernels(
const std::vector<OpLite::Place> &places, const std::vector<Place> &places, const std::string &kernel_type = "");
const std::string &kernel_type = "");
protected: protected:
std::unique_ptr<OpContext> op_context_; std::unique_ptr<OpContext> op_context_;
......
...@@ -43,6 +43,7 @@ std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type, ...@@ -43,6 +43,7 @@ std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type,
} }
#undef CREATE_KERNEL #undef CREATE_KERNEL
return nullptr;
} }
KernelRegistry::KernelRegistry() { KernelRegistry::KernelRegistry() {
......
...@@ -161,7 +161,10 @@ class KernelRegistor : public lite::Registor<KernelType> { ...@@ -161,7 +161,10 @@ class KernelRegistor : public lite::Registor<KernelType> {
int touch_##op_type__##target__##precision__() { \ int touch_##op_type__##target__##precision__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \ LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \
return 0; \ return 0; \
} } \
static bool op_type__##target__##precision__##param_register \
__attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
TARGET(target__), PRECISION(precision__)>(#op_type__)
#define USE_LITE_KERNEL(op_type__, target__, precision__) \ #define USE_LITE_KERNEL(op_type__, target__, precision__) \
extern int touch_##op_type__##target__##precision__(); \ extern int touch_##op_type__##target__##precision__(); \
......
...@@ -18,41 +18,53 @@ ...@@ -18,41 +18,53 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
enum class TargetType { kHost = 0, kX86, kCUDA, kLastAsPlaceHolder }; enum class TargetType : int { kHost = 0, kX86, kCUDA, kLastAsPlaceHolder };
enum class PrecisionType : int { kFloat = 0, kInt8, kLastAsPlaceHolder };
enum class DataLayoutType : int { kNCHW = 0, kLastAsPlaceHolder };
// Some helper macro to get a specific TargetType. // Some helper macro to get a specific TargetType.
#define TARGET(item__) paddle::lite::TargetType::item__ #define TARGET(item__) paddle::lite::TargetType::item__
#define TARGET_VAL(item__) static_cast<int>(TARGET(item__)) #define TARGET_VAL(item__) static_cast<int>(TARGET(item__))
// Some helper macro to get a specific PrecisionType.
constexpr int kNumTargets = TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost); #define PRECISION(item__) paddle::lite::PrecisionType::item__
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
#define DATALAYOUT(item__) paddle::lite::DataLayoutType::item__
/* /*
template <TargetType target> * Place specifies the execution context of a Kernel or input/output for a
struct Target {}; * kernel. It is used to make the analysis of the MIR more clear and accurate.
using Host = Target<TargetType::kHost>;
using X86 = Target<TargetType::kX86>;
using CUDA = Target<TargetType::kCUDA>;
using ARM = Target<TargetType::kARM>;
*/ */
struct Place {
TargetType target{TARGET(kHost)};
PrecisionType precision{PRECISION(kFloat)};
DataLayoutType layout{DATALAYOUT(kNCHW)};
Place() = default;
Place(TargetType target, PrecisionType precision,
DataLayoutType layout = DATALAYOUT(kNCHW))
: target(target), precision(precision), layout(layout) {}
};
enum class PrecisionType { kFloat = 0, kInt8, kLastAsPlaceHolder }; constexpr const int kNumPrecisions =
// Some helper macro to get a specific PrecisionType.
#define PRECISION(item__) paddle::lite::PrecisionType::item__
#define PRECISION_VAL(item__) static_cast<int>(PRECISION(item__))
constexpr int kNumPrecisions =
PRECISION_VAL(kLastAsPlaceHolder) - PRECISION_VAL(kFloat); PRECISION_VAL(kLastAsPlaceHolder) - PRECISION_VAL(kFloat);
constexpr const int kNumTargets =
TARGET_VAL(kLastAsPlaceHolder) - TARGET_VAL(kHost);
static const std::string target2string[] = {"host", "x86", "cuda"}; static const std::string target2string[] = {"host", "x86", "cuda"};
static const std::string& TargetToStr(TargetType target) { static const std::string& TargetToStr(TargetType target) {
return target2string[static_cast<int>(target)]; return target2string[static_cast<int>(target)];
} }
static const std::string precision2string[] = {"float, int8"}; static const std::string precision2string[] = {"float", "int8"};
static const std::string& PrecisionToStr(PrecisionType precision) { static const std::string& PrecisionToStr(PrecisionType precision) {
return precision2string[static_cast<int>(precision)]; return precision2string[static_cast<int>(precision)];
} }
static const std::string datalayout2string[] = {"NCHW"};
static const std::string& DataLayoutToStr(DataLayoutType x) {
return datalayout2string[static_cast<int>(x)];
}
// Event sync for multi-stream devices like CUDA and OpenCL. // Event sync for multi-stream devices like CUDA and OpenCL.
// For the devices without support of stream, leave it empty. // For the devices without support of stream, leave it empty.
template <TargetType Target> template <TargetType Target>
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/type_system.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// This file contains the file system of the lite system. Every data type in
// Variable should be registered here, and the analysis phase will check the
// data type correction.
// This mechanism is made for keeping our system simpler and more stable, for
// the dubious typed Variables in the Operators' inputs and outputs are disaster
// for analysis and runtime.
#include <glog/logging.h>
#include <string>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/lite/core/tensor.h"
namespace paddle {
namespace lite {
// NOTE TypeSystem has some overhead, and better to be used in analysis phase.
class TypeSystem {
private:
// Put all valid types for Variables here!
TypeSystem() {
// Tensor is a valid data type for Variable.
Register<Tensor>("tensor");
}
public:
static TypeSystem& Global() {
static TypeSystem x;
return x;
}
template <typename T>
void Register(const std::string& type) {
size_t hash = typeid(T).hash_code();
CHECK(!types_.count(hash)) << "duplicate register type " << type
<< " found!";
types_[hash] = type;
names_.insert(type);
}
template <typename T>
bool Contains() const {
return types_.count(typeid(T).hash_code());
}
bool Contains(size_t hash) const { return types_.count(hash); }
bool Contains(const std::string& type) { return names_.count(type); }
std::string DebugInfo() const {
std::stringstream ss;
for (const auto& it : types_) {
ss << it.second << "\n";
}
return ss.str();
}
private:
std::unordered_map<size_t /*hash*/, std::string /*name*/> types_;
TypeSystem(const TypeSystem&) = delete;
std::unordered_set<std::string> names_;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/type_system.h"
#include <gtest/gtest.h>
namespace paddle {
namespace lite {
TEST(TypeSystem, test) {
ASSERT_TRUE(TypeSystem::Global().Contains<lite::Tensor>());
}
TEST(TypeSystem, register_new) {
TypeSystem::Global().Register<int>("int32");
ASSERT_TRUE(TypeSystem::Global().Contains<int>());
ASSERT_TRUE(TypeSystem::Global().Contains(typeid(int).hash_code()));
ASSERT_TRUE(TypeSystem::Global().Contains("int32"));
}
} // namespace lite
} // namespace paddle
...@@ -50,7 +50,7 @@ namespace paddle { ...@@ -50,7 +50,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace cuda { namespace cuda {
const char* CublasErrorInfo(int error) { static const char* CublasErrorInfo(int error) {
switch (error) { switch (error) {
#define LITE_CUBLAS_ERROR_INFO(xx) \ #define LITE_CUBLAS_ERROR_INFO(xx) \
case xx: \ case xx: \
......
...@@ -23,9 +23,6 @@ namespace host { ...@@ -23,9 +23,6 @@ namespace host {
// NOTE should use pure std C++ implementation. // NOTE should use pure std C++ implementation.
void FcCompute::Run() { void FcCompute::Run() {
using matrix_t = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
using matrix_map_t = Eigen::Map<matrix_t>;
auto& param = this->param<operators::FcParam>(); auto& param = this->param<operators::FcParam>();
CHECK_GE(param.input->dims().size(), 2UL); CHECK_GE(param.input->dims().size(), 2UL);
...@@ -53,4 +50,7 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } ...@@ -53,4 +50,7 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute); REGISTER_LITE_KERNEL(fc, kHost, kFloat, paddle::lite::kernels::host::FcCompute)
.BindInput(0, {typeid(paddle::lite::Tensor).hash_code(),
paddle::lite::Place{TARGET(kHost), PRECISION(kFloat)}})
.Finalize();
...@@ -67,4 +67,5 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -67,4 +67,5 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(mul, kHost, kFloat, REGISTER_LITE_KERNEL(mul, kHost, kFloat,
paddle::lite::kernels::host::MulCompute); paddle::lite::kernels::host::MulCompute)
.Finalize();
...@@ -43,4 +43,5 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -43,4 +43,5 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(relu, kHost, kFloat, REGISTER_LITE_KERNEL(relu, kHost, kFloat,
paddle::lite::kernels::host::ReluCompute); paddle::lite::kernels::host::ReluCompute)
.Finalize();
...@@ -51,4 +51,5 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -51,4 +51,5 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL(scale, kHost, kFloat, REGISTER_LITE_KERNEL(scale, kHost, kFloat,
paddle::lite::kernels::host::ScaleCompute); paddle::lite::kernels::host::ScaleCompute)
.Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册