提交 d6791276 编写于 作者: D dingminghui 提交者: MaxwellDing

refactor: abstract function to generate axes trans vector

上级 88513fd0
......@@ -45,13 +45,9 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto dims = output_dims.size();
int axis = (param_axis < 0) ? (param_axis + dims) : param_axis;
CHECK_LT(axis, dims) << "Unsupport dims in mlu concat";
std::vector<int> nchw2nhwc_axis(dims);
nchw2nhwc_axis[0] = 0;
if (dims > 1) nchw2nhwc_axis[1] = dims - 1;
for (size_t i = 2; i < dims; ++i) {
nchw2nhwc_axis[i] = i - 1;
}
int nhwc_axis = nchw2nhwc_axis[axis];
// value of nhwc2nchw_axis is index of nhwc
// order of nhwc2nchw_axis is nchw
int nhwc_axis = GetAxisNHWC2NCHW<int>(dims)[axis];
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
......
......@@ -38,24 +38,7 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ================== Trans1: NHWC => NCHW ===========================
auto input_tensor = graph->GetNode(x_var_name);
// std::vector<int> nhwc_to_nchw_axis = {0, 3, 1, 2};
std::vector<int> trans_1_axis;
switch (x->dims().size()) {
case 4:
trans_1_axis = {0, 3, 1, 2};
break;
case 3:
trans_1_axis = {0, 2, 1};
break;
case 2:
trans_1_axis = {0, 1};
break;
case 1:
trans_1_axis = {0};
break;
default:
break;
}
auto trans_1_axis = std::move(GetAxisNHWC2NCHW<int>(x->dims().size()));
auto trans1_out = graph->AddNode(x_var_name + ".trans.i",
x->dims().Vectorize(),
CNML_TENSOR,
......@@ -95,24 +78,7 @@ int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ======================= Flatten End ===================================
// ================== Trans2: NCHW => NHWC ===============================
// std::vector<int> nchw_to_nhwc_axis = {0, 2, 3, 1};
std::vector<int> trans_2_axis;
switch (output->dims().size()) {
case 4:
trans_2_axis = {0, 2, 3, 1};
break;
case 3:
trans_2_axis = {0, 2, 1};
break;
case 2:
trans_2_axis = {0, 1};
break;
case 1:
trans_2_axis = {0};
break;
default:
break;
}
auto trans_2_axis = std::move(GetAxisNCHW2NHWC<int>(output->dims().size()));
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t trans2_op{nullptr};
......
......@@ -38,24 +38,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ================== Trans1: NHWC => NCHW ===========================
auto input_tensor = graph->GetNode(x_var_name);
// std::vector<int> nhwc_to_nchw_axis = {0, 3, 1, 2};
std::vector<int> trans_1_axis;
switch (x->dims().size()) {
case 4:
trans_1_axis = {0, 3, 1, 2};
break;
case 3:
trans_1_axis = {0, 2, 1};
break;
case 2:
trans_1_axis = {0, 1};
break;
case 1:
trans_1_axis = {0};
break;
default:
break;
}
auto trans_1_axis = std::move(GetAxisNHWC2NCHW<int>(x->dims().size()));
auto trans1_out = graph->AddNode(x_var_name + ".trans.i",
x->dims().Vectorize(),
CNML_TENSOR,
......@@ -95,24 +78,7 @@ int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// ======================= Reshape op End ===================================
// ================== Trans2: NCHW => NHWC ===============================
// std::vector<int> nchw_to_nhwc_axis = {0, 2, 3, 1};
std::vector<int> trans_2_axis;
switch (output->dims().size()) {
case 4:
trans_2_axis = {0, 2, 3, 1};
break;
case 3:
trans_2_axis = {0, 2, 1};
break;
case 2:
trans_2_axis = {0, 1};
break;
case 1:
trans_2_axis = {0};
break;
default:
break;
}
auto trans_2_axis = std::move(GetAxisNCHW2NHWC<int>(output->dims().size()));
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t trans2_op{nullptr};
......
......@@ -53,12 +53,7 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::vector<int32_t> begin_index(input_shape.size(), 0);
std::vector<int32_t> end_index(input_shape.size());
std::vector<int32_t> strides(input_shape.size(), 1);
std::vector<int> nhwc2nchw_axis(input_shape.size());
nhwc2nchw_axis[0] = 0;
if (input_shape.size() > 1) nhwc2nchw_axis[1] = input_shape.size() - 1;
for (size_t i = 2; i < input_shape.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
auto nhwc2nchw_axis = std::move(GetAxisNHWC2NCHW<int>(input_shape.size()));
for (size_t i = 0; i < input_shape.size(); ++i) {
end_index[nhwc2nchw_axis[i]] = input_shape[i];
}
......
......@@ -108,31 +108,19 @@ static void test_case(std::vector<int64_t> x_shape,
std::vector<float> out_ref(out->data_size(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
std::vector<int> nhwc2nchw_axis(x_shape.size());
nhwc2nchw_axis[0] = 0;
if (x_shape.size() > 1) nhwc2nchw_axis[1] = x_shape.size() - 1;
for (size_t i = 2; i < x_shape.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
std::vector<int> nchw2nhwc_axis(x_shape.size());
nchw2nhwc_axis[0] = 0;
for (size_t i = 1; i < x_shape.size() - 1; ++i) {
nchw2nhwc_axis[i] = i + 1;
}
if (x_shape.size() > 1) nchw2nhwc_axis[x_shape.size() - 1] = 1;
auto type_cast = [](int64_t in) { return static_cast<int>(in); };
std::vector<int> i_dims;
std::transform(
x_shape.cbegin(), x_shape.cend(), std::back_inserter(i_dims), type_cast);
auto nchw2nhwc_axis = std::move(GetAxisNCHW2NHWC<int>(x_shape.size()));
Tensor input_x;
input_x.Resize(x->dims());
transpose<float*>(x->mutable_data<float>(),
input_x.mutable_data<float>(),
i_dims,
nchw2nhwc_axis);
transpose<float>(x->mutable_data<float>(),
input_x.mutable_data<float>(),
i_dims,
nchw2nhwc_axis);
x->CopyDataFrom(input_x);
auto op = CreateOp<operators::SliceOp>(opdesc, &scope);
......@@ -145,10 +133,10 @@ static void test_case(std::vector<int64_t> x_shape,
for (size_t i = 0; i < os.size(); ++i) {
o_dims[i] = os[nchw2nhwc_axis[i]];
}
transpose<float*>(out->mutable_data<float>(),
output_trans.mutable_data<float>(),
o_dims,
nhwc2nchw_axis);
transpose<float>(out->mutable_data<float>(),
output_trans.mutable_data<float>(),
o_dims,
GetAxisNHWC2NCHW<int>(x_shape.size()));
auto out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
......
......@@ -38,13 +38,7 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto x_shape =
scope->FindVar(x_var_name)->GetMutable<Tensor>()->dims().Vectorize();
// nchw axis to nhwc aixs
std::vector<int> nchw2nhwc_axis(x_shape.size());
nchw2nhwc_axis[0] = 0;
if (x_shape.size() > 1) nchw2nhwc_axis[1] = x_shape.size() - 1;
for (size_t i = 2; i < x_shape.size(); ++i) {
nchw2nhwc_axis[i] = i - 1;
}
// nchw axis to nhwc axis
int axis = 1;
if (op_info->HasAttr("axis")) {
axis = op_info->GetAttr<int>("axis");
......@@ -52,7 +46,9 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
axis = output_dims.size() + axis;
}
}
int nhwc_axis = nchw2nhwc_axis[axis];
// value of nhwc2nchw_axis is index of nhwc
// order of nhwc2nchw_axis is nchw
int nhwc_axis = GetAxisNHWC2NCHW<int>(x_shape.size())[axis];
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
......
......@@ -24,19 +24,8 @@ namespace mlu {
std::vector<int> axis_to_nhwc(const std::vector<int>& axis) {
std::vector<int> new_axis(axis.size());
std::vector<int> nhwc2nchw_axis(axis.size());
nhwc2nchw_axis[0] = 0;
if (axis.size() > 1) nhwc2nchw_axis[1] = axis.size() - 1;
for (size_t i = 2; i < axis.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
std::vector<int> nchw2nhwc_axis(axis.size());
nchw2nhwc_axis[0] = 0;
for (size_t i = 1; i < axis.size() - 1; ++i) {
nchw2nhwc_axis[i] = i + 1;
}
if (axis.size() > 1) nchw2nhwc_axis[axis.size() - 1] = 1;
auto nhwc2nchw_axis = std::move(GetAxisNHWC2NCHW<int>(axis.size()));
auto nchw2nhwc_axis = std::move(GetAxisNCHW2NHWC<int>(axis.size()));
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = nhwc2nchw_axis[axis[nchw2nhwc_axis[i]]];
......
......@@ -44,12 +44,12 @@ void transpose(dtype* input_data,
int new_index = -1;
std::vector<int> shape;
std::vector<int> expand_axis;
if (input_shape.size() < 5) {
for (int i = 0; i < 5 - input_shape.size(); i++) {
if (input_shape.size() < 5u) {
for (size_t i = 0; i < 5 - input_shape.size(); i++) {
shape.push_back(1);
expand_axis.push_back(i);
}
for (int i = 0; i < input_shape.size(); i++) {
for (size_t i = 0; i < input_shape.size(); i++) {
shape.push_back(input_shape[i]);
expand_axis.push_back(axis[i] + 5 - input_shape.size());
}
......@@ -154,6 +154,28 @@ inline const std::vector<data_type> DimNCHW2NHWC(
}
}
template <typename data_type>
inline std::vector<data_type> GetAxisNHWC2NCHW(size_t n_dims) {
std::vector<data_type> nhwc2nchw_axis(n_dims);
nhwc2nchw_axis[0] = 0;
if (n_dims > 1) nhwc2nchw_axis[1] = n_dims - 1;
for (size_t i = 2; i < n_dims; ++i) {
nhwc2nchw_axis[i] = i - 1;
}
return nhwc2nchw_axis;
}
template <typename data_type>
inline std::vector<data_type> GetAxisNCHW2NHWC(size_t n_dims) {
std::vector<data_type> nchw2nhwc_axis(n_dims);
nchw2nhwc_axis[0] = 0;
for (size_t i = 1; i < n_dims - 1; ++i) {
nchw2nhwc_axis[i] = i + 1;
}
if (n_dims > 1) nchw2nhwc_axis[n_dims - 1] = 1;
return nchw2nhwc_axis;
}
template <paddle::lite_api::PrecisionType>
struct MLUTypeTraits {
/* using type = void; */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册