提交 63da1451 编写于 作者: D dingminghui 提交者: jackzhang235

fix(layout): fix compute error in nhwc2nchw layout trans

上级 067111d4
......@@ -34,17 +34,17 @@ struct FPTypeTraits {};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFloat> {
typedef float T;
using type = float;
};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFP16> {
typedef paddle::lite::fluid::float16 T;
using type = paddle::lite::fluid::float16;
};
template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kInt8> {
typedef int8_t T;
using type = int8_t;
};
template <lite::TargetType Target, typename T>
......@@ -81,36 +81,36 @@ class LayoutNchwToNhwcCompute
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* out = param.y;
out->template mutable_data<typename FPTypeTraits<Precision>::T>();
auto x_dims = param.x->dims().size();
out->template mutable_data<typename FPTypeTraits<Precision>::type>();
auto x_ndims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>();
const auto origin_dims = out->dims().Vectorize();
std::vector<int> axis;
switch (x_dims) {
switch (x_ndims) {
case 2:
axis = {0, 1};
break;
case 3:
axis = {0, 2, 1};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[1]});
origin_dims[0], origin_dims[2], origin_dims[1]});
break;
case 4:
axis = {0, 2, 3, 1};
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[3], out->dims()[1]});
origin_dims[0], origin_dims[2], origin_dims[3], origin_dims[1]});
break;
default:
CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc";
}
LayoutTransCompute<lite::TargetType::kX86,
typename FPTypeTraits<Precision>::T>(
x_dims, context, *x, out, axis);
typename FPTypeTraits<Precision>::type>(
x_ndims, context, *x, out, axis);
if (x_dims > 2) {
if (x_ndims > 2) {
out->Resize(origin_dims);
}
}
......@@ -130,25 +130,26 @@ class LayoutNhwcToNchwCompute
auto& param = this->template Param<param_t>();
auto* x = param.x;
auto* out = param.y;
out->template mutable_data<typename FPTypeTraits<Precision>::T>();
auto x_dims = param.x->dims().size();
out->template mutable_data<typename FPTypeTraits<Precision>::type>();
auto& context = this->ctx_->template As<X86Context>();
const auto origin_dims = out->dims().Vectorize();
TensorLite tmp_t;
tmp_t.ShareDataWith(*x);
const auto x_dims = x->dims().Vectorize();
auto x_ndims = param.x->dims().size();
std::vector<int> axis;
switch (x_dims) {
switch (x_ndims) {
case 2:
axis = {0, 1};
break;
case 3:
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[1]});
tmp_t.Resize(std::vector<int64_t>{x_dims[0], x_dims[2], x_dims[1]});
axis = {0, 2, 1};
break;
case 4:
out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[3], out->dims()[1], out->dims()[2]});
tmp_t.Resize(
std::vector<int64_t>{x_dims[0], x_dims[2], x_dims[3], x_dims[1]});
axis = {0, 3, 1, 2};
break;
default:
......@@ -156,12 +157,8 @@ class LayoutNhwcToNchwCompute
}
LayoutTransCompute<lite::TargetType::kX86,
typename FPTypeTraits<Precision>::T>(
x_dims, context, *x, out, axis);
if (x_dims > 2) {
out->Resize(origin_dims);
}
typename FPTypeTraits<Precision>::type>(
x_ndims, context, tmp_t, out, axis);
}
std::string doc() const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册