未验证 提交 64f7516a 编写于 作者: T tensor-tang 提交者: GitHub

fix lrn on mac (#14426)

* rename and fix blas vsqr

test=develop

* update
上级 8a1eeec5
...@@ -46,7 +46,7 @@ struct LRNFunctor<platform::CPUDeviceContext, T> { ...@@ -46,7 +46,7 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
int pre_pad = (n - 1) / 2; int pre_pad = (n - 1) / 2;
// compute batches one by one // compute batches one by one
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
blas.VSQR(fea_size, idata + i * fea_size, sdata + pre_pad * img_size); blas.VSQUARE(fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
// init the first channel of mid // init the first channel of mid
for (int c = 0; c < n; ++c) { for (int c = 0; c < n; ++c) {
blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size); blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size);
......
...@@ -153,7 +153,7 @@ class Blas { ...@@ -153,7 +153,7 @@ class Blas {
void VEXP(int n, const T* x, T* y) const; void VEXP(int n, const T* x, T* y) const;
template <typename T> template <typename T>
void VSQR(int n, const T* x, T* y) const; void VSQUARE(int n, const T* x, T* y) const;
template <typename T> template <typename T>
void VPOW(int n, const T* x, T alpha, T* y) const; void VPOW(int n, const T* x, T alpha, T* y) const;
...@@ -245,8 +245,8 @@ class BlasT : private Blas<DeviceContext> { ...@@ -245,8 +245,8 @@ class BlasT : private Blas<DeviceContext> {
} }
template <typename... ARGS> template <typename... ARGS>
void VSQR(ARGS... args) const { void VSQUARE(ARGS... args) const {
Base()->template VSQR<T>(args...); Base()->template VSQUARE<T>(args...);
} }
template <typename... ARGS> template <typename... ARGS>
......
...@@ -105,7 +105,7 @@ struct CBlas<float> { ...@@ -105,7 +105,7 @@ struct CBlas<float> {
} }
template <typename... ARGS> template <typename... ARGS>
static void VSQR(ARGS... args) { static void VSQUARE(ARGS... args) {
platform::dynload::vsSqr(args...); platform::dynload::vsSqr(args...);
} }
...@@ -195,7 +195,7 @@ struct CBlas<double> { ...@@ -195,7 +195,7 @@ struct CBlas<double> {
} }
template <typename... ARGS> template <typename... ARGS>
static void VSQR(ARGS... args) { static void VSQUARE(ARGS... args) {
platform::dynload::vdSqr(args...); platform::dynload::vdSqr(args...);
} }
...@@ -262,7 +262,9 @@ struct CBlas<platform::float16> { ...@@ -262,7 +262,9 @@ struct CBlas<platform::float16> {
} }
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); } static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); } static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); } static void VSQUARE(...) {
PADDLE_THROW("float16 VSQUARE not supported on CPU");
}
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); }; static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); }; static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
...@@ -423,12 +425,12 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const { ...@@ -423,12 +425,12 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const { void Blas<platform::CPUDeviceContext>::VSQUARE(int n, const T *x, T *y) const {
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
CBlas<T>::VSQR(n, x, y); CBlas<T>::VSQUARE(n, x, y);
#else #else
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = std::sqrt(x[i]); y[i] = x[i] * x[i];
} }
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册