未验证 提交 1b67bc02 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #9329 from tpatejko/tpatejko/mkldnn-lrn

Improvements for MKLDNN LRN
......@@ -36,6 +36,14 @@ std::shared_ptr<T> insert_to_context(const std::string& key,
return p;
}
template <typename... Args>
void run_primitive(Args&&... args) {
auto forward_op = mkldnn::lrn_forward{args...};
std::vector<mkldnn::primitive> pipeline = {forward_op};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
} // namespace
template <typename T>
......@@ -87,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
static_cast<void*>(output_data)};
std::unique_ptr<mkldnn::lrn_forward> forward_op = nullptr;
if (!is_test) {
const std::string key = ctx.op().Output("Out");
const std::string key_src_memory = key + "@lrn_src_memory";
......@@ -108,9 +114,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc());
forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory,
*workspace_memory, dst_memory});
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
} else {
auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
......@@ -119,12 +123,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()};
forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory,
workspace_memory, dst_memory});
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
}
std::vector<mkldnn::primitive> pipeline = {*forward_op};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
};
......@@ -136,6 +136,9 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
"MKLDNN LRN must use float data.");
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"MKLDNN LRN must use CPUPlace.");
PADDLE_ENFORCE(
!ctx.Attr<bool>("is_test"),
"is_test attribute should be set to False in training phase.");
auto x = ctx.Input<Tensor>("X");
......
......@@ -155,8 +155,8 @@ class LRNOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
ctx->SetOutputDim("Out", x_dim);
ctx->SetOutputDim("MidOut", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("MidOut", x_dim);
}
framework::OpKernelType GetExpectedKernelType(
......
......@@ -97,5 +97,24 @@ class TestLRNMKLDNNOp(TestLRNOp):
self.check_output(atol=0.002)
class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
def get_attrs(self):
attrs = TestLRNMKLDNNOp.get_attrs(self)
attrs['is_test'] = True
return attrs
def test_check_grad_normal(self):
def check_raise_is_test():
try:
self.check_grad(['X'], 'Out', max_relative_error=0.01)
except Exception as e:
t = \
"is_test attribute should be set to False in training phase."
if t in str(e):
raise AttributeError
self.assertRaises(AttributeError, check_raise_is_test)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册