From 4e66e0eb1fc1210566474f775d31ef17bea93e86 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 6 May 2022 16:53:34 +0800 Subject: [PATCH] feat(megdnn/softmax): add softmax operator in OpenCL GitOrigin-RevId: e207d6ceb43f3616d0daf0f415706ceaecc8c7de --- dnn/src/naive/softmax/opr_impl.h | 6 ++-- dnn/test/naive/softmax.cpp | 59 +++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/dnn/src/naive/softmax/opr_impl.h b/dnn/src/naive/softmax/opr_impl.h index 39cefbe27..e0a782771 100644 --- a/dnn/src/naive/softmax/opr_impl.h +++ b/dnn/src/naive/softmax/opr_impl.h @@ -4,7 +4,7 @@ namespace megdnn { namespace naive { -class SoftmaxForwardImpl final : public SoftmaxForward { +class SoftmaxForwardImpl : public SoftmaxForward { public: using SoftmaxForward::SoftmaxForward; void exec( @@ -16,7 +16,7 @@ public: } }; -class SoftmaxBackwardImpl final : public SoftmaxBackward { +class SoftmaxBackwardImpl : public SoftmaxBackward { public: using SoftmaxBackward::SoftmaxBackward; void exec( @@ -32,4 +32,4 @@ public: } // namespace naive } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/softmax.cpp b/dnn/test/naive/softmax.cpp index 94e278c86..4ac616bc5 100644 --- a/dnn/test/naive/softmax.cpp +++ b/dnn/test/naive/softmax.cpp @@ -42,4 +42,61 @@ TEST_F(NAIVE, SOFTMAX_BACKWARD) { {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output}); -} \ No newline at end of file +} + +TEST_F(NAIVE, SOFTMAX_FORWARD_NHWCD4) { + Checker checker(handle(), false); + Softmax::Param param{0}; + + TensorND input1 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}); + TensorND output1 = TensorValue( + {1, 2, 1, 2, 4}, dtype::Float32(), + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + checker.set_param(param).exect(Testcase{input1, {}}, Testcase{{}, output1}); + + TensorND input2 = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, + 16, 20, 24, 28, 17, 21, 25, 29, 18, 22, 26, 30, 19, 23, 27, 31}); + TensorND output2 = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01}); + checker.set_param(param).exect(Testcase{input2, {}}, Testcase{{}, output2}); +} + +TEST_F(NAIVE, SOFTMAX_BACKWARD_NHWCD4) { + Checker checker(handle(), false); + Softmax::Param param{0}; + + TensorND input = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, 1.12535162e-07, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, + 9.99999887e-01, 9.99999887e-01, 9.99999887e-01, 9.99999887e-01}); + + TensorND diff = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + TensorND output = TensorValue( + {2, 2, 1, 2, 4}, dtype::Float32(), + {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + + checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output}); +} -- GitLab