未验证 提交 a893f156 编写于 作者: D dzhwinter 提交者: GitHub

fix layout transform (#7149)

* "fix typo"

* "fix based on comments"

* "follow gogle style"

* "fix based on comemnts"
上级 dd8ffe1e
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <functional>
#include "paddle/framework/data_transform.h" #include "paddle/framework/data_transform.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
...@@ -74,26 +75,28 @@ void TransDataType(const platform::DeviceContext* ctx, ...@@ -74,26 +75,28 @@ void TransDataType(const platform::DeviceContext* ctx,
} }
} }
void TransDataLayout(const platform::DeviceContext* ctx, void TransDataLayout(const std::vector<int>& axis,
const platform::DeviceContext* ctx,
const KernelTypePair& kernel_pair, const Variable& in, const KernelTypePair& kernel_pair, const Variable& in,
Variable* out) { Variable* out) {
PADDLE_ENFORCE(in.IsType<Tensor>(), "Only Support Tensor transform!."); PADDLE_ENFORCE(in.IsType<Tensor>(), "Only support Tensor transform!.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::places_are_same_class(kernel_pair.first.place_, platform::places_are_same_class(kernel_pair.first.place_,
kernel_pair.second.place_), kernel_pair.second.place_),
"TransDataType Only Support DataType transform on same place!"); "TransDataLayout only support DataLayout transform on same place!");
PADDLE_ENFORCE(kernel_pair.first.data_type_ == kernel_pair.second.data_type_,
"TransDataLayout only support Datatype are same!");
auto src = in.Get<Tensor>(); auto src = in.Get<Tensor>();
auto* dst = out->GetMutable<Tensor>(); auto* dst = out->GetMutable<Tensor>();
PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!"); PADDLE_ENFORCE(arity(src.dims()) == 4, "Input Arity Only Suppport 4!");
auto src_dim = src.dims();
dst->Resize(src_dim);
auto place = kernel_pair.second.place_; auto place = kernel_pair.second.place_;
CopyFrom(src, place, *ctx, dst); CopyFrom(src, place, *ctx, dst);
const std::vector<int> axis = {0, 2, 3, 1};
auto src_dim = src.dims();
std::vector<int64_t> dst_dim; std::vector<int64_t> dst_dim;
dst_dim.resize(axis.size()); dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) { for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]]; dst_dim[i] = src_dim[axis[i]];
...@@ -102,7 +105,7 @@ void TransDataLayout(const platform::DeviceContext* ctx, ...@@ -102,7 +105,7 @@ void TransDataLayout(const platform::DeviceContext* ctx,
dst->Resize(make_ddim(dst_dim)); dst->Resize(make_ddim(dst_dim));
auto src_type = kernel_pair.first.data_type_; auto src_type = kernel_pair.first.data_type_;
framework::VisitDataType(src_type, CastDataLayout(src, dst, ctx, axis)); framework::VisitDataType(src_type, CastDataLayout(ctx, axis, src, dst));
dst->set_layout(kernel_pair.second.data_layout_); dst->set_layout(kernel_pair.second.data_layout_);
} }
...@@ -111,5 +114,22 @@ void TransDataLayout(const platform::DeviceContext* ctx, ...@@ -111,5 +114,22 @@ void TransDataLayout(const platform::DeviceContext* ctx,
} // namespace paddle } // namespace paddle
namespace f = paddle::framework; namespace f = paddle::framework;
namespace {
std::vector<int> NHWC2NCHW = {0, 3, 1, 2};
std::vector<int> NCHW2NHWC = {0, 2, 3, 1};
}
REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType); REGISTER_DATA_TRANSFORM_FN(f::KernelFP32, f::KernelFP64, f::TransDataType);
REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW, f::TransDataLayout); REGISTER_DATA_TRANSFORM_FN(f::KernelNHWC, f::KernelNCHW,
std::bind(f::TransDataLayout, NHWC2NCHW,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3,
std::placeholders::_4));
REGISTER_DATA_TRANSFORM_FN(f::KernelNCHW, f::KernelNHWC,
std::bind(f::TransDataLayout, NCHW2NHWC,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3,
std::placeholders::_4));
...@@ -73,6 +73,7 @@ struct CastDataType { ...@@ -73,6 +73,7 @@ struct CastDataType {
auto numel = in_.numel(); auto numel = in_.numel();
auto* in_end = in_begin + numel; auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutType>(place); auto* out_begin = out_->mutable_data<OutType>(place);
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_); auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
...@@ -86,9 +87,9 @@ struct CastDataType { ...@@ -86,9 +87,9 @@ struct CastDataType {
}; };
struct CastDataLayout { struct CastDataLayout {
CastDataLayout(const framework::Tensor& in, framework::Tensor* out, CastDataLayout(const platform::DeviceContext* ctx,
const platform::DeviceContext* ctx, const std::vector<int>& axis, const framework::Tensor& in,
const std::vector<int>& axis) framework::Tensor* out)
: in_(in), out_(out), ctx_(ctx), axis_(axis) {} : in_(in), out_(out), ctx_(ctx), axis_(axis) {}
const framework::Tensor in_; const framework::Tensor in_;
framework::Tensor* out_; framework::Tensor* out_;
...@@ -98,6 +99,7 @@ struct CastDataLayout { ...@@ -98,6 +99,7 @@ struct CastDataLayout {
template <typename T> template <typename T>
void operator()() { void operator()() {
auto place = ctx_->GetPlace(); auto place = ctx_->GetPlace();
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4; operators::math::Transpose<platform::CPUDeviceContext, T, 4> trans4;
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_); auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
......
...@@ -106,7 +106,7 @@ TEST(DataTransform, Register) { ...@@ -106,7 +106,7 @@ TEST(DataTransform, Register) {
ASSERT_EQ(test_value, 2); ASSERT_EQ(test_value, 2);
} }
TEST(DataTransform, Layout) { TEST(DataTransform, DataLayout) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
...@@ -127,7 +127,19 @@ TEST(DataTransform, Layout) { ...@@ -127,7 +127,19 @@ TEST(DataTransform, Layout) {
} }
Tensor dst = out.Get<Tensor>(); Tensor dst = out.Get<Tensor>();
EXPECT_TRUE(dst.layout() != src->layout());
EXPECT_TRUE(dst.layout() == DataLayout::kNCHW);
EXPECT_TRUE(dst.dims() == make_ddim({2, 2, 3, 1}));
{
auto kernel1 = GenFromBit({1, 0, 1, 0});
auto kernel2 = GenFromBit({1, 0, 0, 0});
auto pair0 = std::make_pair(kernel1, kernel2);
instance.Get(pair0)(ctx, pair0, out, &in);
}
EXPECT_TRUE(src->layout() == DataLayout::kNHWC);
EXPECT_TRUE(src->dims() == make_ddim({2, 3, 1, 2}));
} }
TEST(DataTransform, DataType) { TEST(DataTransform, DataType) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册