提交 1ebad864 编写于 作者: D dingminghui 提交者: jackzhang235

fix(mlu kernel): fix error caused by cancelling expanding tensor to 4 dims

上级 d397458f
......@@ -44,9 +44,14 @@ int ConcatConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto dims = output_dims.size();
int axis = (param_axis < 0) ? (param_axis + dims) : param_axis;
CHECK_LE(axis, 4) << "Unsupport dims in mlu concat";
int nchw_to_nhwc_axis_map[4] = {0, 3, 1, 2};
int nhwc_axis = nchw_to_nhwc_axis_map[axis];
CHECK_LT(axis, dims) << "Unsupport dims in mlu concat";
std::vector<int> nchw2nhwc_axis(dims);
nchw2nhwc_axis[0] = 0;
if (dims > 1) nchw2nhwc_axis[1] = dims - 1;
for (size_t i = 2; i < dims; ++i) {
nchw2nhwc_axis[i] = i - 1;
}
int nhwc_axis = nchw2nhwc_axis[axis];
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
......
......@@ -22,12 +22,24 @@ namespace subgraph {
namespace mlu {
std::vector<int> axis_to_nhwc(const std::vector<int>& axis) {
CHECK_EQ(axis.size(), 4) << "Unsupport dim in mlu transpose";
std::vector<int> new_axis(4, 0);
const std::vector<int> axis_map1 = {0, 2, 3, 1};
const std::vector<int> axis_map2 = {0, 3, 1, 2};
std::vector<int> new_axis(axis.size());
std::vector<int> nhwc2nchw_axis(axis.size());
nhwc2nchw_axis[0] = 0;
if (axis.size() > 1) nhwc2nchw_axis[1] = axis.size() - 1;
for (size_t i = 2; i < axis.size(); ++i) {
nhwc2nchw_axis[i] = i - 1;
}
std::vector<int> nchw2nhwc_axis(axis.size());
nchw2nhwc_axis[0] = 0;
for (size_t i = 1; i < axis.size() - 1; ++i) {
nchw2nhwc_axis[i] = i + 1;
}
if (axis.size() > 1) nchw2nhwc_axis[axis.size() - 1] = 1;
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = axis_map2[axis[axis_map1[i]]];
new_axis[i] = nhwc2nchw_axis[axis[nchw2nhwc_axis[i]]];
}
return new_axis;
}
......@@ -51,9 +63,6 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output_dims = output->dims().Vectorize();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
while (axis.size() < 4) {
axis.push_back(axis.size());
}
std::vector<int> axis_nhwc = axis_to_nhwc(axis);
auto output_tensor = graph->AddNode(
......
......@@ -173,14 +173,14 @@ REGISTER_LITE_KERNEL(
kMLU,
kInt8,
kNHWC,
paddle::lite::kernels::mlu::IoCopyMluToHostCompute<PRECISION(kInt8)>,
device_to_host_kInt8)
paddle::lite::kernels::mlu::IoCopyHostToMluCompute<PRECISION(kInt8)>,
host_to_device_to_kInt8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kMLU),
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt8),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
{LiteType::GetTensorTy(TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kAny))})
.Finalize();
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册