paddle_gtest_main.cc 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "gflags/gflags.h"
#include "gtest/gtest.h"
Y
Refine  
Yu Yang 已提交
17
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
18
#include "paddle/fluid/platform/init.h"
19 20

int main(int argc, char** argv) {
Y
Refine  
Yu Yang 已提交
21
  paddle::memory::allocation::UseAllocatorStrategyGFlag();
Y
Yu Yang 已提交
22
  testing::InitGoogleTest(&argc, argv);
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
  // Because the dynamic library libpaddle_fluid.so clips the symbol table, the
  // external program cannot recognize the flag inside the so, and the flag
  // defined by the external program cannot be accessed inside the so.
  // Therefore, the ParseCommandLine function needs to be called separately
  // inside and outside.
  std::vector<char*> external_argv;
  std::vector<char*> internal_argv;

  // ParseNewCommandLineFlags in gflags.cc starts processing
  // commandline strings from idx 1.
  // The reason is, it assumes that the first one (idx 0) is
  // the filename of executable file.
  external_argv.push_back(argv[0]);
  internal_argv.push_back(argv[0]);

  std::vector<google::CommandLineFlagInfo> all_flags;
  std::vector<std::string> external_flags_name;
  google::GetAllFlags(&all_flags);
  for (size_t i = 0; i < all_flags.size(); ++i) {
    external_flags_name.push_back(all_flags[i].name);
  }

45
  for (int i = 0; i < argc; ++i) {
46 47 48 49 50 51 52 53 54 55 56 57
    bool flag = true;
    std::string tmp(argv[i]);
    for (size_t j = 0; j < external_flags_name.size(); ++j) {
      if (tmp.find(external_flags_name[j]) != std::string::npos) {
        external_argv.push_back(argv[i]);
        flag = false;
        break;
      }
    }
    if (flag) {
      internal_argv.push_back(argv[i]);
    }
58
  }
59 60 61 62

  std::vector<std::string> envs;
  std::vector<std::string> undefok;
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC)
G
gongweibao 已提交
63 64 65 66 67
  std::string str_max_body_size;
  if (google::GetCommandLineOption("max_body_size", &str_max_body_size)) {
    setenv("FLAGS_max_body_size", "2147483647", 1);
    envs.push_back("max_body_size");
  }
68 69
#endif

S
sabreshao 已提交
70
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
71
  envs.push_back("fraction_of_gpu_memory_to_use");
72 73
  envs.push_back("initial_gpu_memory_in_mb");
  envs.push_back("reallocate_gpu_memory_in_mb");
74
  envs.push_back("allocator_strategy");
J
JiabinYang 已提交
75
#elif __clang__
76 77 78 79 80 81
  envs.push_back("use_mkldnn");
  envs.push_back("initial_cpu_memory_in_mb");
  envs.push_back("allocator_strategy");

  undefok.push_back("use_mkldnn");
  undefok.push_back("initial_cpu_memory_in_mb");
82
#else
83 84 85 86 87
  envs.push_back("use_pinned_memory");
  envs.push_back("use_mkldnn");
  envs.push_back("initial_cpu_memory_in_mb");
  envs.push_back("allocator_strategy");

88
  undefok.push_back("use_pinned_memory");
89 90
  undefok.push_back("use_mkldnn");
  undefok.push_back("initial_cpu_memory_in_mb");
91
#endif
92

93
  char* env_str = nullptr;
94 95 96 97 98 99
  if (envs.size() > 0) {
    std::string env_string = "--tryfromenv=";
    for (auto t : envs) {
      env_string += t + ",";
    }
    env_string = env_string.substr(0, env_string.length() - 1);
100
    env_str = strdup(env_string.c_str());
101
    internal_argv.push_back(env_str);
102 103 104
    VLOG(1) << "gtest env_string:" << env_string;
  }

105
  char* undefok_str = nullptr;
106 107 108 109 110 111
  if (undefok.size() > 0) {
    std::string undefok_string = "--undefok=";
    for (auto t : undefok) {
      undefok_string += t + ",";
    }
    undefok_string = undefok_string.substr(0, undefok_string.length() - 1);
112
    undefok_str = strdup(undefok_string.c_str());
113
    internal_argv.push_back(undefok_str);
114 115 116
    VLOG(1) << "gtest undefok_string:" << undefok_string;
  }

117 118 119 120 121 122 123
  int new_argc = static_cast<int>(external_argv.size());
  char** external_argv_address = external_argv.data();
  google::ParseCommandLineFlags(&new_argc, &external_argv_address, false);

  int internal_argc = internal_argv.size();
  char** arr = internal_argv.data();
  paddle::platform::ParseCommandLineFlags(internal_argc, arr, true);
124
  paddle::framework::InitDevices();
125 126 127 128 129 130 131

  int ret = RUN_ALL_TESTS();

  if (env_str) free(env_str);
  if (undefok_str) free(undefok_str);

  return ret;
132
}