提交 01fb2be9 编写于 作者: T Tomasz Patejko

MKL elementwise add: default implementation used for integral types, float16 and/or GPU

上级 6f932482
...@@ -36,9 +36,12 @@ void default_elementwise_add(const framework::ExecutionContext& ctx, ...@@ -36,9 +36,12 @@ void default_elementwise_add(const framework::ExecutionContext& ctx,
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add( typename std::enable_if<
const framework::ExecutionContext& ctx, const framework::Tensor* x, std::is_floating_point<T>::value &&
const framework::Tensor* y, framework::Tensor* z) { std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto eigen_x = framework::EigenVector<T>::Flatten(*x); auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_y = framework::EigenVector<T>::Flatten(*y); auto eigen_y = framework::EigenVector<T>::Flatten(*y);
auto eigen_z = framework::EigenVector<T>::Flatten(*z); auto eigen_z = framework::EigenVector<T>::Flatten(*z);
...@@ -48,9 +51,12 @@ typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add( ...@@ -48,9 +51,12 @@ typename std::enable_if<std::is_floating_point<T>::value>::type elementwise_add(
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add( typename std::enable_if<
const framework::ExecutionContext& ctx, const framework::Tensor* x, !std::is_floating_point<T>::value ||
const framework::Tensor* y, framework::Tensor* z) { !std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z); default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
} }
...@@ -66,7 +72,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> { ...@@ -66,7 +72,7 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
auto dims_equal = x->dims() == y->dims(); auto dims_equal = x->dims() == y->dims();
if (platform::is_cpu_place(ctx.GetPlace()) && dims_equal) { if (dims_equal) {
elementwise_add<DeviceContext, T>(ctx, x, y, z); elementwise_add<DeviceContext, T>(ctx, x, y, z);
} else { } else {
default_elementwise_add<DeviceContext, T>(ctx, x, y, z); default_elementwise_add<DeviceContext, T>(ctx, x, y, z);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册