提交 bd5a82e1 编写于 作者: M minqiyang

Polish unit test code

上级 047fa2f9
...@@ -46,7 +46,7 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) { ...@@ -46,7 +46,7 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
in_grad.set_lod(lod); in_grad.set_lod(lod);
auto in_dims = paddle::framework::make_ddim( auto in_dims = paddle::framework::make_ddim(
{static_cast<int64_t>(lod[0].back()), static_cast<int64_t>(second_dim)}); {static_cast<int64_t>(lod[0].back()), static_cast<int64_t>(second_dim)});
in_grad.mutable_data<T>(in_dims, context.GetPlace()); in_grad.mutable_data<T>(in_dims, context->GetPlace());
// check tensor contruction result // check tensor contruction result
PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size()); PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size());
...@@ -56,15 +56,15 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) { ...@@ -56,15 +56,15 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
// call functor // call functor
paddle::operators::math::SequencePoolGradFunctor<DeviceContext, T>()( paddle::operators::math::SequencePoolGradFunctor<DeviceContext, T>()(
*context, "SUM", out_grad, &in_grad) *context, "SUM", out_grad, &in_grad);
EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim); EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim);
EXPECT_EQ(in_grad.lod(), lod); EXPECT_EQ(in_grad.lod(), lod);
for (int64_t i = 0; i < in_grad.lod().size() - 1; ++i) { for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = in_grad.lod()[i]; int64_t begin = in_grad.lod()[0][i];
int64_t end = in_grad.lod()[i + 1]; int64_t end = in_grad.lod()[0][i + 1];
Tensor tmp = in_grad.Slice(begin, end); paddle::framework::Tensor tmp = in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel(); j) { for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) { for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim], EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
out_grad.data<T>()[m + i * second_dim]); out_grad.data<T>()[m + i * second_dim]);
...@@ -78,16 +78,14 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) { ...@@ -78,16 +78,14 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
TEST(SequencePoolingGrad, CPU_SUM) { TEST(SequencePoolingGrad, CPU_SUM) {
paddle::framework::LoD lod1; paddle::framework::LoD lod1;
auto dim1 = std::vector<size_t>{0, 10}; lod1.push_back(std::vector<size_t>{0, 10});
lod1.push_back(dim1);
TestSequencePoolingSum<paddle::platform::CPUDeviceContext, TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(dim, lod1, "SUM", paddle::platform::CPUPlace, float>(lod1);
16);
paddle::framework::LoD lod2; paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10}); lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePoolingSum<paddle::platform::CPUDeviceContext, TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod2, "SUM", 128); paddle::platform::CPUPlace, float>(lod2);
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -95,11 +93,11 @@ TEST(SequencePoolingGrad, CUDA_SUM) { ...@@ -95,11 +93,11 @@ TEST(SequencePoolingGrad, CUDA_SUM) {
paddle::framework::LoD lod1; paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10}); lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePoolingSum<paddle::platform::CUDADeviceContext, TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod1, "SUM", 16); paddle::platform::CUDAPlace, float>(lod1);
paddle::framework::LoD lod2; paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10}); lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePoolingSum<paddle::platform::CUDADeviceContext, TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod2, "SUM", 128); paddle::platform::CUDAPlace, float>(lod2);
} }
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册