提交 9553fab2 编写于 作者: D dingminghui 提交者: jackzhang235

feat(mlu): support NCHW node

上级 8d9cc823
......@@ -14,6 +14,7 @@
#include "lite/core/mir/mlu_postprocess_pass.h"
#include <list>
#include <map>
#include <memory>
#include <string>
#include <utility>
......@@ -676,7 +677,7 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc,
void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
const Type* op_type) {
auto subgraph_op = subgraph_node->AsStmt().op();
CHECK(subgraph_op->Type() == "subgraph");
CHECK_EQ(subgraph_op->Type(), "subgraph");
auto op = dynamic_cast<operators::SubgraphOp*>(subgraph_op.get());
CHECK(op);
auto block_desc = op->GetSubBlock();
......
......@@ -27,10 +27,11 @@ std::shared_ptr<MLUTensor> Graph::AddNode(const std::string& name,
cnmlTensorType_t tensor_type,
cnmlDataOrder_t shape_order,
cnmlDataType_t mlu_dtype,
cnmlDataOrder_t data_order,
void* raw_ptr) {
CHECK(!HasNode(name));
auto node = std::shared_ptr<MLUTensor>(
new MLUTensor(shape, tensor_type, shape_order, mlu_dtype));
new MLUTensor(shape, tensor_type, shape_order, mlu_dtype, data_order));
node->set_mlu_ptr(raw_ptr);
nodes_.insert(std::make_pair(name, node));
return node;
......
......@@ -66,8 +66,9 @@ class Graph {
const std::string& name,
std::vector<int64_t> shape,
cnmlTensorType_t tensor_type = CNML_TENSOR,
cnmlDataOrder_t data_order = CNML_NCHW,
cnmlDataOrder_t shape_order = CNML_NCHW,
cnmlDataType_t mlu_dtype = CNML_DATA_FLOAT32,
cnmlDataOrder_t data_order = CNML_NHWC,
void* raw_ptr = nullptr);
std::shared_ptr<MLUTensor> GetNode(const std::string& name) {
......
......@@ -16,6 +16,8 @@
#include <glog/logging.h>
#include <algorithm>
#include <climits>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
......@@ -26,8 +28,9 @@ namespace mlu {
MLUTensor::MLUTensor(const std::vector<int64_t>& shape,
cnmlTensorType_t tensor_type,
cnmlDataOrder_t data_order,
cnmlDataType_t mlu_dtype)
cnmlDataOrder_t shape_order,
cnmlDataType_t mlu_dtype,
cnmlDataOrder_t data_order)
: mlu_tensor_(nullptr), tensor_type_(tensor_type), mlu_ptr_(nullptr) {
std::vector<int> int_shape;
for (auto i : shape) {
......@@ -37,15 +40,17 @@ MLUTensor::MLUTensor(const std::vector<int64_t>& shape,
LOG(FATAL) << "Shape size is beyond the limitation of MLUTensor!";
}
}
remember(int_shape, tensor_type, mlu_dtype, data_order);
remember(int_shape, tensor_type, mlu_dtype, shape_order, data_order);
}
void MLUTensor::remember(const std::vector<int>& shape,
cnmlTensorType_t tensor_type,
cnmlDataType_t mlu_dtype,
cnmlDataOrder_t shape_order) {
cnmlDataOrder_t shape_order,
cnmlDataOrder_t data_order) {
tensor_type_ = tensor_type;
mlu_dtype_ = mlu_dtype;
data_order_ = data_order;
origin_shape_.assign(shape.begin(), shape.end());
int size = 4;
......@@ -248,6 +253,12 @@ void MLUTensor::Create() {
if (mlu_tensor_ == nullptr) {
CNML_CALL(cnmlCreateTensor_V2(&mlu_tensor_, tensor_type_));
std::vector<int> dim_shape(shape_);
if (data_order_ == CNML_NCHW) {
std::transform(origin_shape_.cbegin(),
origin_shape_.cend(),
dim_shape.begin(),
[](DDim::value_type in) { return static_cast<int>(in); });
}
int* dim_strides = nullptr;
CNML_CALL(cnmlSetTensorShape_V2(
mlu_tensor_, dim_, dim_shape.data(), dim_strides));
......@@ -297,15 +308,23 @@ void MLUTensor::ToFile(std::string file_name) {
// trans to nchw
std::vector<float> cpu_data_trans(count);
transpose(
cpu_data_fp32.data(), cpu_data_trans.data(), shape_, {0, 3, 1, 2});
if (data_order_ != CNML_NCHW) {
transpose(
cpu_data_fp32.data(), cpu_data_trans.data(), shape_, {0, 3, 1, 2});
}
// to file
std::ofstream of;
of.open(file_name, std::ios::out);
std::ostringstream outs;
for (size_t i = 0; i < count; i++) {
of << cpu_data_trans[i] << std::endl;
if (data_order_ == CNML_NCHW) {
outs << cpu_data_fp32[i] << std::endl;
} else {
outs << cpu_data_trans[i] << std::endl;
}
}
std::ofstream of;
of.open(file_name, std::ios::out);
of << outs.str();
of.close();
} else {
LOG(FATAL) << "mlu ptr is null ,can not dump mlu content to : " << file_name
......
......@@ -35,13 +35,15 @@ class MLUTensor {
MLUTensor(const std::vector<int64_t>& shape,
cnmlTensorType_t tensor_type = CNML_TENSOR,
cnmlDataOrder_t data_order = CNML_NCHW,
cnmlDataType_t mlu_dtype = CNML_DATA_FLOAT32);
cnmlDataOrder_t shape_order = CNML_NCHW,
cnmlDataType_t mlu_dtype = CNML_DATA_FLOAT32,
cnmlDataOrder_t data_order = CNML_NHWC);
void remember(const std::vector<int>& shape,
cnmlTensorType_t tensor_type,
cnmlDataType_t mlu_dtype,
cnmlDataOrder_t shape_order);
cnmlDataOrder_t shape_order,
cnmlDataOrder_t data_order);
void Create();
cnmlTensor_t mlu_tensor();
void* mlu_data() {
......
......@@ -24,11 +24,35 @@ namespace lite {
namespace subgraph {
namespace mlu {
template <lite_api::PrecisionType Dtype>
void PrepareInput(Graph* graph,
const std::string& input_name,
Tensor* input_tensor) {
thread_local Tensor temp_input;
temp_input.Resize(input_tensor->dims().Vectorize());
temp_input.CopyDataFrom(*input_tensor);
using data_type = typename MLUTypeTraits<Dtype>::type;
auto input_node = graph->AddNode(
input_name,
input_tensor->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
MLUTypeTraits<Dtype>::cnml_type,
CNML_NHWC,
reinterpret_cast<void*>(
input_tensor->template mutable_data<data_type>(TARGET(kMLU))));
CHECK(input_node);
CNRT_CHECK(cnrtMemcpy(input_tensor->template mutable_data<data_type>(),
temp_input.mutable_data<data_type>(),
sizeof(data_type) * input_tensor->dims().production(),
CNRT_MEM_TRANS_DIR_HOST2DEV));
}
void LaunchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names) {
CNRT_CALL(cnrtInit(0));
::paddle::lite::SetMluDevice(0);
lite::SetMluDevice(0);
cnrtQueue_t queue_;
cnrtInvokeFuncParam_t forward_param;
u32_t affinity = 1;
......@@ -51,70 +75,20 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op,
for (auto& input_name : input_var_names) {
auto input_tensor = scope->FindMutableTensor(input_name);
auto data_type = input_tensor->precision();
cnmlDataType_t fp_type;
switch (data_type) {
case paddle::lite_api::PrecisionType::kFP16:
fp_type = CNML_DATA_FLOAT16;
break;
case paddle::lite_api::PrecisionType::kFloat:
fp_type = CNML_DATA_FLOAT32;
break;
case paddle::lite_api::PrecisionType::kInt32:
fp_type = CNML_DATA_INT32;
break;
#define PREPARE_INPUT(type__) \
case PRECISION(type__): \
PrepareInput<PRECISION(type__)>(&graph, input_name, input_tensor); \
break;
PREPARE_INPUT(kFP16)
PREPARE_INPUT(kFloat)
PREPARE_INPUT(kInt8)
PREPARE_INPUT(kInt32)
#undef PREPARE_INPUT
default:
CHECK(0);
}
CHECK(input_tensor);
Tensor temp_input;
temp_input.Resize(input_tensor->dims().Vectorize());
temp_input.CopyDataFrom(*input_tensor);
if (fp_type == CNML_DATA_INT32) {
auto input_node =
graph.AddNode(input_name,
input_tensor->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
fp_type,
reinterpret_cast<void*>(
input_tensor->mutable_data<int>(TARGET(kMLU))));
CHECK(input_node);
CNRT_CHECK(cnrtMemcpy(input_tensor->mutable_data<int>(),
temp_input.mutable_data<int>(),
sizeof(int) * input_tensor->dims().production(),
CNRT_MEM_TRANS_DIR_HOST2DEV));
} else if (fp_type == CNML_DATA_FLOAT16) {
auto input_node = graph.AddNode(
input_name,
input_tensor->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
fp_type,
reinterpret_cast<void*>(
input_tensor->mutable_data<paddle::lite::fluid::float16>(
TARGET(kMLU))));
CHECK(input_node);
CNRT_CHECK(
cnrtMemcpy(input_tensor->mutable_data<paddle::lite::fluid::float16>(),
temp_input.mutable_data<paddle::lite::fluid::float16>(),
sizeof(paddle::lite::fluid::float16) *
input_tensor->dims().production(),
CNRT_MEM_TRANS_DIR_HOST2DEV));
} else {
auto input_node =
graph.AddNode(input_name,
input_tensor->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
fp_type,
reinterpret_cast<void*>(
input_tensor->mutable_data<float>(TARGET(kMLU))));
CHECK(input_node);
CNRT_CHECK(cnrtMemcpy(input_tensor->mutable_data<float>(),
temp_input.mutable_data<float>(),
sizeof(float) * input_tensor->dims().production(),
CNRT_MEM_TRANS_DIR_HOST2DEV));
}
}
op->CheckShape();
op->InferShape();
......
......@@ -144,16 +144,33 @@ inline const std::vector<DDimLite::value_type> DimNCHW2NHWC(
}
template <paddle::lite_api::PrecisionType>
struct FPTypeTraits {};
struct MLUTypeTraits {
/* using type = void; */
/* static constexpr cnmlDataType_t cnml_type = CNML_DATA_INVALID; */
};
template <>
struct MLUTypeTraits<paddle::lite_api::PrecisionType::kFloat> {
using type = float;
static constexpr cnmlDataType_t cnml_type = CNML_DATA_FLOAT32;
};
template <>
struct MLUTypeTraits<paddle::lite_api::PrecisionType::kFP16> {
using type = paddle::lite::fluid::float16;
static constexpr cnmlDataType_t cnml_type = CNML_DATA_FLOAT16;
};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFloat> {
typedef float T;
struct MLUTypeTraits<paddle::lite_api::PrecisionType::kInt8> {
using type = int8_t;
static constexpr cnmlDataType_t cnml_type = CNML_DATA_INT8;
};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFP16> {
typedef paddle::lite::fluid::float16 T;
struct MLUTypeTraits<paddle::lite_api::PrecisionType::kInt32> {
using type = int32_t;
static constexpr cnmlDataType_t cnml_type = CNML_DATA_INT32;
};
} // namespace mlu
......
......@@ -22,6 +22,7 @@
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/operators/layout_op.h"
namespace paddle {
......@@ -29,24 +30,6 @@ namespace lite {
namespace kernels {
namespace mlu {
template <paddle::lite_api::PrecisionType>
struct FPTypeTraits {};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFloat> {
using type = float;
};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFP16> {
using type = paddle::lite::fluid::float16;
};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kInt8> {
using type = int8_t;
};
template <lite::TargetType Target, typename T>
inline void LayoutTransCompute(const int dim,
const lite::Context<Target>& context,
......@@ -81,7 +64,8 @@ class LayoutNchwToNhwcCompute
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* out = param.y;
out->template mutable_data<typename FPTypeTraits<Precision>::type>();
out->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>();
auto x_ndims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>();
......@@ -107,7 +91,7 @@ class LayoutNchwToNhwcCompute
}
LayoutTransCompute<lite::TargetType::kX86,
typename FPTypeTraits<Precision>::type>(
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
x_ndims, context, *x, out, axis);
if (x_ndims > 2) {
......@@ -130,7 +114,8 @@ class LayoutNhwcToNchwCompute
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* out = param.y;
out->template mutable_data<typename FPTypeTraits<Precision>::type>();
out->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>();
auto& context = this->ctx_->template As<X86Context>();
TensorLite tmp_t;
......@@ -157,7 +142,7 @@ class LayoutNhwcToNchwCompute
}
LayoutTransCompute<lite::TargetType::kX86,
typename FPTypeTraits<Precision>::type>(
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
x_ndims, context, tmp_t, out, axis);
}
......
......@@ -147,6 +147,9 @@ class SubgraphEngine : public subgraph::Engine {
origin_itensors_.clear();
origin_otensors_.clear();
auto data_order = block_desc_->GetOp<cpp::OpDesc>(0)->Type() == "cast"
? CNML_NCHW
: CNML_NHWC;
// Convert all of input data vars and added into the MLU IR graph
status |= subgraph::REBUILD_WHEN_SHAPE_CHANGED;
for (auto& input_name : input_names_) {
......@@ -167,7 +170,8 @@ class SubgraphEngine : public subgraph::Engine {
input_tensor->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
fp_type);
fp_type,
data_order);
CHECK(input_node);
// MLU doesn't support dynamic dimensions/shapes, so need to rebuild
// the program when the shape of any input tensor is changed.
......@@ -367,8 +371,9 @@ class SubgraphEngine : public subgraph::Engine {
// origin_otensors_[i]->Resize(new_output_size.at(i));
void* p_data = static_cast<void*>(
origin_otensors_[i]
->mutable_data<typename paddle::lite::subgraph::mlu::
FPTypeTraits<Precision>::T>(TARGET(kMLU)));
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
graph_out[i]->set_mlu_ptr(p_data);
}
} else {
......@@ -377,8 +382,9 @@ class SubgraphEngine : public subgraph::Engine {
// origin_otensors_[i]->Resize(new_output_size.at(i));
void* p_data = static_cast<void*>(
origin_otensors_[i]
->mutable_data<typename paddle::lite::subgraph::mlu::
FPTypeTraits<Precision>::T>(TARGET(kMLU)));
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
paddle::lite::subgraph::mlu::MLUTensor tmp(
origin_otensors_[i]->dims().Vectorize());
tmp.set_mlu_dtype(graph_output->at(i)->dtype());
......@@ -398,8 +404,9 @@ class SubgraphEngine : public subgraph::Engine {
origin_otensors_[i]->Resize(graph_output->at(i)->get_origin_shape());
void* p_data = static_cast<void*>(
origin_otensors_[i]
->mutable_data<typename paddle::lite::subgraph::mlu::
FPTypeTraits<Precision>::T>(TARGET(kMLU)));
->template mutable_data<
typename subgraph::mlu::MLUTypeTraits<Precision>::type>(
TARGET(kMLU)));
graph_output->at(i)->set_mlu_ptr(p_data);
}
graph->Compute(forward_param, exec_queue);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册