提交 32929cdf 编写于 作者: K Krzysztof Binias

Cache input data

上级 0aa01929
......@@ -58,6 +58,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
std::vector<int> src_tz = framework::vectorize2int(src->dims());
const std::string key = gethash(src_tz, algorithm);
const std::string key_src_data = key + "@eltwise_fwd_src_data";
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
const std::string key_dst_mem = key + "@eltwise_fwd_dst_mem";
const std::string key_fwd = key + "@eltwise_fwd";
......@@ -65,6 +66,10 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
dev_ctx.GetBlob(key_fwd));
// save input data to be referred in backward path
auto p_src_data = std::make_shared<const T *>(src_data);
dev_ctx.SetBlob(key_src_data, p_src_data);
if (p_fwd == nullptr) {
// create memory description
auto data_md = src_tz.size() == 2
......@@ -131,11 +136,19 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
std::vector<int> src_tz = framework::vectorize2int(out->dims());
const std::string key = gethash(src_tz, algorithm);
const std::string key_diff_src_mem = key + "@eltwise_diff_src_mem";
const std::string key_diff_dst_mem = key + "@eltwise_diff_dst_mem";
const std::string key_grad = key + "@eltwise_grad";
const std::string key_src_data = key + "@eltwise_fwd_src_data";
const auto p_src_data =
std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
auto p_src_mem =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
p_src_mem->set_data_handle(*p_src_data.get());
auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
dev_ctx.GetBlob(key_grad));
......@@ -167,9 +180,6 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
bwd_desc, mkldnn_engine, *p_fwd_pd);
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
const std::shared_ptr<void> p_src_mem = dev_ctx.GetBlob(key_src_mem);
p_grad = std::make_shared<mkldnn::eltwise_backward>(
eltwise_bwd_prim_desc, *static_cast<mkldnn::memory *>(p_src_mem.get()),
*(static_cast<mkldnn::memory *>(p_diff_dst_mem.get())),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册