diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 3eaffbdc8bf35be9af8da73d28c92f4d8f00f53b..1d9c10bdb8cc6a698a4a1b6ab376e90b67eb2a03 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -80,7 +80,10 @@ class ConcatKernel : public framework::OpKernel { pt_ins.push_back(*in); } - pten::ConcatKernel(dev_ctx, pt_ins, axis, out); + pten::ConcatKernel( + static_cast::TYPE&>(dev_ctx), + pt_ins, axis, out); } }; diff --git a/paddle/pten/kernels/cpu/concat_kernel.cc b/paddle/pten/kernels/cpu/concat_kernel.cc index fb59c9c6005ff7b5d9acd1480c7145225ea07378..c4aed7679bd72c42d1d0b46d3ebf195d1c35298b 100644 --- a/paddle/pten/kernels/cpu/concat_kernel.cc +++ b/paddle/pten/kernels/cpu/concat_kernel.cc @@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx, pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); out->Resize(out_dims); - out->mutable_data(); + out->mutable_data(dev_ctx.GetPlace()); // If axis is 0, the lod of the output is not the same as inputs. if (axis == 0 && x[0].lod().size() > 0) { diff --git a/paddle/pten/kernels/gpu/concat_kernel.cu b/paddle/pten/kernels/gpu/concat_kernel.cu index 6ddfef460fc6cf2945903fbb70574272e4e18e55..e52e3a3d6446c7debdc0fa603ba326173f064181 100644 --- a/paddle/pten/kernels/gpu/concat_kernel.cu +++ b/paddle/pten/kernels/gpu/concat_kernel.cu @@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx, pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); out->Resize(out_dims); - out->mutable_data(); + out->mutable_data(dev_ctx.GetPlace()); // If axis is 0, the lod of the output is not the same as inputs. if (axis == 0 && x[0].lod().size() > 0) { diff --git a/paddle/pten/tests/api/test_concat_api.cc b/paddle/pten/tests/api/test_concat_api.cc index e84aee0aaaf4ff6151327fc556595f13eb7efc1f..c003e89f6c0097b47156e7a6439c0a0a8d4a1b29 100644 --- a/paddle/pten/tests/api/test_concat_api.cc +++ b/paddle/pten/tests/api/test_concat_api.cc @@ -37,14 +37,16 @@ TEST(API, concat) { pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_x_data = dense_x->mutable_data(); + auto* dense_x_data = + dense_x->mutable_data(paddle::platform::CPUPlace()); auto dense_y = std::make_shared( alloc.get(), pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_y_data = dense_y->mutable_data(); + auto* dense_y_data = + dense_y->mutable_data(paddle::platform::CPUPlace()); for (size_t i = 0; i < 3; ++i) { for (size_t j = 0; j < 10; ++j) { diff --git a/paddle/pten/tests/kernels/test_concat_dev_api.cc b/paddle/pten/tests/kernels/test_concat_dev_api.cc index c5d979ad908fff52e4d8c95db5e01acc0c50a2f6..6f9ea1b0d990ae9e4d789bc4c37fb104c730fe82 100644 --- a/paddle/pten/tests/kernels/test_concat_dev_api.cc +++ b/paddle/pten/tests/kernels/test_concat_dev_api.cc @@ -25,7 +25,7 @@ namespace pten { namespace tests { namespace framework = paddle::framework; -using DDim = paddle::framework::DDim; +using DDim = pten::framework::DDim; TEST(DEV_API, concat) { // 1. create tensor @@ -35,13 +35,15 @@ TEST(DEV_API, concat) { pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_x_data = dense_x.mutable_data(); + auto* dense_x_data = + dense_x.mutable_data(paddle::platform::CPUPlace()); pten::DenseTensor dense_y(alloc.get(), pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_y_data = dense_y.mutable_data(); + auto* dense_y_data = + dense_y.mutable_data(paddle::platform::CPUPlace()); for (size_t i = 0; i < 3; ++i) { for (size_t j = 0; j < 10; ++j) { @@ -50,15 +52,11 @@ TEST(DEV_API, concat) { } } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - std::vector inputs = {dense_x, dense_y}; // 2. test API - auto out = pten::Concat( - *(static_cast(dev_ctx)), inputs, 0); + pten::CPUContext dev_ctx; + auto out = pten::Concat(dev_ctx, inputs, 0); // 3. check result ASSERT_EQ(out.dims().size(), 2);