提交 79e3c2e4 编写于 作者: E eclipsess

restore lrn fix error

上级 923216e1
...@@ -23,27 +23,22 @@ namespace operators { ...@@ -23,27 +23,22 @@ namespace operators {
template <> template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const { void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const float alpha = param.Alpha();
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
auto x_dims = input_x->dims(); auto x_dims = input_x->dims();
Tensor *out = param.Out(); Tensor *out = param.Out();
out->mutable_data<float>(); out->mutable_data<float>();
if (alpha < 0.001) { /// data_format = NCHW
// coarse precision const int N = x_dims[0];
out->ShareDataWith(*input_x); const int C = x_dims[1];
} else { const int H = x_dims[2];
/// data_format = NCHW const int W = x_dims[3];
const int N = x_dims[0];
const int C = x_dims[1]; const int n = param.N();
const int H = x_dims[2]; const float alpha = param.Alpha();
const int W = x_dims[3]; const float beta = param.Beta();
const float k = param.K();
const int n = param.N(); LRNFunctor<float> lrnFunctor;
const float beta = param.Beta(); lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta);
const float k = param.K();
LRNFunctor<float> lrnFunctor;
lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta);
}
} }
template class LrnKernel<CPU, float>; template class LrnKernel<CPU, float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册