提交 72cc64e4 编写于 作者: T Tomasz Patejko

Device blobs are created only in training. Added testing attribute

上级 2d955275
......@@ -22,6 +22,22 @@ namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
namespace {
template <typename T, typename... Args>
std::shared_ptr<T> insert_to_context(const std::string& key,
const MKLDNNDeviceContext& dev_ctx,
Args&&... args) {
auto p = std::static_pointer_cast<T, void>(dev_ctx.GetBlob(key));
if (!p) {
p = std::make_shared<T>(args...);
dev_ctx.SetBlob(key, std::static_pointer_cast<void, T>(p));
}
return p;
}
} // namespace
template <typename T>
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -42,15 +58,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto output_data = out->mutable_data<T>(ctx.GetPlace());
mid->mutable_data<T>(ctx.GetPlace());
const std::string key = ctx.op().Output("Out");
const std::string key_src_memory = key + "@lrn_src_memory";
const std::string key_pd = key + "@lrn_pd";
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha");
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
const bool is_test = ctx.Attr<bool>("is_test");
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
......@@ -71,28 +83,47 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
beta,
k};
auto forward_pd = std::make_shared<mkldnn::lrn_forward::primitive_desc>(
forward_desc, mkldnn_engine);
dev_ctx.SetBlob(key_pd, forward_pd);
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto src_memory = std::make_shared<mkldnn::memory>(
src_memory_pd, static_cast<void*>(const_cast<float*>(input_data)));
dev_ctx.SetBlob(key_src_memory, src_memory);
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
static_cast<void*>(output_data)};
auto workspace_md = forward_pd->workspace_primitive_desc();
auto workspace_memory = std::make_shared<mkldnn::memory>(workspace_md);
std::unique_ptr<mkldnn::lrn_forward> forward_op = nullptr;
dev_ctx.SetBlob(key_workspace_memory, workspace_memory);
if (!is_test) {
const std::string key = ctx.op().Output("Out");
const std::string key_src_memory = key + "@lrn_src_memory";
const std::string key_pd = key + "@lrn_pd";
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
auto forward_pd = insert_to_context<mkldnn::lrn_forward::primitive_desc>(
key_pd, dev_ctx, forward_desc, mkldnn_engine);
auto src_memory = insert_to_context<mkldnn::memory>(
key_src_memory, dev_ctx, src_memory_pd);
auto forward_op = mkldnn::lrn_forward{*forward_pd, *src_memory,
*workspace_memory, dst_memory};
src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data)));
auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc());
forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory,
*workspace_memory, dst_memory});
} else {
auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
auto src_memory = mkldnn::memory{
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()};
forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory,
workspace_memory, dst_memory});
}
std::vector<mkldnn::primitive> pipeline = {forward_op};
std::vector<mkldnn::primitive> pipeline = {*forward_op};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
};
......
......@@ -214,6 +214,7 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddAttr<bool>("is_test", "").SetDefault(false);
AddComment(R"DOC(
Local Response Normalization Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册