未验证 提交 25aa4539 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #10934 from tensor-tang/mklml_funcs

speedup vInvSqrt vLogqp vTanh with mklml
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "MathFunctions.h" #include "paddle/math/MathFunctions.h"
#include "hl_matrix_apply.cuh" #include "hl_matrix_apply.cuh"
#include "hl_matrix_ops.cuh" #include "hl_matrix_ops.cuh"
#include "paddle/utils/DynamicLoader.h" #include "paddle/utils/DynamicLoader.h"
...@@ -240,6 +240,36 @@ template <> ...@@ -240,6 +240,36 @@ template <>
void vAdd<double>(const int n, const double* a, const double* b, double* r) { void vAdd<double>(const int n, const double* a, const double* b, double* r) {
vdAdd(n, a, b, r); vdAdd(n, a, b, r);
} }
template <>
void vTanh<float>(const int n, const float* a, float* r) {
vsTanh(n, a, r);
}
template <>
void vTanh<double>(const int n, const double* a, double* r) {
vdTanh(n, a, r);
}
template <>
void vInvSqrt<float>(const int n, const float* a, float* r) {
vsInvSqrt(n, a, r);
}
template <>
void vInvSqrt<double>(const int n, const double* a, double* r) {
vdInvSqrt(n, a, r);
}
template <>
void vLog1p<float>(const int n, const float* a, float* r) {
vsLog1p(n, a, r);
}
template <>
void vLog1p<double>(const int n, const double* a, double* r) {
vdLog1p(n, a, r);
}
#else #else
DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a)); DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a));
...@@ -277,17 +307,6 @@ void vAdd(const int n, const T* a, const T* b, T* r) { ...@@ -277,17 +307,6 @@ void vAdd(const int n, const T* a, const T* b, T* r) {
n); n);
} }
template void vExp(const int n, const float* a, float* r);
template void vExp(const int n, const double* a, double* r);
template void vLog(const int n, const float* a, float* r);
template void vLog(const int n, const double* a, double* r);
template void vPow(const int n, const float* a, const float b, float* r);
template void vPow(const int n, const double* a, const double b, double* r);
template void vAdd(const int n, const float* a, const float* b, float* r);
template void vAdd(const int n, const double* a, const double* b, double* r);
#endif
DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a)); DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a));
template <class T> template <class T>
void vInvSqrt(const int n, const T* a, T* r) { void vInvSqrt(const int n, const T* a, T* r) {
...@@ -311,11 +330,19 @@ void vTanh(const int n, const T* a, T* r) { ...@@ -311,11 +330,19 @@ void vTanh(const int n, const T* a, T* r) {
binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n); binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n);
} }
template void vExp(const int n, const float* a, float* r);
template void vExp(const int n, const double* a, double* r);
template void vLog(const int n, const float* a, float* r);
template void vLog(const int n, const double* a, double* r);
template void vPow(const int n, const float* a, const float b, float* r);
template void vPow(const int n, const double* a, const double b, double* r);
template void vAdd(const int n, const float* a, const float* b, float* r);
template void vAdd(const int n, const double* a, const double* b, double* r);
template void vInvSqrt(const int n, const double* a, double* r); template void vInvSqrt(const int n, const double* a, double* r);
template void vInvSqrt(const int n, const float* a, float* r); template void vInvSqrt(const int n, const float* a, float* r);
template void vLog1p(const int n, const float* a, float* r); template void vLog1p(const int n, const float* a, float* r);
template void vLog1p(const int n, const double* a, double* r); template void vLog1p(const int n, const double* a, double* r);
template void vTanh(const int n, const float* a, float* r); template void vTanh(const int n, const float* a, float* r);
template void vTanh(const int n, const double* a, double* r); template void vTanh(const int n, const double* a, double* r);
#endif
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册