From a14dc68820dbb221831b13b8c43155f537e265e9 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Fri, 21 Jan 2022 20:56:04 +0800 Subject: [PATCH] [pten] fix test concat dev api build failed (#39117) * fix test concat dev api build failed * fix conflict * fix conflict --- paddle/fluid/operators/concat_op.h | 5 ++++- paddle/pten/kernels/cpu/concat_kernel.cc | 2 +- paddle/pten/kernels/gpu/concat_kernel.cu | 2 +- paddle/pten/tests/api/test_concat_api.cc | 6 ++++-- paddle/pten/tests/kernels/test_concat_dev_api.cc | 16 +++++++--------- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 3eaffbdc8bf..1d9c10bdb8c 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -80,7 +80,10 @@ class ConcatKernel : public framework::OpKernel { pt_ins.push_back(*in); } - pten::ConcatKernel(dev_ctx, pt_ins, axis, out); + pten::ConcatKernel( + static_cast::TYPE&>(dev_ctx), + pt_ins, axis, out); } }; diff --git a/paddle/pten/kernels/cpu/concat_kernel.cc b/paddle/pten/kernels/cpu/concat_kernel.cc index fb59c9c6005..c4aed7679bd 100644 --- a/paddle/pten/kernels/cpu/concat_kernel.cc +++ b/paddle/pten/kernels/cpu/concat_kernel.cc @@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx, pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); out->Resize(out_dims); - out->mutable_data(); + out->mutable_data(dev_ctx.GetPlace()); // If axis is 0, the lod of the output is not the same as inputs. if (axis == 0 && x[0].lod().size() > 0) { diff --git a/paddle/pten/kernels/gpu/concat_kernel.cu b/paddle/pten/kernels/gpu/concat_kernel.cu index 6ddfef460fc..e52e3a3d644 100644 --- a/paddle/pten/kernels/gpu/concat_kernel.cu +++ b/paddle/pten/kernels/gpu/concat_kernel.cu @@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx, pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); out->Resize(out_dims); - out->mutable_data(); + out->mutable_data(dev_ctx.GetPlace()); // If axis is 0, the lod of the output is not the same as inputs. if (axis == 0 && x[0].lod().size() > 0) { diff --git a/paddle/pten/tests/api/test_concat_api.cc b/paddle/pten/tests/api/test_concat_api.cc index e84aee0aaaf..c003e89f6c0 100644 --- a/paddle/pten/tests/api/test_concat_api.cc +++ b/paddle/pten/tests/api/test_concat_api.cc @@ -37,14 +37,16 @@ TEST(API, concat) { pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_x_data = dense_x->mutable_data(); + auto* dense_x_data = + dense_x->mutable_data(paddle::platform::CPUPlace()); auto dense_y = std::make_shared( alloc.get(), pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_y_data = dense_y->mutable_data(); + auto* dense_y_data = + dense_y->mutable_data(paddle::platform::CPUPlace()); for (size_t i = 0; i < 3; ++i) { for (size_t j = 0; j < 10; ++j) { diff --git a/paddle/pten/tests/kernels/test_concat_dev_api.cc b/paddle/pten/tests/kernels/test_concat_dev_api.cc index c5d979ad908..6f9ea1b0d99 100644 --- a/paddle/pten/tests/kernels/test_concat_dev_api.cc +++ b/paddle/pten/tests/kernels/test_concat_dev_api.cc @@ -25,7 +25,7 @@ namespace pten { namespace tests { namespace framework = paddle::framework; -using DDim = paddle::framework::DDim; +using DDim = pten::framework::DDim; TEST(DEV_API, concat) { // 1. create tensor @@ -35,13 +35,15 @@ TEST(DEV_API, concat) { pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_x_data = dense_x.mutable_data(); + auto* dense_x_data = + dense_x.mutable_data(paddle::platform::CPUPlace()); pten::DenseTensor dense_y(alloc.get(), pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_y_data = dense_y.mutable_data(); + auto* dense_y_data = + dense_y.mutable_data(paddle::platform::CPUPlace()); for (size_t i = 0; i < 3; ++i) { for (size_t j = 0; j < 10; ++j) { @@ -50,15 +52,11 @@ TEST(DEV_API, concat) { } } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - std::vector inputs = {dense_x, dense_y}; // 2. test API - auto out = pten::Concat( - *(static_cast(dev_ctx)), inputs, 0); + pten::CPUContext dev_ctx; + auto out = pten::Concat(dev_ctx, inputs, 0); // 3. check result ASSERT_EQ(out.dims().size(), 2); -- GitLab