未验证 提交 d234aa02 编写于 作者: T Tao Luo 提交者: GitHub

add transfer_scope_cache unit-test (#18467)

test=develop
上级 7c6f2350
...@@ -46,27 +46,5 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, ...@@ -46,27 +46,5 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
return new_scope; return new_scope;
} }
void RemoveKidsFromTransferScopeCache(Scope* scope) {
auto it = global_transfer_scope_cache().find(scope);
if (it != global_transfer_scope_cache().end()) {
global_transfer_scope_cache().erase(it);
}
for (auto* s : scope->kids()) {
auto it = global_transfer_scope_cache().find(s);
if (it != global_transfer_scope_cache().end()) {
global_transfer_scope_cache().erase(it);
}
}
// remove global transfer data cache
auto& cache = global_transfer_data_cache();
for (auto it = cache.begin(); it != cache.end();) {
if (it->second == scope)
it = cache.erase(it);
else
it++;
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,5 @@ static size_t CombineHash(size_t seed, size_t a) { ...@@ -35,7 +35,5 @@ static size_t CombineHash(size_t seed, size_t a) {
Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1,
const Scope* scope); const Scope* scope);
void RemoveKidsFromTransferScopeCache(Scope* scope);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h" #include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle { namespace paddle {
...@@ -228,5 +229,44 @@ TEST(Analyzer_bert, compare_determine) { ...@@ -228,5 +229,44 @@ TEST(Analyzer_bert, compare_determine) {
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg), CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
inputs); inputs);
} }
TEST(Analyzer_bert, transfer_scope_cache) {
AnalysisConfig config;
SetConfig(&config);
std::vector<PaddleTensor> input, output;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
int threads_num = 10;
std::vector<std::thread> threads;
std::unordered_set<std::unordered_set<paddle::framework::Scope *> *>
global_transfer_scope_cache;
std::unordered_set<std::unordered_map<size_t, paddle::framework::Scope *> *>
global_transfer_data_cache;
std::ifstream fin(FLAGS_infer_data);
std::string line;
for (int i = 0; i < threads_num; i++) {
threads.emplace_back([&, i]() {
std::getline(fin, line);
ParseLine(line, &input);
predictor->Run(input, &output, FLAGS_batch_size);
global_transfer_scope_cache.insert(
&paddle::framework::global_transfer_scope_cache());
global_transfer_data_cache.insert(
&paddle::framework::global_transfer_data_cache());
});
threads[0].join();
threads.clear();
std::vector<PaddleTensor>().swap(input);
}
// Since paddle::framework::global_transfer_scope_cache() and
// paddle::framework::global_transfer_data_cache() are thread_local,
// their pointer should be different among different thread id.
PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num);
PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册