提交 124b097a 编写于 作者: J jackzhang235

shape is alwayes nchw, layout can be chosen

上级 68533439
...@@ -37,7 +37,7 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -37,7 +37,7 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>(); auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize(); auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, fp_type); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, fp_type);
CHECK(graph->HasNode(x_var_name)); CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t activation_op; cnmlBaseOp_t activation_op;
......
...@@ -42,7 +42,7 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -42,7 +42,7 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output = scope->FindVar(y_var_name)->GetMutable<Tensor>(); auto output = scope->FindVar(y_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize(); auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
y_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); y_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
CHECK(graph->HasNode(x_var_name)); CHECK(graph->HasNode(x_var_name));
......
...@@ -32,60 +32,33 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -32,60 +32,33 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto x_var_name = op_info->Input("X"); auto x_var_name = op_info->Input("X");
auto out_var_name = op_info->Output("Out").front(); auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto param_axis = op_info->GetAttr<int>("axis"); auto param_axis = op_info->GetAttr<int>("axis");
// auto x = scope->FindVar(x_var_name[0])->GetMutable<Tensor>();
auto input_num = x_var_name.size();
std::vector<cnmlTensor_t> input_tensor; std::vector<cnmlTensor_t> input_tensor;
std::vector<std::vector<int64_t>> input_dims;
for (auto x_name : x_var_name) { for (auto x_name : x_var_name) {
CHECK(graph->HasNode(x_name)); CHECK(graph->HasNode(x_name));
input_tensor.push_back(graph->GetNode(x_name)->mlu_tensor()); input_tensor.push_back(graph->GetNode(x_name)->mlu_tensor());
auto x = scope->FindVar(x_name)->GetMutable<Tensor>();
input_dims.push_back(x->dims().Vectorize());
} }
auto dims = input_dims[0].size(); auto dims = output_dims.size();
int axis = (param_axis < 0) ? (param_axis + dims) : param_axis; int axis = (param_axis < 0) ? (param_axis + dims) : param_axis;
int nhwc_axis = -1; CHECK_LE(axis, 4) << "Unsupport dims in mlu concat";
if (dims == 4) {
int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2}; int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2};
nhwc_axis = nchw_to_nhwc_axis_map[axis]; int nhwc_axis = nchw_to_nhwc_axis_map[axis];
} else if (dims == 3) {
int nchw_to_nhwc_axis_map[3] = {0, 2, 1};
nhwc_axis = nchw_to_nhwc_axis_map[axis];
} else {
CHECK(0) << "Unsupport dims in mlu concat";
}
std::vector<int64_t> output_dims;
output_dims.assign(dims, 0);
/* std::cout << string_format("concat axis: %d(NCHW), %d(NHWC)", axis,
* nhwc_axis) << std::endl; */
for (int i = 0; i < output_dims.size(); ++i) {
if (i == nhwc_axis) {
for (auto& dim : input_dims) output_dims[i] += dim[i];
} else {
output_dims[i] = input_dims[0][i];
}
}
/* std::cout << string_format("concat output dim: %ld, %ld, %ld, %ld") <<
* std::endl; */
auto* output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
output->Resize(output_dims);
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t concat_op; cnmlBaseOp_t concat_op;
cnmlTensor_t outputs[1]; cnmlTensor_t outputs = output_tensor->mlu_tensor();
outputs[0] = output_tensor->mlu_tensor(); CNML_CALL(cnmlCreateNdConcatOp(&concat_op,
CNML_CALL(cnmlCreateNdConcatOp( nhwc_axis,
&concat_op, nhwc_axis, input_tensor.data(), input_num, outputs, 1)); input_tensor.data(),
x_var_name.size(),
&outputs,
1));
graph->FuseOp(concat_op); graph->FuseOp(concat_op);
return SUCCESS; return SUCCESS;
} }
......
...@@ -33,13 +33,14 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -33,13 +33,14 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// get input, filter and op attributes // get input, filter and op attributes
const auto input_var_name = op_info->Input("Input").front(); const auto input_var_name = op_info->Input("Input").front();
const auto& input_dims_nhwc = const auto& input_dims =
scope->FindVar(input_var_name)->GetMutable<Tensor>()->dims(); scope->FindVar(input_var_name)->GetMutable<Tensor>()->dims();
const auto input_dims = DimNHWC2NCHW(input_dims_nhwc);
const auto filter_var_name = op_info->Input("Filter").front(); const auto filter_var_name = op_info->Input("Filter").front();
auto* filter = scope->FindVar(filter_var_name)->GetMutable<Tensor>(); auto* filter = scope->FindVar(filter_var_name)->GetMutable<Tensor>();
const auto& filter_dims = filter->dims(); const auto& filter_dims = filter->dims();
const auto output_var_name = op_info->Output("Output").front(); const auto output_var_name = op_info->Output("Output").front();
auto* output = scope->FindVar(output_var_name)->GetMutable<Tensor>();
const auto output_shape = output->dims().Vectorize();
const auto bs = input_dims[0]; const auto bs = input_dims[0];
const auto oc = filter_dims[0]; const auto oc = filter_dims[0];
CHECK_EQ(input_dims.size(), 4); CHECK_EQ(input_dims.size(), 4);
...@@ -70,24 +71,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -70,24 +71,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
input_dims, input_dims,
filter_dims); filter_dims);
std::vector<int64_t> output_shape({bs, oc}); const auto output_tensor = graph->AddNode(
for (size_t i = 0; i < 2; i++) { output_var_name, output_shape, CNML_TENSOR, CNML_NCHW, graph->FPType());
const int dkernel = dilations[i] * (filter_dims[2 + i] - 1) + 1;
output_shape.push_back(
(input_dims[i + 2] + paddings[2 * i] + paddings[2 * i + 1] - dkernel) /
strides[i] +
1);
}
const auto output_shape_nhwc = DimNCHW2NHWC(output_shape);
const auto output_tensor = graph->AddNode(output_var_name,
output_shape_nhwc,
CNML_TENSOR,
CNML_NHWC,
graph->FPType());
scope->FindVar(output_var_name)
->GetMutable<::paddle::lite::Tensor>()
->Resize(output_shape_nhwc);
// Create filter node // Create filter node
const auto filter_tensor = graph->AddNode(filter_var_name, const auto filter_tensor = graph->AddNode(filter_var_name,
...@@ -156,7 +141,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -156,7 +141,7 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
const auto input_scale = op_info->GetAttr<float>("input_scale"); const auto input_scale = op_info->GetAttr<float>("input_scale");
bool use_first_conv = false; bool use_first_conv = false;
if (lite::DeviceInfo::Global().UseFirstConv() && input_dims_nhwc[3] == 3) { if (lite::DeviceInfo::Global().UseFirstConv() && input_dims[1] == 3) {
use_first_conv = true; use_first_conv = true;
} }
......
...@@ -77,7 +77,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -77,7 +77,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output_tensor = graph->AddNode(out_var_name, auto output_tensor = graph->AddNode(out_var_name,
x->dims().Vectorize(), x->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph->FPType()); graph->FPType());
cnmlBaseOp_t elementwise_op; cnmlBaseOp_t elementwise_op;
...@@ -90,7 +90,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -90,7 +90,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto mid_tensor = graph->AddNode(out_var_name + "_mid", auto mid_tensor = graph->AddNode(out_var_name + "_mid",
x->dims().Vectorize(), x->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph->FPType()); graph->FPType());
CNML_CALL(cnmlCreateBroadcastAddOp(&elementwise_op, CNML_CALL(cnmlCreateBroadcastAddOp(&elementwise_op,
x_tensor->mlu_tensor(), x_tensor->mlu_tensor(),
......
...@@ -37,6 +37,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -37,6 +37,7 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// int in_num_col_dims = op_info->GetAttr<int>("in_num_col_dims"); // int in_num_col_dims = op_info->GetAttr<int>("in_num_col_dims");
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>(); auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto w = scope->FindVar(w_var_name)->GetMutable<Tensor>(); auto w = scope->FindVar(w_var_name)->GetMutable<Tensor>();
auto output = scope->FindVar(output_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims(); auto x_dims = x->dims();
auto w_dims = w->dims(); auto w_dims = w->dims();
...@@ -50,15 +51,11 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -50,15 +51,11 @@ int FCConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto input_scale = op_info->GetAttr<float>("input_scale"); auto input_scale = op_info->GetAttr<float>("input_scale");
std::vector<int64_t> output_shape_nhwc({x_dims[0], 1, 1, w_dims[1]});
auto output_tensor = graph->AddNode(output_var_name, auto output_tensor = graph->AddNode(output_var_name,
output_shape_nhwc, output->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph->FPType()); graph->FPType());
scope->FindVar(output_var_name)
->GetMutable<::paddle::lite::Tensor>()
->Resize(output_shape_nhwc);
std::string bias_var_name; std::string bias_var_name;
std::shared_ptr<MLUTensor> bias_tensor; std::shared_ptr<MLUTensor> bias_tensor;
......
...@@ -25,12 +25,12 @@ namespace mlu { ...@@ -25,12 +25,12 @@ namespace mlu {
std::shared_ptr<MLUTensor> Graph::AddNode(const std::string& name, std::shared_ptr<MLUTensor> Graph::AddNode(const std::string& name,
std::vector<int64_t> shape, std::vector<int64_t> shape,
cnmlTensorType_t tensor_type, cnmlTensorType_t tensor_type,
cnmlDataOrder_t data_order, cnmlDataOrder_t shape_order,
cnmlDataType_t mlu_dtype, cnmlDataType_t mlu_dtype,
void* raw_ptr) { void* raw_ptr) {
CHECK(!HasNode(name)); CHECK(!HasNode(name));
auto node = std::shared_ptr<MLUTensor>( auto node = std::shared_ptr<MLUTensor>(
new MLUTensor(shape, tensor_type, data_order, mlu_dtype)); new MLUTensor(shape, tensor_type, shape_order, mlu_dtype));
node->set_mlu_ptr(raw_ptr); node->set_mlu_ptr(raw_ptr);
nodes_.insert(std::make_pair(name, node)); nodes_.insert(std::make_pair(name, node));
return node; return node;
......
...@@ -45,8 +45,8 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -45,8 +45,8 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(graph->HasNode(x_var_name)); CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
auto in_h = x_dims[1]; auto in_h = x_dims[2];
auto in_w = x_dims[2]; auto in_w = x_dims[3];
// Priority: SizeTensor > OutSize > Scale > scale > out_h/out_w // Priority: SizeTensor > OutSize > Scale > scale > out_h/out_w
if (HasInputArg(op_info, scope, "SizeTensor")) { if (HasInputArg(op_info, scope, "SizeTensor")) {
...@@ -69,25 +69,13 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -69,25 +69,13 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
} }
out->Resize({x_dims[0], out_h, out_w, x_dims[3]});
auto output_tensor = graph->AddNode(out_var_name, auto output_tensor = graph->AddNode(out_var_name,
out->dims().Vectorize(), out->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph->FPType()); graph->FPType());
cnmlBaseOp_t interp_op; cnmlBaseOp_t interp_op;
/* if (interp_method == "bilinear") { */
/* cnmlInterpOpParam_t interp_param; */
/* CNML_CALL(cnmlCreateInterpOpParam(&interp_param, out_w, out_h,
* align_corners)); */
/* CNML_CALL(cnmlCreateInterpOp(&interp_op, */
/* input_tensor->mlu_tensor(), */
/* output_tensor->mlu_tensor(), */
/* interp_param)); */
/* CNML_CALL(cnmlDestroyInterpOpParam(&interp_param)); */
/* } else if (interp_method == "nearest") { */
cnmlNearestNeighborOpParam_t nn_param; cnmlNearestNeighborOpParam_t nn_param;
CNML_CALL(cnmlCreateNearestNeighborOpParam(&nn_param, out_w, out_h)); CNML_CALL(cnmlCreateNearestNeighborOpParam(&nn_param, out_w, out_h));
CNML_CALL(cnmlSetNearestNeighborAlignCorner(&nn_param, align_corners)); CNML_CALL(cnmlSetNearestNeighborAlignCorner(&nn_param, align_corners));
...@@ -96,11 +84,6 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -96,11 +84,6 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
output_tensor->mlu_tensor(), output_tensor->mlu_tensor(),
nn_param)); nn_param));
CNML_CALL(cnmlDestroyNearestNeighborOpParam(&nn_param)); CNML_CALL(cnmlDestroyNearestNeighborOpParam(&nn_param));
/* } else { */
/* LOG(WARNING) << "[MLU] Unsupported interpolate method: " <<
* interp_method; */
/* return FAILED; */
/* } */
graph->FuseOp(interp_op); graph->FuseOp(interp_op);
return SUCCESS; return SUCCESS;
......
...@@ -47,9 +47,8 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -47,9 +47,8 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input, and attributes // Get input, and attributes
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_var_name); auto x = scope->FindTensor(x_var_name);
auto input_dims_nhwc = x->dims();
const auto input_dims = DimNHWC2NCHW(input_dims_nhwc);
auto output_var_name = op_info->Output("Out").front(); auto output_var_name = op_info->Output("Out").front();
auto output_shape = scope->FindTensor(output_var_name)->dims().Vectorize();
auto pooling_type = op_info->GetAttr<std::string>("pooling_type"); auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
auto ceil_mode = op_info->GetAttr<bool>("ceil_mode"); auto ceil_mode = op_info->GetAttr<bool>("ceil_mode");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings"); auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
...@@ -81,23 +80,17 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -81,23 +80,17 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
strides, strides,
ksize); ksize);
std::vector<int64_t> output_shape({input_dims[0], input_dims[1]}); // std::vector<int64_t> output_shape({input_dims[0], input_dims[1]});
for (size_t i = 0; i < 2; i++) { // for (size_t i = 0; i < 2; i++) {
output_shape.push_back( // output_shape.push_back(
(input_dims[i + 2] + paddings[2 * i] + paddings[2 * i + 1] - ksize[0]) / // (input_dims[i + 2] + paddings[2 * i] + paddings[2 * i + 1] -
strides[i] + // ksize[0]) /
1); // strides[i] +
} // 1);
// }
auto output_shape_nhwc = DimNCHW2NHWC(output_shape); auto output_tensor = graph->AddNode(
auto output_tensor = graph->AddNode(output_var_name, output_var_name, output_shape, CNML_TENSOR, CNML_NCHW, graph->FPType());
output_shape_nhwc,
CNML_TENSOR,
CNML_NHWC,
graph->FPType());
scope->FindVar(output_var_name)
->GetMutable<::paddle::lite::Tensor>()
->Resize(output_shape_nhwc);
cnmlPoolOpParam_t pool_param; cnmlPoolOpParam_t pool_param;
CNML_CALL( CNML_CALL(
......
...@@ -36,7 +36,7 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -36,7 +36,7 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>(); auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize(); auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
auto bias_after_scale = op_info->GetAttr<bool>("bias_after_scale"); auto bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
auto scale = op_info->GetAttr<float>("scale"); auto scale = op_info->GetAttr<float>("scale");
auto bias = op_info->GetAttr<float>("bias"); auto bias = op_info->GetAttr<float>("bias");
......
...@@ -45,11 +45,10 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -45,11 +45,10 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
axis = output_dims.size() + axis; axis = output_dims.size() + axis;
} }
} }
int nhwc_axis = nchw_to_nhwc_aixs_map[axis]; int nhwc_axis = nchw_to_nhwc_aixs_map[axis];
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t softmax_op; cnmlBaseOp_t softmax_op;
CNML_CALL(cnmlCreateNdSoftmaxOp(&softmax_op, CNML_CALL(cnmlCreateNdSoftmaxOp(&softmax_op,
nhwc_axis, nhwc_axis,
......
...@@ -21,8 +21,30 @@ namespace lite { ...@@ -21,8 +21,30 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace mlu { namespace mlu {
std::vector<int> axis_to_nhwc4d(const std::vector<int>& axis) { // std::vector<int> axis_to_nhwc4d(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 4); // CHECK_EQ(axis.size(), 4);
// std::vector<int> new_axis(4, 0);
// const std::vector<int> axis_map1 = {0, 2, 3, 1};
// const std::vector<int> axis_map2 = {0, 3, 1, 2};
// for (size_t i = 0; i < new_axis.size(); ++i) {
// new_axis[i] = axis_map2[axis[axis_map1[i]]];
// }
// return new_axis;
//}
//
// std::vector<int> axis_to_nhw3d(const std::vector<int>& axis) {
// CHECK_EQ(axis.size(), 3);
// std::vector<int> new_axis(3, 0);
// const std::vector<int> axis_map = {0, 2, 1};
// for (size_t i = 0; i < new_axis.size(); ++i) {
// new_axis[i] = axis_map[axis[axis_map[i]]];
// }
// new_axis.push_back(3);
// return new_axis;
//}
std::vector<int> axis_to_nhwc(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 4) << "Unsupport dim in mlu transpose";
std::vector<int> new_axis(4, 0); std::vector<int> new_axis(4, 0);
const std::vector<int> axis_map1 = {0, 2, 3, 1}; const std::vector<int> axis_map1 = {0, 2, 3, 1};
const std::vector<int> axis_map2 = {0, 3, 1, 2}; const std::vector<int> axis_map2 = {0, 3, 1, 2};
...@@ -32,26 +54,6 @@ std::vector<int> axis_to_nhwc4d(const std::vector<int>& axis) { ...@@ -32,26 +54,6 @@ std::vector<int> axis_to_nhwc4d(const std::vector<int>& axis) {
return new_axis; return new_axis;
} }
std::vector<int> axis_to_nhw3d(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 3);
std::vector<int> new_axis(3, 0);
const std::vector<int> axis_map = {0, 2, 1};
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = axis_map[axis[axis_map[i]]];
}
new_axis.push_back(3);
return new_axis;
}
std::vector<int64_t> infer_shape(const std::vector<int64_t>& x_dims,
const std::vector<int>& axis_nhwc) {
std::vector<int64_t> out_dims(x_dims);
for (size_t i = 0; i < out_dims.size(); ++i) {
out_dims[i] = x_dims[axis_nhwc[i]];
}
return out_dims;
}
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr); CHECK(ctx != nullptr);
CHECK(op != nullptr); CHECK(op != nullptr);
...@@ -71,21 +73,13 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -71,21 +73,13 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output_dims = output->dims().Vectorize(); auto output_dims = output->dims().Vectorize();
auto axis = op_info->GetAttr<std::vector<int>>("axis"); auto axis = op_info->GetAttr<std::vector<int>>("axis");
while (axis.size() < 4) {
std::vector<int> axis_nhwc; axis.push_back(axis.size());
if (axis.size() == 4) {
axis_nhwc = axis_to_nhwc4d(axis);
} else if (axis.size() == 3) {
axis_nhwc = axis_to_nhw3d(axis);
} else {
CHECK(0) << "Unsupport dim in mlu transpose";
} }
std::vector<int> axis_nhwc = axis_to_nhwc(axis);
auto output_dims_nhwc = infer_shape(x_dims, axis_nhwc);
output->Resize(output_dims_nhwc);
auto output_tensor = graph->AddNode( auto output_tensor = graph->AddNode(
out_var_name, output_dims_nhwc, CNML_TENSOR, CNML_NHWC, graph->FPType()); out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
CHECK(graph->HasNode(x_var_name)); CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name); auto input_tensor = graph->GetNode(x_var_name);
...@@ -113,7 +107,6 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -113,7 +107,6 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
REGISTER_SUBGRAPH_BRIDGE(transpose, REGISTER_SUBGRAPH_BRIDGE(transpose,
kMLU, kMLU,
paddle::lite::subgraph::mlu::TransposeConverter); paddle::lite::subgraph::mlu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(transpose2, REGISTER_SUBGRAPH_BRIDGE(transpose2,
kMLU, kMLU,
paddle::lite::subgraph::mlu::TransposeConverter); paddle::lite::subgraph::mlu::TransposeConverter);
...@@ -67,6 +67,8 @@ class LayoutNchwToNhwcCompute ...@@ -67,6 +67,8 @@ class LayoutNchwToNhwcCompute
auto x_dims = param.x->dims().size(); auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>(); auto& context = this->ctx_->template As<X86Context>();
const auto origin_dims = out->dims().Vectorize();
std::vector<int> axis; std::vector<int> axis;
switch (x_dims) { switch (x_dims) {
case 2: case 2:
...@@ -88,6 +90,10 @@ class LayoutNchwToNhwcCompute ...@@ -88,6 +90,10 @@ class LayoutNchwToNhwcCompute
LayoutTransCompute<lite::TargetType::kX86, float>( LayoutTransCompute<lite::TargetType::kX86, float>(
x_dims, context, *x, out, axis); x_dims, context, *x, out, axis);
if (x_dims > 2) {
out->Resize(origin_dims);
}
} }
std::string doc() const override { std::string doc() const override {
...@@ -109,20 +115,22 @@ class LayoutNhwcToNchwCompute ...@@ -109,20 +115,22 @@ class LayoutNhwcToNchwCompute
auto x_dims = param.x->dims().size(); auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>(); auto& context = this->ctx_->template As<X86Context>();
const auto origin_dims = out->dims().Vectorize();
std::vector<int> axis; std::vector<int> axis;
switch (x_dims) { switch (x_dims) {
case 2: case 2:
axis = {0, 1}; axis = {0, 1};
break; break;
case 3: case 3:
axis = {0, 2, 1};
out->Resize(std::vector<int64_t>{ out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[1]}); out->dims()[0], out->dims()[2], out->dims()[1]});
axis = {0, 2, 1};
break; break;
case 4: case 4:
axis = {0, 3, 1, 2};
out->Resize(std::vector<int64_t>{ out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[3], out->dims()[1], out->dims()[2]}); out->dims()[0], out->dims()[3], out->dims()[1], out->dims()[2]});
axis = {0, 3, 1, 2};
break; break;
default: default:
CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw"; CHECK(0) << "Unsupport dim in mlu layout nhwc to nchw";
...@@ -130,6 +138,10 @@ class LayoutNhwcToNchwCompute ...@@ -130,6 +138,10 @@ class LayoutNhwcToNchwCompute
LayoutTransCompute<lite::TargetType::kX86, float>( LayoutTransCompute<lite::TargetType::kX86, float>(
x_dims, context, *x, out, axis); x_dims, context, *x, out, axis);
if (x_dims > 2) {
out->Resize(origin_dims);
}
} }
std::string doc() const override { std::string doc() const override {
......
...@@ -83,7 +83,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -83,7 +83,7 @@ class SubgraphEngine : public subgraph::Engine {
graph_.AddNode(input_name, graph_.AddNode(input_name,
input_tensor->dims().Vectorize(), input_tensor->dims().Vectorize(),
CNML_TENSOR, CNML_TENSOR,
CNML_NHWC, CNML_NCHW,
graph_.FPType(), graph_.FPType(),
const_cast<void*>(input_tensor->raw_data())); const_cast<void*>(input_tensor->raw_data()));
CHECK(input_node); CHECK(input_node);
...@@ -99,9 +99,7 @@ class SubgraphEngine : public subgraph::Engine { ...@@ -99,9 +99,7 @@ class SubgraphEngine : public subgraph::Engine {
CHECK(op); CHECK(op);
std::string op_type = op->op_info()->Type(); std::string op_type = op->op_info()->Type();
op->CheckShape(); op->CheckShape();
if (op_type != "concat") {
op->InferShape(); op->InferShape();
}
if (!bridges.Exists(op_type, TARGET(kMLU))) { if (!bridges.Exists(op_type, TARGET(kMLU))) {
LOG(INFO) << "MLU bridges doesn't support op_type: " << op_type; LOG(INFO) << "MLU bridges doesn't support op_type: " << op_type;
return subgraph::FAILED; return subgraph::FAILED;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册