提交 cf688607 编写于 作者: J jiaopu 提交者: jackzhang235

Add layout op

上级 c096096f
......@@ -620,7 +620,8 @@ std::string CheckInputAndInsert(Scope* scope,
auto layout_op = block_desc->AddOp<cpp::OpDesc>();
auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name);
VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name;
VLOG(5) << "insert layout in subgraph, arg tensor name: "
<< layout_arg_name;
layout_op->SetType("layout");
layout_op->SetInput("Input", {cur_node});
layout_op->SetOutput("Out", {layout_arg_name});
......@@ -663,7 +664,8 @@ std::string CheckOutputAndInsert(Scope* scope,
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name);
VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name;
VLOG(5) << "insert layout in subgraph, arg tensor name: "
<< layout_arg_name;
layout_op = block_desc->AddOp<cpp::OpDesc>();
layout_op->SetType("layout");
layout_op->SetInput("Input", {layout_arg_name});
......@@ -709,16 +711,22 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
auto input_name = input->AsArg().name;
if (!(input->AsArg().is_weight || input->AsArg().is_persist)) {
i_names.emplace_back(input_name);
node_replace[input_name] = CheckInputAndInsert(
op->scope(), new_block_desc, input_name, input->AsArg().type, subgraph_type);
node_replace[input_name] = CheckInputAndInsert(op->scope(),
new_block_desc,
input_name,
input->AsArg().type,
subgraph_type);
}
}
for (auto& output : subgraph_node->outlinks) {
auto output_name = output->AsArg().name;
if (!(output->AsArg().is_weight || output->AsArg().is_persist)) {
o_names.emplace_back(output_name);
node_replace[output_name] = CheckOutputAndInsert(
op->scope(), block_desc, output_name, output->AsArg().type, subgraph_type);
node_replace[output_name] = CheckOutputAndInsert(op->scope(),
block_desc,
output_name,
output->AsArg().type,
subgraph_type);
}
}
......
......@@ -23,6 +23,7 @@ lite_cc_library(subgraph_bridge_dropout_op_mlu SRCS dropout_op.cc DEPS ${subgrap
lite_cc_library(subgraph_bridge_slice_op_mlu SRCS slice_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_split_op_mlu SRCS split_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_cast_op_mlu SRCS cast_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_layout_op_mlu SRCS layout_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_argmax_op_mlu SRCS argmax_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_squeeze_op_mlu SRCS squeeze_op.cc DEPS ${subgraph_bridge_deps_mlu})
set(mlu_subgraph_bridges
......@@ -44,6 +45,7 @@ set(mlu_subgraph_bridges
subgraph_bridge_slice_op_mlu
subgraph_bridge_split_op_mlu
subgraph_bridge_cast_op_mlu
subgraph_bridge_layout_op_mlu
subgraph_bridge_argmax_op_mlu
subgraph_bridge_squeeze_op_mlu
CACHE INTERNAL "mlu_subgraph_bridges")
......@@ -71,6 +73,7 @@ lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope o
lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_slice_converter_mlu SRCS slice_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_split_converter_mlu SRCS split_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_layout_converter_mlu SRCS layout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_cast_converter_mlu SRCS cast_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int LayoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("Input").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
std::shared_ptr<MLUTensor> output_tensor;
CHECK(graph->HasNode(x_var_name));
std::vector<int> axis;
auto x_tensor = graph->GetNode(x_var_name);
auto x_data_order = x_tensor->dorder();
auto x_dims = x->dims().Vectorize();
if (x_data_order == CNML_NCHW) {
switch (x_dims.size()) {
case 2:
axis = {0, 1};
break;
case 3:
axis = {0, 2, 1};
break;
case 4:
axis = {0, 2, 3, 1};
break;
case 5:
axis = {0, 2, 3, 4, 1};
break;
default:
CHECK(0) << "Unsupport shape";
}
output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
} else {
switch (x_dims.size()) {
case 2:
axis = {0, 1};
break;
case 3:
axis = {0, 2, 1};
break;
case 4:
axis = {0, 3, 1, 2};
break;
case 5:
axis = {0, 4, 1, 2, 3};
break;
default:
CHECK(0) << "Unsupport shpae";
}
output_tensor = graph->AddNode(out_var_name,
output_dims,
CNML_TENSOR,
CNML_NCHW,
graph->FPType(),
CNML_NCHW);
}
cnmlBaseOp_t layout_op;
cnmlNdTransposeOpParam_t transpose_param;
CNML_CALL(
cnmlCreateNdTransposeOpParam(&transpose_param, axis.data(), axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&layout_op,
x_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
transpose_param));
graph->FuseOp(layout_op);
CNML_CALL(cnmlDestroyBaseOp(&layout_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(layout,
kMLU,
paddle::lite::subgraph::mlu::LayoutConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/layout_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
void test_layout_NHWC2NCHW(std::vector<int64_t> input_shape) {
// prepare input&output variables
std::string x_var_name = "input";
std::string out_var_name = "out";
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
x->Resize(DDim(input_shape));
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("layout");
opdesc.SetInput("Input", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
auto op = CreateOp<operators::LayoutOp>(opdesc, &scope);
// execute reference implementation and save to output tensor
Tensor input;
input.Resize(DDim(input_shape));
switch (input_shape.size()) {
case 2:
transpose<float*>(
x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), static_cast<int>(input_shape[1])},
{0, 1});
break;
case 3:
transpose<float*>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[1])},
{0, 2, 1});
break;
case 4:
transpose<float*>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3]),
static_cast<int>(input_shape[1])},
{0, 3, 1, 2});
break;
case 5:
transpose<float*>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3]),
static_cast<int>(input_shape[4]),
static_cast<int>(input_shape[1])},
{0, 4, 1, 2, 3});
break;
default:
CHECK(0) << "Unsupport";
}
auto* x_data = input.mutable_data<float>();
LaunchOp(op, {x_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], x_data[i], 5e-4);
}
}
void test_layout_NCHW2NHWC(std::vector<int64_t> input_shape) {
// prepare input&output variables
std::string x_var_name = "input";
std::string out_var_name = "out";
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
x->Resize(DDim(input_shape));
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("layout");
opdesc.SetInput("Input", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
auto op = CreateOp<operators::LayoutOp>(opdesc, &scope);
// execute reference implementation and save to output tensor
Tensor input;
input.Resize(DDim(input_shape));
switch (input_shape.size()) {
case 2:
transpose<float*>(
x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]), static_cast<int>(input_shape[1])},
{0, 1});
break;
case 3:
transpose<float*>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2])},
{0, 2, 1});
break;
case 4:
transpose<float*>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])},
{0, 2, 3, 1});
break;
case 5:
transpose<float*>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3]),
static_cast<int>(input_shape[4])},
{0, 2, 3, 4, 1});
break;
default:
CHECK(0) << "Unsupport";
}
auto* x_data = input.mutable_data<float>();
LaunchOp(op, {x_var_name}, {out_var_name}, CNML_NCHW);
// compare results
auto* out_data = out->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], x_data[i], 5e-4);
}
}
TEST(MLUBridges, layout) {
test_layout_NHWC2NCHW({12, 32, 4});
test_layout_NHWC2NCHW({12, 32, 44, 3});
test_layout_NHWC2NCHW({12, 32, 44, 3, 6});
test_layout_NCHW2NHWC({12, 32, 55});
test_layout_NCHW2NHWC({12, 32, 44, 3});
test_layout_NCHW2NHWC({12, 32, 44, 3, 8});
test_layout_NHWC2NCHW({12, 32});
test_layout_NCHW2NHWC({12, 32});
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(layout, kMLU);
......@@ -35,6 +35,7 @@ USE_SUBGRAPH_BRIDGE(dropout, kMLU);
USE_SUBGRAPH_BRIDGE(argmax, kMLU);
USE_SUBGRAPH_BRIDGE(split, kMLU);
USE_SUBGRAPH_BRIDGE(cast, kMLU);
USE_SUBGRAPH_BRIDGE(layout, kMLU);
USE_SUBGRAPH_BRIDGE(slice, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze2, kMLU);
......
......@@ -246,7 +246,10 @@ void MLUTensor::remember(const std::vector<int>& shape,
break;
}
}
dim_ = shape_.size();
auto shape_NCHW = DimNHWC2NCHW(shape_);
shape_NCHW.erase(shape_NCHW.begin() + shape.size(), shape_NCHW.end());
dim_ = shape_NCHW.size();
shape_ = DimNCHW2NHWC(shape_NCHW);
}
void MLUTensor::Create() {
......
......@@ -59,6 +59,7 @@ class MLUTensor {
~MLUTensor();
void ToFile(std::string file_name);
cnmlDataOrder_t dorder() { return data_order_; }
private:
cnmlTensor_t mlu_tensor_;
......
......@@ -27,7 +27,8 @@ namespace mlu {
template <lite_api::PrecisionType Dtype>
void PrepareInput(Graph* graph,
const std::string& input_name,
Tensor* input_tensor) {
Tensor* input_tensor,
cnmlDataOrder_t order) {
thread_local Tensor temp_input;
temp_input.Resize(input_tensor->dims().Vectorize());
temp_input.CopyDataFrom(*input_tensor);
......@@ -38,7 +39,7 @@ void PrepareInput(Graph* graph,
CNML_TENSOR,
CNML_NCHW,
MLUTypeTraits<Dtype>::cnml_type,
CNML_NHWC,
order,
reinterpret_cast<void*>(
input_tensor->template mutable_data<data_type>(TARGET(kMLU))));
CHECK(input_node);
......@@ -50,7 +51,8 @@ void PrepareInput(Graph* graph,
void LaunchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names) {
const std::vector<std::string>& output_var_names,
cnmlDataOrder_t order) {
CNRT_CALL(cnrtInit(0));
lite::SetMluDevice(0);
cnrtQueue_t queue_;
......@@ -77,9 +79,9 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op,
auto data_type = input_tensor->precision();
switch (data_type) {
#define PREPARE_INPUT(type__) \
case PRECISION(type__): \
PrepareInput<PRECISION(type__)>(&graph, input_name, input_tensor); \
#define PREPARE_INPUT(type__) \
case PRECISION(type__): \
PrepareInput<PRECISION(type__)>(&graph, input_name, input_tensor, order); \
break;
PREPARE_INPUT(kFP16)
PREPARE_INPUT(kFloat)
......
......@@ -58,7 +58,8 @@ void FillTensor(Tensor* x,
void LaunchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names);
const std::vector<std::string>& output_var_names,
cnmlDataOrder_t order = CNML_NHWC);
} // namespace mlu
} // namespace subgraph
......
......@@ -47,22 +47,74 @@ void transpose(dtype input_data,
std::vector<int> axis) {
int old_index = -1;
int new_index = -1;
int dim[4] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) {
old_index = dim[0] * shape[1] * shape[2] * shape[3] +
dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + dim[3];
new_index =
dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[2]] * shape[axis[3]] + dim[axis[3]];
if (input_shape.size() == 2) {
int dim[2] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
old_index = dim[0] * shape[1] + dim[1];
new_index = dim[axis[0]] * shape[axis[1]] + dim[axis[1]];
output_data[new_index] = input_data[old_index];
}
}
} else if (input_shape.size() == 3) {
int dim[3] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
old_index = dim[0] * shape[1] * shape[2] + dim[1] * shape[2] + dim[2];
new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] +
dim[axis[1]] * shape[axis[2]] + dim[axis[2]];
output_data[new_index] = input_data[old_index];
}
}
}
} else if (input_shape.size() == 4) {
int dim[4] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) {
old_index = dim[0] * shape[1] * shape[2] * shape[3] +
dim[1] * shape[2] * shape[3] + dim[2] * shape[3] +
dim[3];
new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] *
shape[axis[3]] +
dim[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[2]] * shape[axis[3]] + dim[axis[3]];
output_data[new_index] = input_data[old_index];
}
}
}
}
} else if (input_shape.size() == 5) {
int dim[5] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) {
for (dim[4] = 0; dim[4] < input_shape[4]; dim[4]++) {
old_index = dim[0] * shape[1] * shape[2] * shape[3] * shape[4] +
dim[1] * shape[2] * shape[3] * shape[4] +
dim[2] * shape[3] * shape[4] + dim[3] * shape[4] +
dim[4];
new_index = dim[axis[0]] * shape[axis[1]] * shape[axis[2]] *
shape[axis[3]] * shape[axis[4]] +
dim[axis[1]] * shape[axis[2]] * shape[axis[3]] *
shape[axis[4]] +
dim[axis[2]] * shape[axis[3]] * shape[axis[4]] +
dim[axis[3]] * shape[axis[4]] + dim[axis[4]];
output_data[new_index] = input_data[old_index];
}
}
}
}
}
} else {
}
}
......@@ -103,41 +155,39 @@ inline const ::paddle::lite::DDimLite DimNCHW2NHWC(
std::vector<int64_t>({dim[0], dim[2], dim[3], dim[1]}));
}
inline const std::vector<DDimLite::value_type> DimNHWC2NCHW(
const std::vector<DDimLite::value_type>& dim) {
template <typename data_type>
inline const std::vector<data_type> DimNHWC2NCHW(
const std::vector<data_type>& dim) {
switch (dim.size()) {
case 1:
return dim;
case 2:
return dim;
case 3:
return std::vector<DDimLite::value_type>({dim[0], dim[2], dim[1]});
return std::vector<data_type>({dim[0], dim[2], dim[1]});
case 4:
return std::vector<DDimLite::value_type>(
{dim[0], dim[3], dim[1], dim[2]});
return std::vector<data_type>({dim[0], dim[3], dim[1], dim[2]});
case 5:
return std::vector<DDimLite::value_type>(
{dim[0], dim[4], dim[1], dim[2], dim[3]});
return std::vector<data_type>({dim[0], dim[4], dim[1], dim[2], dim[3]});
default:
CHECK(0) << "unsupport dimension";
}
}
inline const std::vector<DDimLite::value_type> DimNCHW2NHWC(
const std::vector<DDimLite::value_type>& dim) {
template <typename data_type>
inline const std::vector<data_type> DimNCHW2NHWC(
const std::vector<data_type>& dim) {
switch (dim.size()) {
case 1:
return dim;
case 2:
return dim;
case 3:
return std::vector<DDimLite::value_type>({dim[0], dim[2], dim[1]});
return std::vector<data_type>({dim[0], dim[2], dim[1]});
case 4:
return std::vector<DDimLite::value_type>(
{dim[0], dim[2], dim[3], dim[1]});
return std::vector<data_type>({dim[0], dim[2], dim[3], dim[1]});
case 5:
return std::vector<DDimLite::value_type>(
{dim[0], dim[2], dim[3], dim[4], dim[1]});
return std::vector<data_type>({dim[0], dim[2], dim[3], dim[4], dim[1]});
default:
CHECK(0) << "unsupport dimension";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册