提交 668145b9 编写于 作者: D dingminghui 提交者: jackzhang235

feat(slice): support slice that dim is not 4

上级 1bb607bb
...@@ -30,6 +30,9 @@ std::shared_ptr<MLUTensor> Graph::AddNode(const std::string& name, ...@@ -30,6 +30,9 @@ std::shared_ptr<MLUTensor> Graph::AddNode(const std::string& name,
cnmlDataOrder_t data_order, cnmlDataOrder_t data_order,
void* raw_ptr) { void* raw_ptr) {
CHECK(!HasNode(name)); CHECK(!HasNode(name));
VLOG(5) << "add mlu node: " << name << "\t data type "
<< static_cast<int>(mlu_dtype) << "\t data order "
<< static_cast<int>(data_order);
auto node = std::shared_ptr<MLUTensor>( auto node = std::shared_ptr<MLUTensor>(
new MLUTensor(shape, tensor_type, shape_order, mlu_dtype, data_order)); new MLUTensor(shape, tensor_type, shape_order, mlu_dtype, data_order));
node->set_mlu_ptr(raw_ptr); node->set_mlu_ptr(raw_ptr);
......
...@@ -53,17 +53,21 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -53,17 +53,21 @@ int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::vector<int32_t> begin_index(input_shape.size(), 0); std::vector<int32_t> begin_index(input_shape.size(), 0);
std::vector<int32_t> end_index(input_shape.size()); std::vector<int32_t> end_index(input_shape.size());
std::vector<int32_t> strides(input_shape.size(), 1); std::vector<int32_t> strides(input_shape.size(), 1);
CHECK(input_shape.size() == 4) << "only support 4 dimention"; std::vector<int> nhwc2nchw_axis(input_shape.size());
std::vector<int> nchw2nhwc_index = {0, 3, 1, 2}; nhwc2nchw_axis[0] = 0;
if (input_shape.size() > 1) nhwc2nchw_axis[1] = input_shape.size() - 1;
for (size_t i = 2; i < input_shape.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
for (size_t i = 0; i < input_shape.size(); ++i) { for (size_t i = 0; i < input_shape.size(); ++i) {
end_index[nchw2nhwc_index[i]] = input_shape[i]; end_index[nhwc2nchw_axis[i]] = input_shape[i];
} }
for (size_t i = 0; i < axes.size(); i++) { for (size_t i = 0; i < axes.size(); i++) {
int dim_value = input_shape[axes[i]]; int dim_value = input_shape[axes[i]];
int end = ends[i] < 0 ? std::max(ends[i] + dim_value, 0) : ends[i]; int end = ends[i] < 0 ? std::max(ends[i] + dim_value, 0) : ends[i];
begin_index[nchw2nhwc_index[axes[i]]] = begin_index[nhwc2nchw_axis[axes[i]]] =
starts[i] < 0 ? std::max(starts[i] + dim_value, 0) : starts[i]; starts[i] < 0 ? std::max(starts[i] + dim_value, 0) : starts[i];
end_index[nchw2nhwc_index[axes[i]]] = std::min(end, dim_value); end_index[nhwc2nchw_axis[axes[i]]] = std::min(end, dim_value);
} }
cnmlNdStridedSliceOpParam_t param; cnmlNdStridedSliceOpParam_t param;
......
...@@ -108,30 +108,47 @@ static void test_case(std::vector<int64_t> x_shape, ...@@ -108,30 +108,47 @@ static void test_case(std::vector<int64_t> x_shape,
std::vector<float> out_ref(out->data_size(), 0); std::vector<float> out_ref(out->data_size(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data()); slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
std::vector<int> nhwc2nchw_axis(x_shape.size());
nhwc2nchw_axis[0] = 0;
if (x_shape.size() > 1) nhwc2nchw_axis[1] = x_shape.size() - 1;
for (size_t i = 2; i < x_shape.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
std::vector<int> nchw2nhwc_axis(x_shape.size());
nchw2nhwc_axis[0] = 0;
for (size_t i = 1; i < x_shape.size() - 1; ++i) {
nchw2nhwc_axis[i] = i + 1;
}
if (x_shape.size() > 1) nchw2nhwc_axis[x_shape.size() - 1] = 1;
auto type_cast = [](int64_t in) { return static_cast<int>(in); };
std::vector<int> i_dims;
std::transform(
x_shape.cbegin(), x_shape.cend(), std::back_inserter(i_dims), type_cast);
Tensor input_x; Tensor input_x;
input_x.Resize(x->dims()); input_x.Resize(x->dims());
transpose(x->mutable_data<float>(), transpose<float*>(x->mutable_data<float>(),
input_x.mutable_data<float>(), input_x.mutable_data<float>(),
{static_cast<int>(x_shape[0]), i_dims,
static_cast<int>(x_shape[1]), nchw2nhwc_axis);
static_cast<int>(x_shape[2]),
static_cast<int>(x_shape[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_x); x->CopyDataFrom(input_x);
auto op = CreateOp<operators::SliceOp>(opdesc, &scope); auto op = CreateOp<operators::SliceOp>(opdesc, &scope);
LaunchOp(op, {x_var_name}, {out_var_name}); LaunchOp(op, {x_var_name}, {out_var_name});
Tensor output_trans; Tensor output_trans;
auto os = out->dims(); auto os = out->dims().Vectorize();
output_trans.Resize(os); output_trans.Resize(os);
transpose(out->mutable_data<float>(), std::vector<int> o_dims(os.size());
output_trans.mutable_data<float>(), for (size_t i = 0; i < os.size(); ++i) {
{static_cast<int>(os[0]), o_dims[i] = os[nchw2nhwc_axis[i]];
static_cast<int>(os[2]), }
static_cast<int>(os[3]), transpose<float*>(out->mutable_data<float>(),
static_cast<int>(os[1])}, output_trans.mutable_data<float>(),
{0, 3, 1, 2}); o_dims,
nhwc2nchw_axis);
auto out_data = output_trans.mutable_data<float>(); auto out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) { for (int i = 0; i < out->dims().production(); i++) {
...@@ -141,8 +158,8 @@ static void test_case(std::vector<int64_t> x_shape, ...@@ -141,8 +158,8 @@ static void test_case(std::vector<int64_t> x_shape,
TEST(MLUBridges, slice) { TEST(MLUBridges, slice) {
/* test_case({3}, {3}, {-3}, {3}, {0}); */ /* test_case({3}, {3}, {-3}, {3}, {0}); */
/* test_case({3, 4}, {3, 4}, {-3, 0}, {3, 100}, {0, 1}); */ test_case({3, 4}, {3, 4}, {-3, 0}, {3, 100}, {0, 1});
/* test_case({3, 4, 5}, {3, 4, 2}, {-3, 0, 2}, {3, 100, -1}, {0, 1, 2}); */ test_case({3, 4, 5}, {3, 4, 2}, {-3, 0, 2}, {3, 100, -1}, {0, 1, 2});
test_case({3, 4, 5, 6}, {3, 4, 2, 6}, {-3, 0, 2}, {3, 100, -1}, {0, 1, 2}); test_case({3, 4, 5, 6}, {3, 4, 2, 6}, {-3, 0, 2}, {3, 100, -1}, {0, 1, 2});
/* test_case({3, 4, 5, 6, 3}, {3, 4, 2, 6, 3}, {-3, 0, 2}, {3, 100, -1}, {0, /* test_case({3, 4, 5, 6, 3}, {3, 4, 2, 6, 3}, {-3, 0, 2}, {3, 100, -1}, {0,
* 1, 2}); */ * 1, 2}); */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册