提交 14ba67c0 编写于 作者: T Tomasz Patejko

Function for running MKLDNN primitive added. Unittest added for is_test attribute

上级 e027eb40
...@@ -36,6 +36,14 @@ std::shared_ptr<T> insert_to_context(const std::string& key, ...@@ -36,6 +36,14 @@ std::shared_ptr<T> insert_to_context(const std::string& key,
return p; return p;
} }
template <typename... Args>
void run_primitive(Args&&... args) {
auto forward_op = mkldnn::lrn_forward{args...};
std::vector<mkldnn::primitive> pipeline = {forward_op};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
} // namespace } // namespace
template <typename T> template <typename T>
...@@ -87,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -87,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
static_cast<void*>(output_data)}; static_cast<void*>(output_data)};
std::unique_ptr<mkldnn::lrn_forward> forward_op = nullptr;
if (!is_test) { if (!is_test) {
const std::string key = ctx.op().Output("Out"); const std::string key = ctx.op().Output("Out");
const std::string key_src_memory = key + "@lrn_src_memory"; const std::string key_src_memory = key + "@lrn_src_memory";
...@@ -108,9 +114,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -108,9 +114,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
key_workspace_memory, dev_ctx, key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc()); forward_pd->workspace_primitive_desc());
forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory, run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
*workspace_memory, dst_memory});
} else { } else {
auto forward_pd = auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
...@@ -119,12 +123,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -119,12 +123,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto workspace_memory = auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()}; mkldnn::memory{forward_pd.workspace_primitive_desc()};
forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory, run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
workspace_memory, dst_memory});
} }
std::vector<mkldnn::primitive> pipeline = {*forward_op};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} }
}; };
...@@ -136,6 +136,9 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -136,6 +136,9 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
"MKLDNN LRN must use float data."); "MKLDNN LRN must use float data.");
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"MKLDNN LRN must use CPUPlace."); "MKLDNN LRN must use CPUPlace.");
PADDLE_ENFORCE(
!ctx.Attr<bool>("is_test"),
"is_test attribute should be set to False in training phase.");
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
......
...@@ -155,8 +155,8 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -155,8 +155,8 @@ class LRNOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4."); PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
ctx->SetOutputDim("Out", x_dim); ctx->SetOutputDim("Out", x_dim);
ctx->SetOutputDim("MidOut", x_dim);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("MidOut", x_dim);
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
......
...@@ -97,5 +97,24 @@ class TestLRNMKLDNNOp(TestLRNOp): ...@@ -97,5 +97,24 @@ class TestLRNMKLDNNOp(TestLRNOp):
self.check_output(atol=0.002) self.check_output(atol=0.002)
class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp):
def get_attrs(self):
attrs = TestLRNMKLDNNOp.get_attrs(self)
attrs['is_test'] = True
return attrs
def test_check_grad_normal(self):
def check_raise_is_test():
try:
self.check_grad(['X'], 'Out', max_relative_error=0.01)
except Exception as e:
t = \
"is_test attribute should be set to False in training phase."
if t in str(e):
raise AttributeError
self.assertRaises(AttributeError, check_raise_is_test)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册