reader_test.cc 2.6 KB
Newer Older
Y
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
18
#include "paddle/fluid/framework/ddim.h"
Y
yuyang18 已提交
19 20 21 22 23 24

class StubDecoratedReader : public paddle::framework::DecoratedReader {
 public:
  explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
      : DecoratedReader(reader) {}

F
fengjiayi 已提交
25
  void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
Y
yuyang18 已提交
26 27 28 29
};

class StubRootReader : public paddle::framework::ReaderBase {
 public:
30 31 32 33 34
  explicit StubRootReader(
      const std::vector<paddle::framework::DDim> &dims,
      const std::vector<paddle::framework::proto::VarType::Type> &var_types,
      const std::vector<bool> &need_check_feed)
      : paddle::framework::ReaderBase(dims, var_types, need_check_feed) {}
F
fengjiayi 已提交
35
  void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
Y
yuyang18 已提交
36 37 38
};

TEST(READER, decorate_chain) {
39 40 41 42 43 44 45 46
  paddle::framework::proto::VarType::Type dtype =
      paddle::framework::proto::VarType::FP32;
  paddle::framework::DDim dim = paddle::framework::make_ddim({5, 7});
  std::vector<paddle::framework::DDim> init_dims(4, dim);
  std::vector<paddle::framework::proto::VarType::Type> init_types(4, dtype);
  std::vector<bool> init_need_check(4, true);
  auto root =
      std::make_shared<StubRootReader>(init_dims, init_types, init_need_check);
47 48 49 50
  auto end_point1 =
      paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
  auto end_point2 =
      paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
Y
yuyang18 已提交
51 52 53 54

  {
    auto endpoints = root->GetEndPoints();
    ASSERT_EQ(endpoints.size(), 2U);
T
Tao Luo 已提交
55
    ASSERT_NE(endpoints.count(end_point1.get()), 0UL);
56
    ASSERT_NE(endpoints.count(end_point2.get()), 0UL);
Y
yuyang18 已提交
57 58 59
  }

  {
60 61
    auto end_point3 =
        paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
Y
yuyang18 已提交
62 63 64
    ASSERT_EQ(root->GetEndPoints().size(), 3U);
  }
  { ASSERT_EQ(root->GetEndPoints().size(), 2U); }
65 66 67 68 69 70

  {
    ASSERT_EQ(end_point1->Shapes(), init_dims);
    ASSERT_EQ(end_point1->VarTypes(), init_types);
    ASSERT_EQ(end_point1->NeedCheckFeed(), init_need_check);
  }
Y
yuyang18 已提交
71
}