提交 63da1451 编写于 作者: D dingminghui 提交者: jackzhang235

fix(layout): fix compute error in nhwc2nchw layout trans

上级 067111d4
...@@ -34,17 +34,17 @@ struct FPTypeTraits {}; ...@@ -34,17 +34,17 @@ struct FPTypeTraits {};
template <> template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFloat> { struct FPTypeTraits<paddle::lite_api::PrecisionType::kFloat> {
typedef float T; using type = float;
}; };
template <> template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kFP16> { struct FPTypeTraits<paddle::lite_api::PrecisionType::kFP16> {
typedef paddle::lite::fluid::float16 T; using type = paddle::lite::fluid::float16;
}; };
template <> template <>
struct FPTypeTraits<paddle::lite_api::PrecisionType::kInt8> { struct FPTypeTraits<paddle::lite_api::PrecisionType::kInt8> {
typedef int8_t T; using type = int8_t;
}; };
template <lite::TargetType Target, typename T> template <lite::TargetType Target, typename T>
...@@ -81,36 +81,36 @@ class LayoutNchwToNhwcCompute ...@@ -81,36 +81,36 @@ class LayoutNchwToNhwcCompute
auto& param = this->template Param<param_t>(); auto& param = this->template Param<param_t>();
auto* x = param.x; auto* x = param.x;
auto* out = param.y; auto* out = param.y;
out->template mutable_data<typename FPTypeTraits<Precision>::T>(); out->template mutable_data<typename FPTypeTraits<Precision>::type>();
auto x_dims = param.x->dims().size(); auto x_ndims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>(); auto& context = this->ctx_->template As<X86Context>();
const auto origin_dims = out->dims().Vectorize(); const auto origin_dims = out->dims().Vectorize();
std::vector<int> axis; std::vector<int> axis;
switch (x_dims) { switch (x_ndims) {
case 2: case 2:
axis = {0, 1}; axis = {0, 1};
break; break;
case 3: case 3:
axis = {0, 2, 1}; axis = {0, 2, 1};
out->Resize(std::vector<int64_t>{ out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[1]}); origin_dims[0], origin_dims[2], origin_dims[1]});
break; break;
case 4: case 4:
axis = {0, 2, 3, 1}; axis = {0, 2, 3, 1};
out->Resize(std::vector<int64_t>{ out->Resize(std::vector<int64_t>{
out->dims()[0], out->dims()[2], out->dims()[3], out->dims()[1]}); origin_dims[0], origin_dims[2], origin_dims[3], origin_dims[1]});
break; break;
default: default:
CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc"; CHECK(0) << "Unsupport dim in mlu layout nchw to nhwc";
} }
LayoutTransCompute<lite::TargetType::kX86, LayoutTransCompute<lite::TargetType::kX86,
typename FPTypeTraits<Precision>::T>( typename FPTypeTraits<Precision>::type>(
x_dims, context, *x, out, axis); x_ndims, context, *x, out, axis);
if (x_dims > 2) { if (x_ndims > 2) {
out->Resize(origin_dims); out->Resize(origin_dims);
} }
} }
...@@ -130,25 +130,26 @@ class LayoutNhwcToNchwCompute ...@@ -130,25 +130,26 @@ class LayoutNhwcToNchwCompute
auto& param = this->template Param<param_t>(); auto& param = this->template Param<param_t>();
auto* x = param.x; auto* x = param.x;
auto* out = param.y; auto* out = param.y;
out->template mutable_data<typename FPTypeTraits<Precision>::T>(); out->template mutable_data<typename FPTypeTraits<Precision>::type>();
auto x_dims = param.x->dims().size();
auto& context = this->ctx_->template As<X86Context>(); auto& context = this->ctx_->template As<X86Context>();
const auto origin_dims = out->dims().Vectorize(); TensorLite tmp_t;
tmp_t.ShareDataWith(*x);
const auto x_dims = x->dims().Vectorize();
auto x_ndims = param.x->dims().size();
std::vector<int> axis; std::vector<int> axis;
switch (x_dims) { switch (x_ndims) {
case 2: case 2:
axis = {0, 1}; axis = {0, 1};
break; break;
case 3: case 3:
out->Resize(std::vector<int64_t>{ tmp_t.Resize(std::vector<int64_t>{x_dims[0], x_dims[2], x_dims[1]});
out->dims()[0], out->dims()[2], out->dims()[1]});
axis = {0, 2, 1}; axis = {0, 2, 1};
break; break;
case 4: case 4:
out->Resize(std::vector<int64_t>{ tmp_t.Resize(
out->dims()[0], out->dims()[3], out->dims()[1], out->dims()[2]}); std::vector<int64_t>{x_dims[0], x_dims[2], x_dims[3], x_dims[1]});
axis = {0, 3, 1, 2}; axis = {0, 3, 1, 2};
break; break;
default: default:
...@@ -156,12 +157,8 @@ class LayoutNhwcToNchwCompute ...@@ -156,12 +157,8 @@ class LayoutNhwcToNchwCompute
} }
LayoutTransCompute<lite::TargetType::kX86, LayoutTransCompute<lite::TargetType::kX86,
typename FPTypeTraits<Precision>::T>( typename FPTypeTraits<Precision>::type>(
x_dims, context, *x, out, axis); x_ndims, context, tmp_t, out, axis);
if (x_dims > 2) {
out->Resize(origin_dims);
}
} }
std::string doc() const override { std::string doc() const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册