提交 2eea0009 编写于 作者: M Megvii Engine Team

feat(mgb): add fast run batch size graph option

GitOrigin-RevId: 94e333ec805d81a279365a03a9665a69f789f522
上级 0ac642b5
......@@ -91,6 +91,11 @@ class MaxTensorDiff : public OperatorBase {
void check_exec(const TensorLayout& layout1,
const TensorLayout& layout2, size_t workspace_in_bytes);
};
bool check_bias_share_in_channel(const TensorLayout& bias,
const param::ConvBias::Format format);
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -318,36 +318,6 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args,
megdnn_assert(false);
}
}
bool check_bias_share_in_channel(const TensorLayout& bias,
const param::ConvBias::Format format) {
bool share_in_channel = false;
if (format == param::ConvBias::Format::NCHW ||
format == param::ConvBias::Format::NCHW4_NCHW) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 &&
bias[3] == 1);
} else if (format == param::ConvBias::Format::NHWC) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 &&
bias[2] == 1);
} else if (format == param::ConvBias::Format::NCHW4 ||
format == param::ConvBias::Format::NCHW8 ||
format == param::ConvBias::Format::NCHW32 ||
format == param::ConvBias::Format::NCHW64 ||
format == param::ConvBias::Format::NCHW4_NCHW32 ||
format == param::ConvBias::Format::NCHW32_NCHW4) {
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 &&
bias[3] == 1);
} else if (format == param::ConvBias::Format::NHWCD4) {
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 &&
bias[3] == 1);
} else {
megdnn_assert(format == param::ConvBias::Format::CHWN4);
share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 &&
bias[3] == 1);
}
return share_in_channel;
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -22,8 +22,6 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args,
const TensorND* dst_tensor,
const TensorND* bias_tensor);
bool check_bias_share_in_channel(const TensorLayout& bias,
const param::ConvBias::Format format);
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -10,6 +10,7 @@
*/
#include "src/common/utils.h"
#include "megdnn/oprs/utils.h"
#include "megdnn/handle.h"
#include <cstdarg>
......@@ -344,4 +345,33 @@ size_t& CpuNDRange::operator[](size_t idx) {
return m_dim[idx];
}
bool megdnn::check_bias_share_in_channel(const TensorLayout& bias,
const param::ConvBias::Format format) {
bool share_in_channel = false;
if (format == param::ConvBias::Format::NCHW ||
format == param::ConvBias::Format::NCHW4_NCHW) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 &&
bias[3] == 1);
} else if (format == param::ConvBias::Format::NHWC) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 &&
bias[2] == 1);
} else if (format == param::ConvBias::Format::NCHW4 ||
format == param::ConvBias::Format::NCHW8 ||
format == param::ConvBias::Format::NCHW32 ||
format == param::ConvBias::Format::NCHW64 ||
format == param::ConvBias::Format::NCHW4_NCHW32 ||
format == param::ConvBias::Format::NCHW32_NCHW4) {
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 &&
bias[3] == 1);
} else if (format == param::ConvBias::Format::NHWCD4) {
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 &&
bias[3] == 1);
} else {
megdnn_assert(format == param::ConvBias::Format::CHWN4);
share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 &&
bias[3] == 1);
}
return share_in_channel;
}
// vim: syntax=cpp.doxygen
......@@ -158,6 +158,11 @@ R"__usage__(
R"__usage__(
--fast-run-algo-policy <path>
It will read the cache file before profile, and save new fastrun in cache file.
--fast-run-shared-batch-size
Set the batch size used during fastrun, Note that it may not be the same as the actual running batch size
--binary-equal-between-batch
Each batch of output is promised binary equal if each batch of input is binary equal.
Note that if this option is turned on, `--reproducible` will also be turned on.
--reproducible
Enable choose algo which is reproducible. It mainly used for cudnn algos.
See https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#reproducibility
......@@ -1356,6 +1361,20 @@ Args Args::from_argv(int argc, char **argv) {
ret.fast_run_cache_path = argv[i];
continue;
}
if (!strcmp(argv[i], "--fast-run-shared-batch-size")) {
++i;
mgb_assert(i < argc,
"value not given for --fast-run-shared-batch-size");
int32_t batch_size = std::stoi(argv[i]);
mgb_assert(batch_size >= 0);
graph_opt.fast_run_config.shared_batch_size = batch_size;
continue;
}
if (!strcmp(argv[i], "--binary-equal-between-batch")) {
graph_opt.fast_run_config.binary_equal_between_batch = true;
ret.reproducible = true;
continue;
}
if (!strcmp(argv[i], "--reproducible")) {
ret.reproducible = true;
continue;
......@@ -1452,6 +1471,14 @@ Args Args::from_argv(int argc, char **argv) {
return ret;
}
#if MGB_ENABLE_FASTRUN
if (graph_opt.fast_run_config.shared_batch_size) {
mgb_assert(ret.use_fast_run || ret.use_full_run ||
!ret.fast_run_cache_path.empty(),
"--fast-run-shared-batch-size should be used with "
"--fast-run/--full-run/--fast-run-algo-policy");
}
#endif
return ret;
}
......
......@@ -502,7 +502,28 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
//! contains any user data associated with this graph
UserDataContainer user_data;
}; // Options
//! Control parameter for fast run
struct FastRunConfig {
/*!
* the batch size used by fastrun
*
* Non-zero value means that fastrun use this batch size
* regardless of the batch size of the model
*
* Zero means fastrun use batch size of the model
*/
uint32_t shared_batch_size = 0;
/*!
* \brief if the content of each input batch is binary equal,
* whether the content of each output batch is promised to be
* equal
*/
bool binary_equal_between_batch = false;
} fast_run_config;
}; // Options
Options& options() {
return m_options;
......
......@@ -68,7 +68,10 @@ class AlgoChooser {
public:
using FixedTensorLayouts = std::array<TensorLayout, arity>;
class AlgoChooserHelper {
FixedTensorLayouts m_layouts;
//! fastrun layouts
FixedTensorLayouts m_fastrun_layouts;
//! layouts used when get and set cache item
FixedTensorLayouts m_incache_layouts;
Opr* m_dnn_opr;
std::string m_param;
const cg::OperatorNodeBase* m_base_mgb_opr;
......@@ -89,7 +92,7 @@ public:
const cg::OperatorNodeBase* mgb_opr() const { return m_base_mgb_opr; }
const TensorLayout& inp_layout(size_t idx) const {
return m_layouts[idx];
return m_fastrun_layouts[idx];
}
cg::ComputingGraph* owner_graph() const {
return m_base_mgb_opr->owner_graph();
......@@ -109,7 +112,13 @@ public:
return m_dnn_opr->get_algorithm_from_desc(desc);
}
const FixedTensorLayouts& layouts() const { return m_layouts; }
const FixedTensorLayouts& fastrun_layouts() const {
return m_fastrun_layouts;
}
const FixedTensorLayouts& incache_layouts() const {
return m_incache_layouts;
}
//! construct algo chain by heuristic
ImplExecutionPolicy choose_by_heuristic(
......@@ -141,7 +150,8 @@ public:
//! get workspace size required for specific execution policy
size_t get_workspace_size_bytes(
const ImplExecutionPolicy& policy) const;
const ImplExecutionPolicy& policy,
const FixedTensorLayouts& layouts = {}) const;
//! get all candidate algos, and the one choose_by_heuristic() is
//! put first
......@@ -173,7 +183,8 @@ public:
const ExecutionStrategy& strategy) const;
private:
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const;
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter(
const FixedTensorLayouts& layouts = {}) const;
};
template <typename U>
......
......@@ -54,11 +54,11 @@ constexpr bool opr_contain_bias() {
return std::is_same<Opr, megdnn::ConvBias>::value;
}
//! matmul and batchedMatrixMul may not be usable once shape changed
//! matmul and batchedMatrixMul
template <typename Opr>
constexpr bool algo_usable_on_shape_change() {
return !(std::is_same<Opr, megdnn::MatrixMul>::value ||
std::is_same<Opr, megdnn::BatchedMatrixMul>::value);
constexpr bool is_matmul() {
return std::is_same<Opr, megdnn::MatrixMul>::value ||
std::is_same<Opr, megdnn::BatchedMatrixMul>::value;
}
template <typename Opr, bool has_prep>
......
/**
* \file src/opr/test/algo_chooser.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/comp_node_env.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/tensor_manip.h"
#include "megdnn/oprs/base.h"
#include "megdnn/dtype.h"
#include <cmath>
#include <random>
#include <utility>
using namespace mgb;
namespace {
#if MGB_CUDA
#if MGB_ENABLE_FASTRUN
template <typename MgbOpr, int arith>
struct GraphMaker;
template <typename MgbOpr>
struct GraphMaker<MgbOpr, 2> {
SymbolVar operator()(const std::array<cg::SymbolVar, 2>& inputs,
typename MgbOpr::Param& param,
typename MgbOpr::ExecutionPolicy& policy) {
return MgbOpr::make(inputs[0], inputs[1], param, policy);
}
};
template <>
struct GraphMaker<opr::ConvolutionBackwardData, 2> {
SymbolVar operator()(
const std::array<cg::SymbolVar, 2>& inputs,
opr::ConvolutionBackwardData::Param& param,
opr::ConvolutionBackwardData::ExecutionPolicy& policy) {
return opr::ConvolutionBackwardData::make_deconv(inputs[0], inputs[1],
param, policy);
}
};
template <>
struct GraphMaker<opr::Convolution3DBackwardData, 2> {
SymbolVar operator()(
const std::array<cg::SymbolVar, 2>& inputs,
opr::Convolution3DBackwardData::Param& param,
opr::Convolution3DBackwardData::ExecutionPolicy& policy) {
return opr::Convolution3DBackwardData::make_deconv(inputs[0], inputs[1],
param, policy);
}
};
template <typename MgbOpr>
struct GraphMaker<MgbOpr, 3> {
SymbolVar operator()(const std::array<cg::SymbolVar, 3>& inputs,
typename MgbOpr::Param& param,
typename MgbOpr::ExecutionPolicy& policy) {
return MgbOpr::make(inputs[0], inputs[1], inputs[2], param, policy, {});
}
};
template <typename MgbOpr>
struct GraphMaker<MgbOpr, 4> {
SymbolVar operator()(const std::array<cg::SymbolVar, 4>& inputs,
typename MgbOpr::Param& param,
typename MgbOpr::ExecutionPolicy& policy) {
return MgbOpr::make(inputs[0], inputs[1], inputs[2], inputs[3], param,
policy, {});
}
};
template <typename MgbOpr>
struct GraphMaker<MgbOpr, 5> {
SymbolVar operator()(const std::array<cg::SymbolVar, 5>& inputs,
typename MgbOpr::Param& param,
typename MgbOpr::ExecutionPolicy& policy) {
return MgbOpr::make(inputs[0], inputs[1], inputs[2], inputs[3],
inputs[4], param, policy, {});
}
};
template <typename MgbOpr, int arith, typename dtype = dtype::Float32>
void test_fastrun_opr(std::array<TensorShape, arith> inps0,
std::array<TensorShape, arith> inps1,
size_t expect_nr_cache_set_inp0 = 0,
size_t expect_nr_cache_set_inp1 = 0,
typename MgbOpr::Param param = {}) {
using Policy = opr::Convolution::ExecutionPolicy;
using S = Policy::Strategy;
using InputGenerator = std::function<void(HostTensorND & dest)>;
using ShapeInpArray = std::array<TensorShape, arith>;
using CacheMem = std::pair<const void*, size_t>;
auto on_get = [](const std::string&, const void*, size_t, const void*,
size_t) {};
std::vector<std::pair<CacheMem, CacheMem>> cache_set_history;
auto on_set = [&cache_set_history](const std::string&, const void* key,
size_t key_size, const void* val,
size_t val_size) {
cache_set_history.emplace_back(std::make_pair(key, key_size),
std::make_pair(val, val_size));
};
PersistentCacheHook cache_hook{on_get, on_set};
CompNode comp_node = CompNode::load("xpu0");
GraphMaker<MgbOpr, arith> graph_maker;
auto run = [&param, &comp_node, &graph_maker](
const std::shared_ptr<cg::ComputingGraph>& graph,
const ShapeInpArray& shapes) {
std::array<InputGenerator, arith> inputs_generator;
std::array<std::shared_ptr<HostTensorND>, arith> inputs;
for (size_t i = 0; i < arith; ++i) {
inputs[i] = std::make_shared<HostTensorND>(comp_node,
dtype());
}
HostTensorGenerator<dtype> gen_host;
for (size_t i = 0; i < arith; ++i) {
inputs[i]->resize(shapes[i]);
*inputs[i] = *gen_host(inputs[i]->shape(), comp_node);
mgb_assert(inputs[i]->shape().eq_shape(shapes[i]));
}
std::array<cg::SymbolVar, arith> sym_in;
for (size_t i = 0; i < arith; ++i) {
// to trigger graph trans
sym_in[i] = opr::Host2DeviceCopy::make(*graph, inputs[i],
ssprintf("inp%zu", i));
}
Policy policy;
policy.strategy = S::PROFILE;
auto out = graph_maker(sym_in, param, policy);
std::unique_ptr<cg::AsyncExecutable> func =
graph->compile({{out, {}}});
func->execute();
};
std::shared_ptr<cg::ComputingGraph> fastrun_ignore_batchsize_graph =
ComputingGraph::make();
fastrun_ignore_batchsize_graph->options()
.fast_run_config.shared_batch_size = 20;
run(fastrun_ignore_batchsize_graph, inps0);
size_t nr_set_inp0 = cache_set_history.size();
if (expect_nr_cache_set_inp0) {
ASSERT_EQ(cache_set_history.size(), expect_nr_cache_set_inp0);
}
run(fastrun_ignore_batchsize_graph, inps1);
size_t nr_set_total = expect_nr_cache_set_inp1 + nr_set_inp0;
ASSERT_EQ(cache_set_history.size(), nr_set_total);
}
TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution) {
REQUIRE_GPU(1);
test_fastrun_opr<opr::Convolution, 2>(
{TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}},
{TensorShape{1, 3, 36, 36}, TensorShape{4, 3, 3, 3}});
test_fastrun_opr<opr::ConvolutionBackwardData, 2>(
{TensorShape{12, 4, 23, 29}, TensorShape{4, 5, 3, 2}},
{TensorShape{2, 4, 23, 29}, TensorShape{4, 5, 3, 2}});
test_fastrun_opr<opr::ConvolutionBackwardFilter, 3>(
{TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28},
TensorShape{5, 4, 3, 2}},
{TensorShape{2, 4, 23, 29}, TensorShape{2, 5, 21, 28},
TensorShape{5, 4, 3, 2}});
}
TEST(TestOprDNN, FastrunIgnoreBatchSizeConvBias) {
REQUIRE_GPU(1);
test_fastrun_opr<opr::ConvBias, 3>(
{TensorShape{20, 16, 50, 50}, TensorShape{24, 16, 3, 3},
TensorShape{1, 24, 1, 1}},
{TensorShape{1, 16, 50, 50}, TensorShape{24, 16, 3, 3},
TensorShape{1, 24, 1, 1}});
}
TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution3D) {
REQUIRE_GPU(1);
test_fastrun_opr<opr::Convolution3D, 2>(
{TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}},
{TensorShape{3, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}});
test_fastrun_opr<opr::Convolution3DBackwardData, 2>(
{TensorShape{14, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}},
{TensorShape{4, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}});
test_fastrun_opr<opr::Convolution3DBackwardFilter, 3>(
{TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18},
TensorShape{16, 16, 1, 1, 1}},
{TensorShape{4, 16, 18, 18, 18}, TensorShape{4, 16, 18, 18, 18},
TensorShape{16, 16, 1, 1, 1}});
}
TEST(TestOprDNN, FastrunIgnoreBatchSizeLocalShare) {
REQUIRE_GPU(1);
opr::LocalShare::Param local_share_param;
local_share_param.mode = opr::LocalShare::Param::Mode::CROSS_CORRELATION;
local_share_param.pad_h = local_share_param.pad_w = 1;
local_share_param.stride_h = local_share_param.stride_w = 1;
local_share_param.spatial_groups_h = local_share_param.spatial_groups_w = 2;
test_fastrun_opr<opr::LocalShareForward, 2>(
{TensorShape{32, 2, 23, 23}, TensorShape{2, 2, 2, 2, 2, 7}},
{TensorShape{3, 2, 23, 23}, TensorShape{2, 2, 2, 2, 2, 7}}, 0, 0,
local_share_param);
test_fastrun_opr<opr::LocalShareBackwardData, 3>(
{TensorShape{3, 3, 128, 1, 1, 128}, TensorShape{32, 128, 24, 24},
TensorShape{32, 128, 24, 24}},
{TensorShape{3, 3, 128, 1, 1, 128}, TensorShape{2, 128, 24, 24},
TensorShape{2, 128, 24, 24}});
test_fastrun_opr<opr::LocalShareBackwardFilter, 3>(
{TensorShape{12, 3, 36, 36}, TensorShape{12, 4, 35, 35},
TensorShape{3, 3, 3, 3, 3, 4}},
{TensorShape{4, 3, 36, 36}, TensorShape{4, 4, 35, 35},
TensorShape{3, 3, 3, 3, 3, 4}});
}
TEST(TestOprDNN, FastrunIgnoreBatchSizeDeformableConv) {
REQUIRE_GPU(1);
test_fastrun_opr<opr::DeformableConvForward, 4>(
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3},
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}},
{TensorShape{4, 6, 20, 20}, TensorShape{6, 6, 3, 3},
TensorShape{4, 18, 18, 18}, TensorShape{4, 9, 18, 18}});
test_fastrun_opr<opr::DeformableConvBackwardData, 5>(
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3},
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18},
TensorShape{12, 6, 18, 18}},
{TensorShape{4, 6, 20, 20},
TensorShape{6, 6, 3, 3},
TensorShape{4, 18, 18, 18},
TensorShape{4, 9, 18, 18},
TensorShape{4, 6, 18, 18}});
test_fastrun_opr<opr::DeformableConvBackwardFilter, 5>(
{TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3},
TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18},
TensorShape{12, 6, 18, 18}},
{TensorShape{4, 6, 20, 20}, TensorShape{6, 6, 3, 3},
TensorShape{4, 18, 18, 18}, TensorShape{4, 9, 18, 18},
TensorShape{4, 6, 18, 18}});
}
TEST(TestOprDNN, FastrunIgnoreBatchSizeMatrixMul) {
REQUIRE_GPU(1);
//! fastrun_shared_batch_size == 20
//! {20(12), 12(1)}, {12(12), 20(1)} -> {20(12), 20(1)} origin
//! {12(10), 20(1)}, {12(12), 20(1)} -> {20(12), 20(1)} transA
//! {12(10), 20(1)}, {20(12), 12(1)} -> {20(12), 20(1)} transA, transB
//! {20(12), 12(1)}, {20(12), 12(1)} -> {20(12), 20(1)} transB
//!
//! {20(12), 12(1)}, {12(12), 20(1)} -> {20(12), 20(1)} origin duplicate
//! {12(4), 20(1)}, {12(12), 20(1)} -> {20(12), 20(1)} transA
//! {12(4), 20(1)}, {20(12), 12(1)} -> {20(12), 20(1)} transA, transB
//! {20(12), 12(1)}, {20(12), 12(1)} -> {20(12), 20(1)} transB duplicate
test_fastrun_opr<opr::MatrixMul, 2>(
{TensorShape{10, 12}, TensorShape{12, 12}},
{TensorShape{4, 12}, TensorShape{12, 12}}, 4, 2);
}
TEST(TestOprDNN, FastrunIgnoreBatchSizeBatchedMatrixMul) {
REQUIRE_GPU(1);
//! fastrun_shared_batch_size == 20
//! {20(48), 6(8), 8(1)}, {20(32), 8(4), 4(1)} -> {20(24), 6(4), 4(1)} origin
//! {20(48), 8(6), 6(1)}, {20(32), 8(4), 4(1)} -> {20(24), 6(4), 4(1)} transA
//! {20(48), 8(6), 6(1)}, {20(32), 4(8), 8(1)} -> {20(24), 6(4), 4(1)} transA, transB
//! {20(48), 6(8), 8(1)}, {20(32), 4(8), 8(1)} -> {20(24), 6(4), 4(1)} transB
//!
//! {20(48), 6(8), 8(1)}, {20(32), 8(4), 4(1)} -> {20(24), 6(4), 4(1)} origin duplicate
//! {20(48), 8(6), 6(1)}, {20(32), 8(4), 4(1)} -> {20(24), 6(4), 4(1)} transA duplicate
//! {20(48), 8(6), 6(1)}, {20(32), 4(8), 8(1)} -> {20(24), 6(4), 4(1)} transA, transB duplicate
//! {20(48), 6(8), 8(1)}, {20(32), 4(8), 8(1)} -> {20(24), 6(4), 4(1)} transB duplicate
test_fastrun_opr<opr::BatchedMatrixMul, 2>(
{TensorShape{12, 6, 8}, TensorShape{12, 8, 4}},
{TensorShape{4, 6, 8}, TensorShape{4, 8, 4}});
}
#endif // MGB_ENABLE_FASTRUN
#endif // MGB_CUDA
} // anonymous namespace
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -460,12 +460,13 @@ mgb::make_callback_copy(SymbolVar dev, HostTensorND &host, bool sync) {
/* ========================== PersistentCacheHook ========================== */
class PersistentCacheHook::HookedImpl final : public PersistentCache {
GetHook m_on_get;
Hook m_on_get, m_on_set;
public:
std::shared_ptr<PersistentCache> orig_impl;
HookedImpl(GetHook on_get) : m_on_get{std::move(on_get)} {}
HookedImpl(Hook on_get, Hook on_set)
: m_on_get{std::move(on_get)}, m_on_set{std::move(on_set)} {}
Maybe<Blob> get(const std::string& category, const Blob& key) override {
auto ret = orig_impl->get(category, key);
......@@ -476,12 +477,18 @@ public:
void put(const std::string& category, const Blob& key,
const Blob& value) override {
m_on_set(category, key.ptr, key.size, value.ptr,
value.size);
orig_impl->put(category, key, value);
}
};
PersistentCacheHook::PersistentCacheHook(GetHook on_get)
: m_impl{std::make_shared<HookedImpl>(std::move(on_get))} {
PersistentCacheHook::Hook PersistentCacheHook::default_set_hook =
[](const std::string&, const void*, size_t, const void*, size_t) {};
PersistentCacheHook::PersistentCacheHook(Hook on_get, Hook on_set)
: m_impl{std::make_shared<HookedImpl>(std::move(on_get),
std::move(on_set))} {
m_impl->orig_impl = PersistentCache::set_impl(m_impl);
}
......
......@@ -512,17 +512,17 @@ bool check_device_type_avaiable(CompNode::DeviceType device_type);
//! hook persistent cache get calls during the lifetime
class PersistentCacheHook {
class HookedImpl;
std::shared_ptr<HookedImpl> m_impl;
public:
//! if value is not available, \p val and \p val_size would be zero
using GetHook = thin_function<void(const std::string& category,
const void* key, size_t key_size,
const void* val, size_t val_size)>;
PersistentCacheHook(GetHook on_get);
using Hook = thin_function<void(const std::string& category,
const void* key, size_t key_size,
const void* val, size_t val_size)>;
PersistentCacheHook(Hook on_get, Hook on_set = default_set_hook);
~PersistentCacheHook();
private:
static Hook default_set_hook;
class HookedImpl;
std::shared_ptr<HookedImpl> m_impl;
};
//! skip a testcase if xpu not available
#define REQUIRE_XPU(n) do { \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册