未验证 提交 6e4acdcd 编写于 作者: H hong19860320 提交者: GitHub

[LITE][CORE][NPU][XPU] Recording the data type when mutable_data() is called,...

[LITE][CORE][NPU][XPU] Recording the data type when mutable_data() is called, and supporting the type inference from the tensor for the op bridges (#2735)
上级 e84406a7
......@@ -139,6 +139,22 @@ class TensorLite {
// For other devices, T and R may be the same type.
template <typename T, typename R = T>
R *mutable_data() {
auto type_id = typeid(T).hash_code();
if (type_id == typeid(bool).hash_code()) { // NOLINT
precision_ = PrecisionType::kBool;
} else if (type_id == typeid(float).hash_code()) { // NOLINT
precision_ = PrecisionType::kFloat;
} else if (type_id == typeid(int8_t).hash_code()) {
precision_ = PrecisionType::kInt8;
} else if (type_id == typeid(int16_t).hash_code()) {
precision_ = PrecisionType::kInt16;
} else if (type_id == typeid(int32_t).hash_code()) {
precision_ = PrecisionType::kInt32;
} else if (type_id == typeid(int64_t).hash_code()) {
precision_ = PrecisionType::kInt64;
} else {
precision_ = PrecisionType::kUnk;
}
memory_size_ = dims_.production() * sizeof(T);
buffer_->ResetLazy(target_, memory_size_);
return reinterpret_cast<R *>(static_cast<char *>(buffer_->data()) +
......@@ -163,10 +179,7 @@ class TensorLite {
template <typename T, typename R = T>
R *mutable_data(TargetType target) {
target_ = target;
memory_size_ = dims_.production() * sizeof(T);
buffer_->ResetLazy(target, memory_size());
return reinterpret_cast<R *>(static_cast<char *>(buffer_->data()) +
offset_);
return mutable_data<T, R>();
}
void *mutable_data(size_t memory_size);
void *mutable_data(TargetType target, size_t memory_size);
......
......@@ -43,14 +43,14 @@ int Graph::Add(const std::string& name, std::shared_ptr<Node> node) {
std::shared_ptr<Node> Graph::Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
std::shared_ptr<Node> node = nullptr;
PrecisionType precision = tensor.precision();
if (tensor.persistable()) {
// Const node
node = Add<ge::op::Const>(name, precision, layout);
node->data<ge::op::Const>()->set_attr_value(
CvtTensor(tensor, shape, precision, layout));
CvtTensor(tensor, shape, layout));
} else {
// Data node
node = Add(name, shape, precision, layout);
......
......@@ -95,22 +95,19 @@ class Graph {
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, tensor.dims().Vectorize(), precision, layout);
return Add(name, tensor, tensor.dims().Vectorize(), layout);
}
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, dims.Vectorize(), precision, layout);
return Add(name, tensor, dims.Vectorize(), layout);
}
// Const node
......@@ -119,17 +116,6 @@ class Graph {
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T);
PrecisionType precision = PRECISION(kFloat);
if (info == typeid(float)) {
precision = PRECISION(kFloat);
} else if (info == typeid(int8_t)) {
precision = PRECISION(kFloat);
} else if (info == typeid(int32_t)) {
precision = PRECISION(kInt32);
} else {
LOG(FATAL) << "[NPU] Unknow data type " << info.name();
}
if (shape.empty()) {
shape = {static_cast<int64_t>(data.size())};
} else {
......@@ -145,7 +131,7 @@ class Graph {
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T));
return Add(name, tensor, precision, layout);
return Add(name, tensor, layout);
}
template <typename T>
......
......@@ -49,7 +49,8 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// Reshape node
auto reshape_node = graph->Add<ge::op::Reshape>(out_name);
auto reshape_node = graph->Add<ge::op::Reshape>(
out_name, x_node->precision(), x_node->layout());
auto reshape_op = reshape_node->data<ge::op::Reshape>();
reshape_op->set_input_tensor(*x_node->data());
......
......@@ -103,8 +103,8 @@ std::vector<int64_t> CvtShape(const DDim& in_dims) {
ge::TensorPtr CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape,
PrecisionType in_precision,
DataLayoutType in_layout) {
PrecisionType in_precision = in_tensor.precision();
auto in_size = in_tensor.dims().production();
auto in_shape = in_tensor.dims().Vectorize();
if (out_shape.empty()) {
......
......@@ -77,7 +77,6 @@ std::vector<int64_t> CvtShape(const DDim& in_dims);
ge::TensorPtr CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape = {},
PrecisionType in_precision = PRECISION(kFloat),
DataLayoutType in_layout = DATALAYOUT(kNCHW));
int CvtActMode(std::string act_type);
......
......@@ -66,27 +66,30 @@ int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
if (graph->Has(index_name)) {
index_node = graph->Get(index_name);
} else {
index_node = graph->Add(
index_name, *index, index_type->precision(), index_type->layout());
index_node = graph->Add(index_name, *index);
}
// Flatten index node
if (index_dims.size() != 1) {
index_node =
graph->Add(index_name + "/reshape",
graph->builder_.CreateReshape(*index_node->data(), {-1}),
index_type->precision(),
index_type->layout());
index_node->precision(),
index_node->layout());
}
// Reshape the gather node with the inferred shape as the output node
auto gather_node =
graph->Add(out_name,
graph->builder_.CreateGather(
*x_node->data(), *index_node->data(), /* axis= */ 0));
*x_node->data(), *index_node->data(), /* axis= */ 0),
x_node->precision(),
x_node->layout());
if (out_dims.size() != 2) {
graph->Add(out_name,
graph->builder_.CreateReshape(
*gather_node->data(), CvtShape<xtcl::Integer>(out_dims)));
graph->builder_.CreateReshape(*gather_node->data(),
CvtShape<xtcl::Integer>(out_dims)),
gather_node->precision(),
gather_node->layout());
}
return SUCCESS;
}
......
......@@ -57,9 +57,9 @@ std::shared_ptr<Node> Graph::Add(const std::string& name,
std::shared_ptr<Node> Graph::Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision,
DataLayoutType layout) {
std::shared_ptr<Node> node = nullptr;
PrecisionType precision = tensor.precision();
if (tensor.persistable()) {
// Const node
node = std::make_shared<Node>(precision, layout, Node::Role::kConst);
......@@ -67,8 +67,7 @@ std::shared_ptr<Node> Graph::Add(const std::string& name,
CHECK_EQ(idx, 1);
node->set_data(std::make_shared<xtcl::xExpr>(builder_.CreateTensor(
name, CvtShape<xtcl::xIndexExpr>(shape), CvtPrecisionType(precision))));
params_.emplace(
std::make_pair(name, *CvtTensor(tensor, shape, precision, layout)));
params_.emplace(std::make_pair(name, *CvtTensor(tensor, shape, layout)));
} else {
// Data node
node = Add(name, shape, precision, layout);
......
......@@ -79,22 +79,19 @@ class Graph {
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
std::vector<int64_t> shape,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW));
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, tensor.dims().Vectorize(), precision, layout);
return Add(name, tensor, tensor.dims().Vectorize(), layout);
}
std::shared_ptr<Node> Add(const std::string& name,
const Tensor& tensor,
DDim dims,
PrecisionType precision = PRECISION(kFloat),
DataLayoutType layout = DATALAYOUT(kNCHW)) {
return Add(name, tensor, dims.Vectorize(), precision, layout);
return Add(name, tensor, dims.Vectorize(), layout);
}
// Const node
......@@ -103,17 +100,6 @@ class Graph {
const std::vector<T>& data,
std::vector<int64_t> shape = {},
DataLayoutType layout = DATALAYOUT(kNCHW)) {
const std::type_info& info = typeid(T);
PrecisionType precision = PRECISION(kFloat);
if (info == typeid(float)) {
precision = PRECISION(kFloat);
} else if (info == typeid(int8_t)) {
precision = PRECISION(kFloat);
} else if (info == typeid(int32_t)) {
precision = PRECISION(kInt32);
} else {
LOG(FATAL) << "[XPU] Unknow data type " << info.name();
}
if (shape.empty()) {
shape = {static_cast<int64_t>(data.size())};
} else {
......@@ -129,7 +115,7 @@ class Graph {
std::memcpy(reinterpret_cast<uint8_t*>(tensor.mutable_data<T>()),
reinterpret_cast<const uint8_t*>(data.data()),
data.size() * sizeof(T));
return Add(name, tensor, precision, layout);
return Add(name, tensor, layout);
}
template <typename T>
......
......@@ -61,16 +61,15 @@ int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) {
if (graph->Has(ids_name)) {
ids_node = graph->Get(ids_name);
} else {
ids_node = graph->Add(
ids_name, ids_dims, ids_type->precision(), ids_type->layout());
ids_node = graph->Add(ids_name, *ids);
}
// Flatten Ids node
if (ids_dims.size() != 1) {
ids_node =
graph->Add(ids_name + "/reshape",
graph->builder_.CreateReshape(*ids_node->data(), {-1}),
ids_type->precision(),
ids_type->layout());
ids_node->precision(),
ids_node->layout());
}
// W node
......@@ -80,11 +79,15 @@ int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto gather_node =
graph->Add(out_name,
graph->builder_.CreateGather(
*w_node->data(), *ids_node->data(), /* axis= */ 0));
*w_node->data(), *ids_node->data(), /* axis= */ 0),
w_node->precision(),
w_node->layout());
if (out_dims.size() != 2) {
graph->Add(out_name,
graph->builder_.CreateReshape(
*gather_node->data(), CvtShape<xtcl::Integer>(out_dims)));
graph->builder_.CreateReshape(*gather_node->data(),
CvtShape<xtcl::Integer>(out_dims)),
gather_node->precision(),
gather_node->layout());
}
return SUCCESS;
}
......
......@@ -33,15 +33,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
// X node
std::shared_ptr<Node> x_node = nullptr;
......@@ -90,7 +84,9 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Reshape node
graph->Add(out_name,
graph->builder_.CreateReshape(*x_node->data(),
CvtShape<xtcl::Integer>(out_dims)));
CvtShape<xtcl::Integer>(out_dims)),
x_node->precision(),
x_node->layout());
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
......@@ -115,8 +115,8 @@ DLDeviceType CvtDLDeviceType(TargetType in_type) {
std::shared_ptr<xtcl::xNDArray> CvtTensor(const Tensor& in_tensor,
std::vector<int64_t> out_shape,
PrecisionType in_precision,
DataLayoutType in_layout) {
PrecisionType in_precision = in_tensor.precision();
auto in_shape = in_tensor.dims().Vectorize();
if (out_shape.empty()) {
out_shape = in_shape;
......
......@@ -58,7 +58,6 @@ xtcl::Array<T> CvtShape(const DDim& in_dims) {
std::shared_ptr<xtcl::xNDArray> CvtTensor(
const Tensor& in_tensor,
std::vector<int64_t> out_shape = {},
PrecisionType in_precision = PRECISION(kFloat),
DataLayoutType in_layout = DATALAYOUT(kNCHW));
} // namespace xpu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册