提交 79e3c2e4 编写于 作者: E eclipsess

restore lrn fix error

上级 923216e1
......@@ -23,15 +23,10 @@ namespace operators {
template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const float alpha = param.Alpha();
const Tensor *input_x = param.InputX();
auto x_dims = input_x->dims();
Tensor *out = param.Out();
out->mutable_data<float>();
if (alpha < 0.001) {
// coarse precision
out->ShareDataWith(*input_x);
} else {
/// data_format = NCHW
const int N = x_dims[0];
const int C = x_dims[1];
......@@ -39,11 +34,11 @@ void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const int W = x_dims[3];
const int n = param.N();
const float alpha = param.Alpha();
const float beta = param.Beta();
const float k = param.K();
LRNFunctor<float> lrnFunctor;
lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta);
}
}
template class LrnKernel<CPU, float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册