提交 d6791276 编写于 作者: D dingminghui 提交者: MaxwellDing

refactor: abstract function to generate axes trans vector

上级 88513fd0
...@@ -45,13 +45,9 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -45,13 +45,9 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto dims = output_dims.size(); auto dims = output_dims.size();
int axis = (param_axis < 0) ? (param_axis + dims) : param_axis; int axis = (param_axis < 0) ? (param_axis + dims) : param_axis;
CHECK_LT(axis, dims) << "Unsupport dims in mlu concat"; CHECK_LT(axis, dims) << "Unsupport dims in mlu concat";
std::vector<int> nchw2nhwc_axis(dims); // value of nhwc2nchw_axis is index of nhwc
nchw2nhwc_axis[0] = 0; // order of nhwc2nchw_axis is nchw
if (dims > 1) nchw2nhwc_axis[1] = dims - 1; int nhwc_axis = GetAxisNHWC2NCHW<int>(dims)[axis];
for (size_t i = 2; i < dims; ++i) {
nchw2nhwc_axis[i] = i - 1;
}
int nhwc_axis = nchw2nhwc_axis[axis];
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
......
...@@ -38,24 +38,7 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -38,24 +38,7 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ================== Trans1: NHWC => NCHW =========================== // ================== Trans1: NHWC => NCHW ===========================
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
// std::vector<int> nhwc_to_nchw_axis = {0, 3, 1, 2}; auto trans_1_axis = std::move(GetAxisNHWC2NCHW<int>(x->dims().size()));
std::vector<int> trans_1_axis;
switch (x->dims().size()) {
case 4:
trans_1_axis = {0, 3, 1, 2};
break;
case 3:
trans_1_axis = {0, 2, 1};
break;
case 2:
trans_1_axis = {0, 1};
break;
case 1:
trans_1_axis = {0};
break;
default:
break;
}
auto trans1_out = graph->AddNode(x_var_name + ".trans.i", auto trans1_out = graph->AddNode(x_var_name + ".trans.i",
x->dims().Vectorize(), x->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
...@@ -95,24 +78,7 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -95,24 +78,7 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ======================= Flatten End =================================== // ======================= Flatten End ===================================
// ================== Trans2: NCHW => NHWC =============================== // ================== Trans2: NCHW => NHWC ===============================
// std::vector<int> nchw_to_nhwc_axis = {0, 2, 3, 1}; auto trans_2_axis = std::move(GetAxisNCHW2NHWC<int>(output->dims().size()));
std::vector<int> trans_2_axis;
switch (output->dims().size()) {
case 4:
trans_2_axis = {0, 2, 3, 1};
break;
case 3:
trans_2_axis = {0, 2, 1};
break;
case 2:
trans_2_axis = {0, 1};
break;
case 1:
trans_2_axis = {0};
break;
default:
break;
}
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t trans2_op{nullptr}; cnmlBaseOp_t trans2_op{nullptr};
......
...@@ -38,24 +38,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -38,24 +38,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ================== Trans1: NHWC => NCHW =========================== // ================== Trans1: NHWC => NCHW ===========================
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
// std::vector<int> nhwc_to_nchw_axis = {0, 3, 1, 2}; auto trans_1_axis = std::move(GetAxisNHWC2NCHW<int>(x->dims().size()));
std::vector<int> trans_1_axis;
switch (x->dims().size()) {
case 4:
trans_1_axis = {0, 3, 1, 2};
break;
case 3:
trans_1_axis = {0, 2, 1};
break;
case 2:
trans_1_axis = {0, 1};
break;
case 1:
trans_1_axis = {0};
break;
default:
break;
}
auto trans1_out = graph->AddNode(x_var_name + ".trans.i", auto trans1_out = graph->AddNode(x_var_name + ".trans.i",
x->dims().Vectorize(), x->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
...@@ -95,24 +78,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -95,24 +78,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ======================= Reshape op End =================================== // ======================= Reshape op End ===================================
// ================== Trans2: NCHW => NHWC =============================== // ================== Trans2: NCHW => NHWC ===============================
// std::vector<int> nchw_to_nhwc_axis = {0, 2, 3, 1}; auto trans_2_axis = std::move(GetAxisNCHW2NHWC<int>(output->dims().size()));
std::vector<int> trans_2_axis;
switch (output->dims().size()) {
case 4:
trans_2_axis = {0, 2, 3, 1};
break;
case 3:
trans_2_axis = {0, 2, 1};
break;
case 2:
trans_2_axis = {0, 1};
break;
case 1:
trans_2_axis = {0};
break;
default:
break;
}
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t trans2_op{nullptr}; cnmlBaseOp_t trans2_op{nullptr};
......
...@@ -53,12 +53,7 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -53,12 +53,7 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::vector<int32_t> begin_index(input_shape.size(), 0); std::vector<int32_t> begin_index(input_shape.size(), 0);
std::vector<int32_t> end_index(input_shape.size()); std::vector<int32_t> end_index(input_shape.size());
std::vector<int32_t> strides(input_shape.size(), 1); std::vector<int32_t> strides(input_shape.size(), 1);
std::vector<int> nhwc2nchw_axis(input_shape.size()); auto nhwc2nchw_axis = std::move(GetAxisNHWC2NCHW<int>(input_shape.size()));
nhwc2nchw_axis[0] = 0;
if (input_shape.size() > 1) nhwc2nchw_axis[1] = input_shape.size() - 1;
for (size_t i = 2; i < input_shape.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
for (size_t i = 0; i < input_shape.size(); ++i) { for (size_t i = 0; i < input_shape.size(); ++i) {
end_index[nhwc2nchw_axis[i]] = input_shape[i]; end_index[nhwc2nchw_axis[i]] = input_shape[i];
} }
......
...@@ -108,31 +108,19 @@ static void test_case(std::vector<int64_t> x_shape, ...@@ -108,31 +108,19 @@ static void test_case(std::vector<int64_t> x_shape,
std::vector<float> out_ref(out->data_size(), 0); std::vector<float> out_ref(out->data_size(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data()); slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
std::vector<int> nhwc2nchw_axis(x_shape.size());
nhwc2nchw_axis[0] = 0;
if (x_shape.size() > 1) nhwc2nchw_axis[1] = x_shape.size() - 1;
for (size_t i = 2; i < x_shape.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
std::vector<int> nchw2nhwc_axis(x_shape.size());
nchw2nhwc_axis[0] = 0;
for (size_t i = 1; i < x_shape.size() - 1; ++i) {
nchw2nhwc_axis[i] = i + 1;
}
if (x_shape.size() > 1) nchw2nhwc_axis[x_shape.size() - 1] = 1;
auto type_cast = [](int64_t in) { return static_cast<int>(in); }; auto type_cast = [](int64_t in) { return static_cast<int>(in); };
std::vector<int> i_dims; std::vector<int> i_dims;
std::transform( std::transform(
x_shape.cbegin(), x_shape.cend(), std::back_inserter(i_dims), type_cast); x_shape.cbegin(), x_shape.cend(), std::back_inserter(i_dims), type_cast);
auto nchw2nhwc_axis = std::move(GetAxisNCHW2NHWC<int>(x_shape.size()));
Tensor input_x; Tensor input_x;
input_x.Resize(x->dims()); input_x.Resize(x->dims());
transpose<float*>(x->mutable_data<float>(), transpose<float>(x->mutable_data<float>(),
input_x.mutable_data<float>(), input_x.mutable_data<float>(),
i_dims, i_dims,
nchw2nhwc_axis); nchw2nhwc_axis);
x->CopyDataFrom(input_x); x->CopyDataFrom(input_x);
auto op = CreateOp<operators::SliceOp>(opdesc, &scope); auto op = CreateOp<operators::SliceOp>(opdesc, &scope);
...@@ -145,10 +133,10 @@ static void test_case(std::vector<int64_t> x_shape, ...@@ -145,10 +133,10 @@ static void test_case(std::vector<int64_t> x_shape,
for (size_t i = 0; i < os.size(); ++i) { for (size_t i = 0; i < os.size(); ++i) {
o_dims[i] = os[nchw2nhwc_axis[i]]; o_dims[i] = os[nchw2nhwc_axis[i]];
} }
transpose<float*>(out->mutable_data<float>(), transpose<float>(out->mutable_data<float>(),
output_trans.mutable_data<float>(), output_trans.mutable_data<float>(),
o_dims, o_dims,
nhwc2nchw_axis); GetAxisNHWC2NCHW<int>(x_shape.size()));
auto out_data = output_trans.mutable_data<float>(); auto out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) { for (int i = 0; i < out->dims().production(); i++) {
......
...@@ -38,13 +38,7 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -38,13 +38,7 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto x_shape = auto x_shape =
scope->FindVar(x_var_name)->GetMutable<Tensor>()->dims().Vectorize(); scope->FindVar(x_var_name)->GetMutable<Tensor>()->dims().Vectorize();
// nchw axis to nhwc aixs // nchw axis to nhwc axis
std::vector<int> nchw2nhwc_axis(x_shape.size());
nchw2nhwc_axis[0] = 0;
if (x_shape.size() > 1) nchw2nhwc_axis[1] = x_shape.size() - 1;
for (size_t i = 2; i < x_shape.size(); ++i) {
nchw2nhwc_axis[i] = i - 1;
}
int axis = 1; int axis = 1;
if (op_info->HasAttr("axis")) { if (op_info->HasAttr("axis")) {
axis = op_info->GetAttr<int>("axis"); axis = op_info->GetAttr<int>("axis");
...@@ -52,7 +46,9 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -52,7 +46,9 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
axis = output_dims.size() + axis; axis = output_dims.size() + axis;
} }
} }
int nhwc_axis = nchw2nhwc_axis[axis]; // value of nhwc2nchw_axis is index of nhwc
// order of nhwc2nchw_axis is nchw
int nhwc_axis = GetAxisNHWC2NCHW<int>(x_shape.size())[axis];
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
......
...@@ -24,19 +24,8 @@ namespace mlu { ...@@ -24,19 +24,8 @@ namespace mlu {
std::vector<int> axis_to_nhwc(const std::vector<int>& axis) { std::vector<int> axis_to_nhwc(const std::vector<int>& axis) {
std::vector<int> new_axis(axis.size()); std::vector<int> new_axis(axis.size());
std::vector<int> nhwc2nchw_axis(axis.size()); auto nhwc2nchw_axis = std::move(GetAxisNHWC2NCHW<int>(axis.size()));
nhwc2nchw_axis[0] = 0; auto nchw2nhwc_axis = std::move(GetAxisNCHW2NHWC<int>(axis.size()));
if (axis.size() > 1) nhwc2nchw_axis[1] = axis.size() - 1;
for (size_t i = 2; i < axis.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
std::vector<int> nchw2nhwc_axis(axis.size());
nchw2nhwc_axis[0] = 0;
for (size_t i = 1; i < axis.size() - 1; ++i) {
nchw2nhwc_axis[i] = i + 1;
}
if (axis.size() > 1) nchw2nhwc_axis[axis.size() - 1] = 1;
for (size_t i = 0; i < new_axis.size(); ++i) { for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = nhwc2nchw_axis[axis[nchw2nhwc_axis[i]]]; new_axis[i] = nhwc2nchw_axis[axis[nchw2nhwc_axis[i]]];
......
...@@ -44,12 +44,12 @@ void transpose(dtype* input_data, ...@@ -44,12 +44,12 @@ void transpose(dtype* input_data,
int new_index = -1; int new_index = -1;
std::vector<int> shape; std::vector<int> shape;
std::vector<int> expand_axis; std::vector<int> expand_axis;
if (input_shape.size() < 5) { if (input_shape.size() < 5u) {
for (int i = 0; i < 5 - input_shape.size(); i++) { for (size_t i = 0; i < 5 - input_shape.size(); i++) {
shape.push_back(1); shape.push_back(1);
expand_axis.push_back(i); expand_axis.push_back(i);
} }
for (int i = 0; i < input_shape.size(); i++) { for (size_t i = 0; i < input_shape.size(); i++) {
shape.push_back(input_shape[i]); shape.push_back(input_shape[i]);
expand_axis.push_back(axis[i] + 5 - input_shape.size()); expand_axis.push_back(axis[i] + 5 - input_shape.size());
} }
...@@ -154,6 +154,28 @@ inline const std::vector<data_type> DimNCHW2NHWC( ...@@ -154,6 +154,28 @@ inline const std::vector<data_type> DimNCHW2NHWC(
} }
} }
template <typename data_type>
inline std::vector<data_type> GetAxisNHWC2NCHW(size_t n_dims) {
std::vector<data_type> nhwc2nchw_axis(n_dims);
nhwc2nchw_axis[0] = 0;
if (n_dims > 1) nhwc2nchw_axis[1] = n_dims - 1;
for (size_t i = 2; i < n_dims; ++i) {
nhwc2nchw_axis[i] = i - 1;
}
return nhwc2nchw_axis;
}
template <typename data_type>
inline std::vector<data_type> GetAxisNCHW2NHWC(size_t n_dims) {
std::vector<data_type> nchw2nhwc_axis(n_dims);
nchw2nhwc_axis[0] = 0;
for (size_t i = 1; i < n_dims - 1; ++i) {
nchw2nhwc_axis[i] = i + 1;
}
if (n_dims > 1) nchw2nhwc_axis[n_dims - 1] = 1;
return nchw2nhwc_axis;
}
template <paddle::lite_api::PrecisionType> template <paddle::lite_api::PrecisionType>
struct MLUTypeTraits { struct MLUTypeTraits {
/* using type = void; */ /* using type = void; */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册