提交 b4dfba17 编写于 作者: T tensor-tang

refine lrn_op cpu forward and speedup

test=develop
上级 1be85d01
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/lrn_op.h" #include "paddle/fluid/operators/lrn_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -29,34 +30,43 @@ struct LRNFunctor<platform::CPUDeviceContext, T> { ...@@ -29,34 +30,43 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& input, framework::Tensor* out, const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n, framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) { T k, T alpha, T beta) {
auto x_v = framework::EigenVector<T>::Flatten(input); const T* idata = input.data<T>();
auto place = ctx.GetPlace();
const int start = -(n - 1) / 2; auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
const int end = start + n; T* odata = out->mutable_data<T>(place);
T* mdata = mid->mutable_data<T>(place);
auto e_mid = framework::EigenTensor<T, 4>::From(*mid); Tensor squared;
e_mid = e_mid.constant(k); T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
std::memset(sdata, 0, sizeof(T) * squared.numel());
auto e_x = framework::EigenTensor<T, 4>::From(input); for (int i = 0; i < mid->numel(); ++i) {
for (int m = 0; m < N; m++) { mdata[i] = k;
for (int i = 0; i < C; i++) { }
for (int c = start; c < end; c++) { int img_size = H * W;
int ch = i + c; int fea_size = C * img_size;
if (ch >= 0 && ch < C) { int pre_pad = (n - 1) / 2;
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}), // compute batches one by one
Eigen::array<int, 4>({{1, 1, H, W}})); for (int i = 0; i < N; ++i) {
blas.VSQR(fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}), // init the first channel of mid
Eigen::array<int, 4>({{1, 1, H, W}})); for (int c = 0; c < n; ++c) {
blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size);
s += alpha * r.square(); }
} for (int c = 1; c < C; ++c) {
} // copy previous scale
int mid_offset = i * fea_size + c * img_size;
std::memcpy(mdata + mid_offset, mdata + mid_offset - img_size,
img_size * sizeof(T));
// add last
blas.AXPY(img_size, alpha, sdata + (c + n - 1) * img_size,
mdata + mid_offset);
// sub rest
blas.AXPY(img_size, -alpha, sdata + (c - 1) * img_size,
mdata + mid_offset);
} }
} }
// compute the final output
auto out_e = framework::EigenVector<T>::Flatten(*out); blas.VPOW(mid->numel(), mdata, -beta, odata);
out_e = x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta); blas.VMUL(mid->numel(), odata, idata, odata);
} }
}; };
template struct LRNFunctor<platform::CPUDeviceContext, float>; template struct LRNFunctor<platform::CPUDeviceContext, float>;
...@@ -156,6 +166,9 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -156,6 +166,9 @@ class LRNOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4."); PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
int n = ctx->Attrs().Get<int>("n");
PADDLE_ENFORCE(n > 0 && n % 2 == 1, "n should be positive odd value");
ctx->SetOutputDim("Out", x_dim); ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("MidOut", x_dim); ctx->SetOutputDim("MidOut", x_dim);
......
...@@ -60,7 +60,6 @@ class LRNKernel : public framework::OpKernel<T> { ...@@ -60,7 +60,6 @@ class LRNKernel : public framework::OpKernel<T> {
T beta = ctx.Attr<float>("beta"); T beta = ctx.Attr<float>("beta");
T k = ctx.Attr<float>("k"); T k = ctx.Attr<float>("k");
PADDLE_ENFORCE(n > 0, "n should >= 0");
PADDLE_ENFORCE(alpha >= 0.0, "alpha should >= 0.0"); PADDLE_ENFORCE(alpha >= 0.0, "alpha should >= 0.0");
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0"); PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0"); PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册