未验证 提交 b71abeee 编写于 作者: Z Zhang Ting 提交者: GitHub

use 32 bit index to improve activation ops (#24206)

* improve activation ops performance, test=develop

* use 32bit only GPU computation, test=develop
上级 89c76a53
......@@ -37,6 +37,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::To32BitIndex;
enum ActBwdOpFwdDeps {
kNoDeps = 0x00, // Do not need any forward input/output
kDepX = 0x01, // Only need forward input X
......@@ -177,8 +179,15 @@ class ActivationKernel
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
// use 32bit index to speed up computation
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
if (use_32bit_index && is_gpu_place) {
functor(*place, To32BitIndex(x), To32BitIndex(out));
} else {
functor(*place, x, out);
}
}
};
template <typename DeviceContext, typename Functor>
......@@ -208,8 +217,16 @@ class ActivationGradKernel
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
// use 32bit index to speed up computation
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = platform::is_gpu_place(context.GetPlace());
if (use_32bit_index && is_gpu_place) {
functor(*place, To32BitIndex(x), To32BitIndex(out), To32BitIndex(dout),
To32BitIndex(dx));
} else {
functor(*place, x, out, dout, dx);
}
}
};
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册