reader_test.cc 2.6 KB
Newer Older
Y
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/reader.h"
16

Y
yuyang18 已提交
17
#include <memory>
18

Y
yuyang18 已提交
19
#include "gtest/gtest.h"
20
#include "gtest/gtest_pred_impl.h"
Y
yuyang18 已提交
21 22 23 24 25 26

class StubDecoratedReader : public paddle::framework::DecoratedReader {
 public:
  explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
      : DecoratedReader(reader) {}

27
  void ReadNextImpl(paddle::framework::LoDTensorArray *out) override {}
Y
yuyang18 已提交
28 29 30 31
};

class StubRootReader : public paddle::framework::ReaderBase {
 public:
32 33 34 35 36
  explicit StubRootReader(
      const std::vector<paddle::framework::DDim> &dims,
      const std::vector<paddle::framework::proto::VarType::Type> &var_types,
      const std::vector<bool> &need_check_feed)
      : paddle::framework::ReaderBase(dims, var_types, need_check_feed) {}
37
  void ReadNextImpl(paddle::framework::LoDTensorArray *out) override {}
Y
yuyang18 已提交
38 39 40
};

TEST(READER, decorate_chain) {
41 42
  paddle::framework::proto::VarType::Type dtype =
      paddle::framework::proto::VarType::FP32;
43
  paddle::framework::DDim dim = phi::make_ddim({5, 7});
44 45 46 47 48
  std::vector<paddle::framework::DDim> init_dims(4, dim);
  std::vector<paddle::framework::proto::VarType::Type> init_types(4, dtype);
  std::vector<bool> init_need_check(4, true);
  auto root =
      std::make_shared<StubRootReader>(init_dims, init_types, init_need_check);
49 50 51 52
  auto end_point1 =
      paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
  auto end_point2 =
      paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
Y
yuyang18 已提交
53 54 55 56

  {
    auto endpoints = root->GetEndPoints();
    ASSERT_EQ(endpoints.size(), 2U);
T
Tao Luo 已提交
57
    ASSERT_NE(endpoints.count(end_point1.get()), 0UL);
58
    ASSERT_NE(endpoints.count(end_point2.get()), 0UL);
Y
yuyang18 已提交
59 60 61
  }

  {
62 63
    auto end_point3 =
        paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
Y
yuyang18 已提交
64 65 66
    ASSERT_EQ(root->GetEndPoints().size(), 3U);
  }
  { ASSERT_EQ(root->GetEndPoints().size(), 2U); }
67 68 69 70 71 72

  {
    ASSERT_EQ(end_point1->Shapes(), init_dims);
    ASSERT_EQ(end_point1->VarTypes(), init_types);
    ASSERT_EQ(end_point1->NeedCheckFeed(), init_need_check);
  }
Y
yuyang18 已提交
73
}