未验证 提交 a14dc688 编写于 作者: C chentianyu03 提交者: GitHub

[pten] fix test concat dev api build failed (#39117)

* fix test concat dev api build failed

* fix conflict

* fix conflict
上级 a0f586bc
......@@ -80,7 +80,10 @@ class ConcatKernel : public framework::OpKernel<T> {
pt_ins.push_back(*in);
}
pten::ConcatKernel<T>(dev_ctx, pt_ins, axis, out);
pten::ConcatKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
pt_ins, axis, out);
}
};
......
......@@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx,
pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis);
out->Resize(out_dims);
out->mutable_data<T>();
out->mutable_data<T>(dev_ctx.GetPlace());
// If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && x[0].lod().size() > 0) {
......
......@@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx,
pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis);
out->Resize(out_dims);
out->mutable_data<T>();
out->mutable_data<T>(dev_ctx.GetPlace());
// If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && x[0].lod().size() > 0) {
......
......@@ -37,14 +37,16 @@ TEST(API, concat) {
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x->mutable_data<float>();
auto* dense_x_data =
dense_x->mutable_data<float>(paddle::platform::CPUPlace());
auto dense_y = std::make_shared<pten::DenseTensor>(
alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y->mutable_data<float>();
auto* dense_y_data =
dense_y->mutable_data<float>(paddle::platform::CPUPlace());
for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) {
......
......@@ -25,7 +25,7 @@ namespace pten {
namespace tests {
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
using DDim = pten::framework::DDim;
TEST(DEV_API, concat) {
// 1. create tensor
......@@ -35,13 +35,15 @@ TEST(DEV_API, concat) {
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x.mutable_data<float>();
auto* dense_x_data =
dense_x.mutable_data<float>(paddle::platform::CPUPlace());
pten::DenseTensor dense_y(alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y.mutable_data<float>();
auto* dense_y_data =
dense_y.mutable_data<float>(paddle::platform::CPUPlace());
for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) {
......@@ -50,15 +52,11 @@ TEST(DEV_API, concat) {
}
}
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
std::vector<pten::DenseTensor> inputs = {dense_x, dense_y};
// 2. test API
auto out = pten::Concat<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)), inputs, 0);
pten::CPUContext dev_ctx;
auto out = pten::Concat<float>(dev_ctx, inputs, 0);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册