未验证 提交 a14dc688 编写于 作者: C chentianyu03 提交者: GitHub

[pten] fix test concat dev api build failed (#39117)

* fix test concat dev api build failed

* fix conflict

* fix conflict
上级 a0f586bc
...@@ -80,7 +80,10 @@ class ConcatKernel : public framework::OpKernel<T> { ...@@ -80,7 +80,10 @@ class ConcatKernel : public framework::OpKernel<T> {
pt_ins.push_back(*in); pt_ins.push_back(*in);
} }
pten::ConcatKernel<T>(dev_ctx, pt_ins, axis, out); pten::ConcatKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
pt_ins, axis, out);
} }
}; };
......
...@@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx,
pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis);
out->Resize(out_dims); out->Resize(out_dims);
out->mutable_data<T>(); out->mutable_data<T>(dev_ctx.GetPlace());
// If axis is 0, the lod of the output is not the same as inputs. // If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && x[0].lod().size() > 0) { if (axis == 0 && x[0].lod().size() > 0) {
......
...@@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx, ...@@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx,
pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis);
out->Resize(out_dims); out->Resize(out_dims);
out->mutable_data<T>(); out->mutable_data<T>(dev_ctx.GetPlace());
// If axis is 0, the lod of the output is not the same as inputs. // If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && x[0].lod().size() > 0) { if (axis == 0 && x[0].lod().size() > 0) {
......
...@@ -37,14 +37,16 @@ TEST(API, concat) { ...@@ -37,14 +37,16 @@ TEST(API, concat) {
pten::DenseTensorMeta(pten::DataType::FLOAT32, pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}), framework::make_ddim({3, 10}),
pten::DataLayout::NCHW)); pten::DataLayout::NCHW));
auto* dense_x_data = dense_x->mutable_data<float>(); auto* dense_x_data =
dense_x->mutable_data<float>(paddle::platform::CPUPlace());
auto dense_y = std::make_shared<pten::DenseTensor>( auto dense_y = std::make_shared<pten::DenseTensor>(
alloc.get(), alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32, pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}), framework::make_ddim({3, 10}),
pten::DataLayout::NCHW)); pten::DataLayout::NCHW));
auto* dense_y_data = dense_y->mutable_data<float>(); auto* dense_y_data =
dense_y->mutable_data<float>(paddle::platform::CPUPlace());
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) { for (size_t j = 0; j < 10; ++j) {
......
...@@ -25,7 +25,7 @@ namespace pten { ...@@ -25,7 +25,7 @@ namespace pten {
namespace tests { namespace tests {
namespace framework = paddle::framework; namespace framework = paddle::framework;
using DDim = paddle::framework::DDim; using DDim = pten::framework::DDim;
TEST(DEV_API, concat) { TEST(DEV_API, concat) {
// 1. create tensor // 1. create tensor
...@@ -35,13 +35,15 @@ TEST(DEV_API, concat) { ...@@ -35,13 +35,15 @@ TEST(DEV_API, concat) {
pten::DenseTensorMeta(pten::DataType::FLOAT32, pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}), framework::make_ddim({3, 10}),
pten::DataLayout::NCHW)); pten::DataLayout::NCHW));
auto* dense_x_data = dense_x.mutable_data<float>(); auto* dense_x_data =
dense_x.mutable_data<float>(paddle::platform::CPUPlace());
pten::DenseTensor dense_y(alloc.get(), pten::DenseTensor dense_y(alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32, pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}), framework::make_ddim({3, 10}),
pten::DataLayout::NCHW)); pten::DataLayout::NCHW));
auto* dense_y_data = dense_y.mutable_data<float>(); auto* dense_y_data =
dense_y.mutable_data<float>(paddle::platform::CPUPlace());
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) { for (size_t j = 0; j < 10; ++j) {
...@@ -50,15 +52,11 @@ TEST(DEV_API, concat) { ...@@ -50,15 +52,11 @@ TEST(DEV_API, concat) {
} }
} }
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<pten::DenseTensor> inputs = {dense_x, dense_y}; std::vector<pten::DenseTensor> inputs = {dense_x, dense_y};
// 2. test API // 2. test API
auto out = pten::Concat<float>( pten::CPUContext dev_ctx;
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), inputs, 0); auto out = pten::Concat<float>(dev_ctx, inputs, 0);
// 3. check result // 3. check result
ASSERT_EQ(out.dims().size(), 2); ASSERT_EQ(out.dims().size(), 2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册