提交 537f57a4 编写于 作者: T tensor-tang

fix undefine error on gpu

上级 315e08eb
...@@ -307,6 +307,29 @@ void vAdd(const int n, const T* a, const T* b, T* r) { ...@@ -307,6 +307,29 @@ void vAdd(const int n, const T* a, const T* b, T* r) {
n); n);
} }
DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a));
template <class T>
void vInvSqrt(const int n, const T* a, T* r) {
hl_cpu_apply_binary_op<T, binary::vInvSqrt<T>, 0, 0>(
binary::vInvSqrt<T>(), const_cast<T*>(a), r, 1, n, n, n);
}
DEFINE_MATRIX_BINARY_OP(vLog1p, b = std::log(1.0f + a));
template <class T>
void vLog1p(const int n, const T* a, T* r) {
hl_cpu_apply_binary_op<T, binary::vLog1p<T>, 0, 0>(
binary::vLog1p<T>(), const_cast<T*>(a), r, 1, n, n, n);
}
DEFINE_MATRIX_BINARY_OP(vTanh, T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
b = 2.0 / (1.0 + std::exp(tmp)) - 1.0);
template <class T>
void vTanh(const int n, const T* a, T* r) {
hl_cpu_apply_binary_op<T, binary::vTanh<T>, 0, 0>(
binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n);
}
template void vExp(const int n, const float* a, float* r); template void vExp(const int n, const float* a, float* r);
template void vExp(const int n, const double* a, double* r); template void vExp(const int n, const double* a, double* r);
template void vLog(const int n, const float* a, float* r); template void vLog(const int n, const float* a, float* r);
...@@ -315,6 +338,11 @@ template void vPow(const int n, const float* a, const float b, float* r); ...@@ -315,6 +338,11 @@ template void vPow(const int n, const float* a, const float b, float* r);
template void vPow(const int n, const double* a, const double b, double* r); template void vPow(const int n, const double* a, const double b, double* r);
template void vAdd(const int n, const float* a, const float* b, float* r); template void vAdd(const int n, const float* a, const float* b, float* r);
template void vAdd(const int n, const double* a, const double* b, double* r); template void vAdd(const int n, const double* a, const double* b, double* r);
template void vInvSqrt(const int n, const double* a, double* r);
template void vInvSqrt(const int n, const float* a, float* r);
template void vLog1p(const int n, const float* a, float* r);
template void vLog1p(const int n, const double* a, double* r);
template void vTanh(const int n, const float* a, float* r);
template void vTanh(const int n, const double* a, double* r);
#endif #endif
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册