提交 bb95dd2c 编写于 作者: E eclipsess

lrn coarse presicion

上级 02143bbd
...@@ -23,21 +23,27 @@ namespace operators { ...@@ -23,21 +23,27 @@ namespace operators {
template <> template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const { void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const float alpha = param.Alpha();
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
auto x_dims = input_x->dims(); auto x_dims = input_x->dims();
/// data_format = NCHW
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
Tensor *out = param.Out(); Tensor *out = param.Out();
out->mutable_data<float>(); out->mutable_data<float>();
const int n = param.N(); if (alpha < 0.001) {
const float alpha = param.Alpha(); // coarse precision
const float beta = param.Beta(); out->ShareDataWith(*input_x);
const float k = param.K(); } else {
LRNFunctor<float> lrnFunctor; /// data_format = NCHW
lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta); const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int n = param.N();
const float beta = param.Beta();
const float k = param.K();
LRNFunctor<float> lrnFunctor;
lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta);
}
} }
template class LrnKernel<CPU, float>; template class LrnKernel<CPU, float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册