提交 01fb2be9 编写于 作者: T Tomasz Patejko

MKL elementwise add: default implementation used for integral types, float16 and/or GPU

上级 6f932482
......@@ -36,9 +36,12 @@ void default_elementwise_add(const framework::ExecutionContext& ctx,
}
template <typename DeviceContext, typename T>
typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add(
const framework::ExecutionContext& ctx, const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z);
......@@ -48,9 +51,12 @@ typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add(
}
template <typename DeviceContext, typename T>
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add(
const framework::ExecutionContext& ctx, const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
typename std::enable_if<
!std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
}
......@@ -66,7 +72,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x->dims() == y->dims();
if (platform::is_cpu_place(ctx.GetPlace()) && dims_equal) {
if (dims_equal) {
elementwise_add<DeviceContext, T>(ctx, x, y, z);
} else {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册