diff --git a/paddle/math/MathFunctions.cpp b/paddle/math/MathFunctions.cpp index de404cad89fba8021b8645a40e25c1f5b7e86596..f48119aa511578b21602a225277f01b4c6a9e9a8 100644 --- a/paddle/math/MathFunctions.cpp +++ b/paddle/math/MathFunctions.cpp @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "MathFunctions.h" +#include "paddle/math/MathFunctions.h" #include "hl_matrix_apply.cuh" #include "hl_matrix_ops.cuh" #include "paddle/utils/DynamicLoader.h" @@ -240,6 +240,36 @@ template <> void vAdd(const int n, const double* a, const double* b, double* r) { vdAdd(n, a, b, r); } + +template <> +void vTanh(const int n, const float* a, float* r) { + vsTanh(n, a, r); +} + +template <> +void vTanh(const int n, const double* a, double* r) { + vdTanh(n, a, r); +} + +template <> +void vInvSqrt(const int n, const float* a, float* r) { + vsInvSqrt(n, a, r); +} + +template <> +void vInvSqrt(const int n, const double* a, double* r) { + vdInvSqrt(n, a, r); +} + +template <> +void vLog1p(const int n, const float* a, float* r) { + vsLog1p(n, a, r); +} + +template <> +void vLog1p(const int n, const double* a, double* r) { + vdLog1p(n, a, r); +} #else DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a)); @@ -277,17 +307,6 @@ void vAdd(const int n, const T* a, const T* b, T* r) { n); } -template void vExp(const int n, const float* a, float* r); -template void vExp(const int n, const double* a, double* r); -template void vLog(const int n, const float* a, float* r); -template void vLog(const int n, const double* a, double* r); -template void vPow(const int n, const float* a, const float b, float* r); -template void vPow(const int n, const double* a, const double b, double* r); -template void vAdd(const int n, const float* a, const float* b, float* r); -template void vAdd(const int n, const double* a, const double* b, double* r); - -#endif - DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a)); template void vInvSqrt(const int n, const T* a, T* r) { @@ -311,11 +330,19 @@ void vTanh(const int n, const T* a, T* r) { binary::vTanh(), const_cast(a), r, 1, n, n, n); } +template void vExp(const int n, const float* a, float* r); +template void vExp(const int n, const double* a, double* r); +template void vLog(const int n, const float* a, float* r); +template void vLog(const int n, const double* a, double* r); +template void vPow(const int n, const float* a, const float b, float* r); +template void vPow(const int n, const double* a, const double b, double* r); +template void vAdd(const int n, const float* a, const float* b, float* r); +template void vAdd(const int n, const double* a, const double* b, double* r); template void vInvSqrt(const int n, const double* a, double* r); template void vInvSqrt(const int n, const float* a, float* r); template void vLog1p(const int n, const float* a, float* r); template void vLog1p(const int n, const double* a, double* r); template void vTanh(const int n, const float* a, float* r); template void vTanh(const int n, const double* a, double* r); - +#endif } // namespace paddle