diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index ec3c39097a01c1404f10455c32c585bdc090900e..b3784ed0744095c2032dd8a0de7bd6b12827cf5c 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -37,6 +37,8 @@ limitations under the License. */ namespace paddle { namespace operators { +using framework::To32BitIndex; + enum ActBwdOpFwdDeps { kNoDeps = 0x00, // Do not need any forward input/output kDepX = 0x01, // Only need forward input X @@ -177,7 +179,14 @@ class ActivationKernel for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } - functor(*place, x, out); + // use 32bit index to speed up computation + bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = platform::is_gpu_place(context.GetPlace()); + if (use_32bit_index && is_gpu_place) { + functor(*place, To32BitIndex(x), To32BitIndex(out)); + } else { + functor(*place, x, out); + } } }; @@ -208,7 +217,15 @@ class ActivationGradKernel for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } - functor(*place, x, out, dout, dx); + // use 32bit index to speed up computation + bool use_32bit_index = out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = platform::is_gpu_place(context.GetPlace()); + if (use_32bit_index && is_gpu_place) { + functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout), + To32BitIndex(dx)); + } else { + functor(*place, x, out, dout, dx); + } } };