提交 f41d73b3 编写于 作者: S Superjomn

init mir

上级 a1b1feb4
......@@ -16,3 +16,5 @@ cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_executor_lite SRCS executor_test.cc DEPS executor_lite ops_lite host_kernels)
cc_test(test_type_system SRCS type_system_test.cc DEPS type_system)
add_subdirectory(mir)
......@@ -16,4 +16,4 @@
// Created by chunwei on 19-2-22.
//
#include "context.h"
#include "paddle/fluid/lite/core/context.h"
cc_library(mir_pass SRCS pass.cc)
cc_library(mir_node SRCS node.cc)
cc_library(mir_ssa_graph SRCS ssa_graph.cc)
\ No newline at end of file
#include "paddle/fluid/lite/core/mir/node.h"
namespace paddle {
namespace lite {
namespace mir {
class Node {
public:
// Tell is instruction.
bool IsInstruct() const;
// Tell is an argument.
bool IsArgument() const;
};
} // namespace mir
} // namespace lite
} // namespace paddle
\ No newline at end of file
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
......@@ -38,11 +38,17 @@ struct Place {
TargetType target{TARGET(kHost)};
PrecisionType precision{PRECISION(kFloat)};
DataLayoutType layout{DATALAYOUT(kNCHW)};
short device{0}; // device ID
Place() = default;
Place(TargetType target, PrecisionType precision,
DataLayoutType layout = DATALAYOUT(kNCHW))
: target(target), precision(precision), layout(layout) {}
DataLayoutType layout = DATALAYOUT(kNCHW), short device = 0)
: target(target), precision(precision), layout(layout), device(device) {}
bool operator==(const Place& other) const {
return target == other.target && precision == other.precision &&
layout == other.layout && device == other.device;
}
};
constexpr const int kNumPrecisions =
......
......@@ -24,7 +24,7 @@ std::ostream &operator<<(std::ostream &os, const DDim &dims) {
}
os << "[";
for (int i = 0; i < dims.size() - 1; i++) {
for (size_t i = 0; i < dims.size() - 1; i++) {
os << dims[i] << " ";
}
os << dims.back() << "]";
......
......@@ -31,6 +31,119 @@
namespace paddle {
namespace lite {
// Type is the definition of all the types that supported by the Variable that
// represents as the input and output of an operator or kernel.
// The DNN system is simple, and the architecture can not process that many data
// types as a compiler, or that will turn out to a chaos.
//
// We should make sure that supported data types should be registered here, and
// keep the quantity small. And avoid using some special data types as op's IO,
// such as some runtime cache, that need to be avoided.
//
// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported
// type mixed in the system.
class DataTypeBase {
public:
// The Void type can cast to any other type.
// The Unsupported is the data type that developed include in the system, for
// example, some `std::set` is used as input of some operator. It wan't be
// analyzed or optimized by the system, that way results in many bugs in
// previous system, so it should be avoided.
enum class ID : int {
Void = 0, // unknown type that can be cast to any data type.
Unsupported, // Unsupported data type that will not be analyzed.
Tensor_Fp32_NCHW,
Tensor_Int8_NCHW,
Tensor_Int64_NCHW,
NumTypes, // Must remains as last defined ID.
};
ID id() const { return id_; }
// type check.
bool IsTensor() const { return is_tensor_; }
bool IsVoid() const { return id_ == ID::Void; }
bool IsUnsupported() const { return id_ == ID::Unsupported; }
bool IsTensorFp32NCHW() const { return id_ == ID::Tensor_Fp32_NCHW; }
bool IsTensorInt8NCHW() const { return id_ == ID::Tensor_Int8_NCHW; }
bool IsTensorInt64NCHW() const { return id_ == ID::Tensor_Int64_NCHW; }
int num_types() const { return static_cast<int>(ID::NumTypes); }
protected:
// Can only extended by subclass.
DataTypeBase(ID id, bool is_tensor) : id_(id), is_tensor_(is_tensor) {}
ID id_{ID::Unsupported};
bool is_tensor_{false};
};
/*
* Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType.
*/
class DeviceDataType : public DataTypeBase {
public:
TargetType target() const { return place_.target; }
PrecisionType precision() const { return place_.precision; }
DataLayoutType layout() const { return place_.layout; }
const Place& place() const { return place_; }
const std::string& name() const { return name_; }
bool operator==(const DeviceDataType& other) {
return id_ == other.id() && place_ == other.place();
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another.
virtual bool TypeCastable(const DeviceDataType& type) const {
return id_ == type.id();
}
virtual ~DeviceDataType() = default;
protected:
DeviceDataType(ID id, const std::string& name, bool is_tensor,
TargetType target = TargetType::kHost,
PrecisionType precision = PrecisionType::kFloat,
DataLayoutType layout = DataLayoutType::kNCHW)
: DataTypeBase(id, is_tensor),
place_{target, precision, layout},
name_(name) {}
protected:
Place place_;
const std::string name_;
};
// -------------------------------- predefined types ---------------------------
class Void : public DeviceDataType {
public:
Void() : DeviceDataType(ID::Void, "Void", false /*is_tensor*/) {}
};
class TensorFp32NCHW : public DeviceDataType {
public:
TensorFp32NCHW(TargetType target)
: DeviceDataType(ID::Tensor_Fp32_NCHW, "TensorFp32NCHW",
true /*is_tensor*/, target, PrecisionType::kFloat,
DataLayoutType::kNCHW) {}
};
class TensorInt8NCHW : public DeviceDataType {
public:
TensorInt8NCHW(TargetType target)
: DeviceDataType(ID::Tensor_Int8_NCHW, "TensorInt8NCHW",
true /*is_tensor*/, target, PrecisionType::kInt8,
DataLayoutType::kNCHW) {}
};
class TensorInt64NCHW : public DeviceDataType {
public:
TensorInt64NCHW(TargetType target)
: DeviceDataType(ID::Tensor_Int64_NCHW, "TensorInt64NCHW",
true /*is_tensor*/, target, PrecisionType::kInt8,
DataLayoutType::kNCHW) {}
};
// ------------------------- end predefined types ---------------------------
// NOTE TypeSystem has some overhead, and better to be used in analysis phase.
class TypeSystem {
private:
......
......@@ -58,8 +58,8 @@ TEST(fc_op_lite, test) {
FcOpLite fc("fc");
fc.SetValidPlaces({OpLite::Place{TARGET(kHost), PRECISION(kFloat)}});
fc.PickKernel({OpLite::Place{TARGET(kHost), PRECISION(kFloat)}});
fc.SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.PickKernel({Place{TARGET(kHost), PRECISION(kFloat)}});
fc.Attach(desc, &scope);
fc.Run();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册