提交 b4dfba17 编写于 作者: T tensor-tang

refine lrn_op cpu forward and speedup

test=develop
上级 1be85d01
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/lrn_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
......@@ -29,34 +30,43 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) {
auto x_v = framework::EigenVector<T>::Flatten(input);
const int start = -(n - 1) / 2;
const int end = start + n;
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
auto e_x = framework::EigenTensor<T, 4>::From(input);
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
for (int c = start; c < end; c++) {
int ch = i + c;
if (ch >= 0 && ch < C) {
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
s += alpha * r.square();
}
}
const T* idata = input.data<T>();
auto place = ctx.GetPlace();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
T* odata = out->mutable_data<T>(place);
T* mdata = mid->mutable_data<T>(place);
Tensor squared;
T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
std::memset(sdata, 0, sizeof(T) * squared.numel());
for (int i = 0; i < mid->numel(); ++i) {
mdata[i] = k;
}
int img_size = H * W;
int fea_size = C * img_size;
int pre_pad = (n - 1) / 2;
// compute batches one by one
for (int i = 0; i < N; ++i) {
blas.VSQR(fea_size, idata + i * fea_size, sdata + pre_pad * img_size);
// init the first channel of mid
for (int c = 0; c < n; ++c) {
blas.AXPY(img_size, alpha, sdata + c * img_size, mdata + i * fea_size);
}
for (int c = 1; c < C; ++c) {
// copy previous scale
int mid_offset = i * fea_size + c * img_size;
std::memcpy(mdata + mid_offset, mdata + mid_offset - img_size,
img_size * sizeof(T));
// add last
blas.AXPY(img_size, alpha, sdata + (c + n - 1) * img_size,
mdata + mid_offset);
// sub rest
blas.AXPY(img_size, -alpha, sdata + (c - 1) * img_size,
mdata + mid_offset);
}
}
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e = x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
// compute the final output
blas.VPOW(mid->numel(), mdata, -beta, odata);
blas.VMUL(mid->numel(), odata, idata, odata);
}
};
template struct LRNFunctor<platform::CPUDeviceContext, float>;
......@@ -156,6 +166,9 @@ class LRNOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
int n = ctx->Attrs().Get<int>("n");
PADDLE_ENFORCE(n > 0 && n % 2 == 1, "n should be positive odd value");
ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("MidOut", x_dim);
......
......@@ -60,7 +60,6 @@ class LRNKernel : public framework::OpKernel<T> {
T beta = ctx.Attr<float>("beta");
T k = ctx.Attr<float>("k");
PADDLE_ENFORCE(n > 0, "n should >= 0");
PADDLE_ENFORCE(alpha >= 0.0, "alpha should >= 0.0");
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册