未验证 提交 7c020c71 编写于 作者: Y YuanRisheng 提交者: GitHub

[Pten]Move CPU_implementation of elementwise kernel in new directory (#38651)

* change 'math' to 'math_kernel'

* fix compile bugs

* merge develop

* fix compile bugs

* move cpu_impl of elementwise kernel to new directory
上级 6e9714a2
...@@ -23,6 +23,9 @@ limitations under the License. */ ...@@ -23,6 +23,9 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/include/core.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -73,9 +76,12 @@ class TestKernel : public OpKernel<float> { ...@@ -73,9 +76,12 @@ class TestKernel : public OpKernel<float> {
output->Resize(input->dims()); output->Resize(input->dims());
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
operators::TransformFunctor<AddFunctor<T>, T, DeviceContext> functor( auto pt_input = paddle::experimental::MakePtenDenseTensor(*input);
input, input, output, ctx.template device_context<DeviceContext>(), auto pt_out = paddle::experimental::MakePtenDenseTensor(*output);
AddFunctor<T>());
pten::funcs::TransformFunctor<AddFunctor<T>, T, DeviceContext> functor(
*pt_input, *pt_input, pt_out.get(),
ctx.template device_context<DeviceContext>(), AddFunctor<T>());
functor.Run(); functor.Run();
} }
}; };
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,58 +26,31 @@ namespace operators { ...@@ -25,58 +26,31 @@ namespace operators {
// Add // Add
template <typename T> template <typename T>
struct AddFunctor { using AddFunctor = pten::funcs::AddFunctor<T>;
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; }
};
template <typename T> template <typename T>
struct InverseAddFunctor { using InverseAddFunctor = pten::funcs::InverseAddFunctor<T>;
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b + a; }
};
// Subtract // Subtract
template <typename T> template <typename T>
struct SubFunctor { using SubFunctor = pten::funcs::SubtractFunctor<T>;
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; }
};
template <typename T> template <typename T>
struct InverseSubFunctor { using InverseSubFunctor = pten::funcs::InverseSubtractFunctor<T>;
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b - a; }
};
// Multiply // Multiply
template <typename T> template <typename T>
struct MulFunctor { using MulFunctor = pten::funcs::MultiplyFunctor<T>;
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};
template <typename T> template <typename T>
struct InverseMulFunctor { using InverseMulFunctor = pten::funcs::InverseMultiplyFunctor<T>;
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b * a; }
};
// Divide // Divide
#define DIV_ERROR_INFO \
"InvalidArgumentError: Integer division by zero encountered in " \
"(floor) divide. Please check the input value."
template <typename T, typename Enable = void>
struct DivFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; }
};
template <typename T> template <typename T>
struct DivFunctor<T, using DivFunctor = pten::funcs::DivideFunctor<T>;
typename std::enable_if<std::is_integral<T>::value>::type> {
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
// For int32/int64, need to check whether the divison is zero.
PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO);
return a / b;
}
};
template <typename T, typename Enable = void> template <typename T>
struct InverseDivFunctor { using InverseDivFunctor = pten::funcs::InverseDivideFunctor<T>;
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b / a; }
};
// Floor Divide // Floor Divide
template <typename T> template <typename T>
......
...@@ -31,8 +31,7 @@ limitations under the License. */ ...@@ -31,8 +31,7 @@ limitations under the License. */
// only can include the headers in paddle/pten/include dirs // only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/kernels/hybird/cpu/elementwise.h" #include "paddle/pten/kernels/cpu/elementwise_impl.h"
#include "paddle/pten/kernels/hybird/general/elementwise_base.h"
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
#ifdef __NVCC__ #ifdef __NVCC__
...@@ -151,9 +150,9 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, ...@@ -151,9 +150,9 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
int *x_dims_array, int *y_dims_array, int *x_dims_array, int *y_dims_array,
int *out_dims_array, const int max_dim, int *out_dims_array, const int max_dim,
const int axis) { const int axis) {
pten::general::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array, pten::funcs::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array,
y_dims_array, out_dims_array, max_dim, y_dims_array, out_dims_array, max_dim,
axis); axis);
} }
template <typename Functor, typename T, typename OutType = T> template <typename Functor, typename T, typename OutType = T>
...@@ -1073,71 +1072,9 @@ void CommonGradBroadcastCUDA( ...@@ -1073,71 +1072,9 @@ void CommonGradBroadcastCUDA(
inline framework::DDim trim_trailing_singular_dims( inline framework::DDim trim_trailing_singular_dims(
const framework::DDim &dims) { const framework::DDim &dims) {
return pten::general::trim_trailing_singular_dims(dims); return pten::funcs::trim_trailing_singular_dims(dims);
} }
template <typename Functor, typename T, typename DeviceContext,
typename OutType = T>
class TransformFunctor {
public:
TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, const DeviceContext &ctx, Functor func,
const bool is_xsize_larger = true)
: x_(x->data<T>()),
y_(y->data<T>()),
z_(z->mutable_data<OutType>(ctx.GetPlace())),
nx_(x->numel()),
ctx_(ctx),
func_(func),
is_xsize_larger_(is_xsize_larger) {
if (is_xsize_larger_ == false) {
nx_ = y->numel();
}
}
inline void Run() const {
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
}
inline void RunRowWise(int n, int pre) const {
platform::Transform<DeviceContext> trans;
if (is_xsize_larger_) {
trans(ctx_, x_, x_ + nx_,
pten::general::RowwiseTransformIterator<T, DeviceContext>(y_, n),
z_, func_);
} else {
trans(ctx_, y_, y_ + nx_,
pten::general::RowwiseTransformIterator<T, DeviceContext>(x_, n),
z_, func_);
}
}
inline void RunMidWise(int n, int pre, int post) const {
platform::Transform<DeviceContext> trans;
if (is_xsize_larger_) {
trans(ctx_, x_, x_ + nx_,
pten::general::MidWiseTransformIterator<T, DeviceContext>(y_, n,
post),
z_, func_);
} else {
trans(ctx_, y_, y_ + nx_,
pten::general::MidWiseTransformIterator<T, DeviceContext>(x_, n,
post),
z_, func_);
}
}
private:
const T *x_;
const T *y_;
OutType *z_;
int64_t nx_;
const DeviceContext &ctx_;
Functor func_;
bool is_xsize_larger_;
};
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T> template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
struct ElemwiseGradNoBroadcast { struct ElemwiseGradNoBroadcast {
const T *x_; const T *x_;
...@@ -1457,13 +1394,13 @@ void ElemwiseGradComputeWithBroadcast( ...@@ -1457,13 +1394,13 @@ void ElemwiseGradComputeWithBroadcast(
if (is_xsize_larger) { if (is_xsize_larger) {
auto y_dims_trimed = trim_trailing_singular_dims(y_dims); auto y_dims_trimed = trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
pten::general::get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n, pten::funcs::get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post,
&post, &is_run_common_broadcast); &is_run_common_broadcast);
} else { } else {
auto x_dims_trimed = trim_trailing_singular_dims(x_dims); auto x_dims_trimed = trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
pten::general::get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, pten::funcs::get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
&post, &is_run_common_broadcast); &is_run_common_broadcast);
} }
// special case for common backward implementation. // special case for common backward implementation.
if (is_run_common_broadcast) { if (is_run_common_broadcast) {
...@@ -1861,8 +1798,8 @@ void FusedElemwiseAndActComputeWithBroadcast( ...@@ -1861,8 +1798,8 @@ void FusedElemwiseAndActComputeWithBroadcast(
axis = (y_dim.size() == 0) ? x_dim.size() : axis; axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post, is_run_common_broadcast; int pre, n, post, is_run_common_broadcast;
pten::general::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, pten::funcs::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
if (post == 1) { if (post == 1) {
int h = pre; int h = pre;
int w = n; int w = n;
...@@ -2409,8 +2346,8 @@ void FusedElemwiseAndActGradComputeWithBroadcast( ...@@ -2409,8 +2346,8 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
axis = (y_dim.size() == 0) ? x_dim.size() : axis; axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post, is_run_common_broadcast; int pre, n, post, is_run_common_broadcast;
pten::general::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, pten::funcs::get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
const T *x_data = nullptr; const T *x_data = nullptr;
const T *y_data = nullptr; const T *y_data = nullptr;
if (x->IsInitialized()) x_data = x->data<T>(); if (x->IsInitialized()) x_data = x->data<T>();
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/kernel_registry.h"
// TODO(chenweihang) After the kernel is split into a single file,
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/pten/infermeta/binary.h" #include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/kernels/hybird/general/elementwise_base.h" #include "paddle/pten/kernels/funcs/elementwise_base.h"
namespace pten { namespace pten {
...@@ -162,13 +162,13 @@ DenseTensorMeta ElementwiseInferMeta(const DenseTensorMeta& x_meta, ...@@ -162,13 +162,13 @@ DenseTensorMeta ElementwiseInferMeta(const DenseTensorMeta& x_meta,
std::vector<int> x_dims_array(max_dim); std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim); std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim); std::vector<int> out_dims_array(max_dim);
general::GetBroadcastDimsArrays(x_dims, funcs::GetBroadcastDimsArrays(x_dims,
y_dims, y_dims,
x_dims_array.data(), x_dims_array.data(),
y_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), out_dims_array.data(),
max_dim, max_dim,
axis); axis);
return_meta.dims = paddle::framework::make_ddim(out_dims_array); return_meta.dims = paddle::framework::make_ddim(out_dims_array);
} }
return_meta.lod = x_meta.lod; return_meta.lod = x_meta.lod;
......
...@@ -15,13 +15,175 @@ limitations under the License. */ ...@@ -15,13 +15,175 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/hybird/general/elementwise_base.h" #include "paddle/pten/kernels/funcs/elementwise_base.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/pten/kernels/hybird/eigen/common.h"
namespace pten { namespace pten {
inline void UpdateElementwiseIndexArray(const int *out_dims_array, // Add
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsAddFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsAddFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VADD(x.numel(), x.data<T>(), y.data<T>(), z->mutable_data<T>());
}
};
template <typename DevCtx, typename T>
struct SameDimsAddFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
z->mutable_data<T>();
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*z);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x + eigen_y;
}
};
// Subtract
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsSubtractFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsSubtractFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VSUB(x.numel(), x.data<T>(), y.data<T>(), z->mutable_data<T>());
}
};
template <typename DevCtx, typename T>
struct SameDimsSubtractFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*z);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x - eigen_y;
}
};
// Divide
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsDivideFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsDivideFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
paddle::platform::errors::InvalidArgument(
"If use SameDimsDivideFunctor, template args(T) must be floating "
"point. ");
}
};
template <typename DevCtx, typename T>
struct SameDimsDivideFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VDIV(x.numel(), x.data<T>(), y.data<T>(), z->mutable_data<T>());
}
};
// Multiply
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsMultiplyFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsMultiplyFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VMUL(x.numel(), x.data<T>(), y.data<T>(), z->mutable_data<T>());
}
};
template <typename DevCtx, typename T>
struct SameDimsMultiplyFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*z);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x * eigen_y;
}
};
inline void UpdateElementwiseIndexArray(const int* out_dims_array,
const int max_dim, const int max_dim,
int *index_array) { int* index_array) {
for (int i = max_dim - 1; i >= 0; --i) { for (int i = max_dim - 1; i >= 0; --i) {
++index_array[i]; ++index_array[i];
if (index_array[i] >= out_dims_array[i]) { if (index_array[i] >= out_dims_array[i]) {
...@@ -32,9 +194,9 @@ inline void UpdateElementwiseIndexArray(const int *out_dims_array, ...@@ -32,9 +194,9 @@ inline void UpdateElementwiseIndexArray(const int *out_dims_array,
} }
} }
inline int GetElementwiseIndex(const int *x_dims_array, inline int GetElementwiseIndex(const int* x_dims_array,
const int max_dim, const int max_dim,
const int *index_array) { const int* index_array) {
int index_ = 0; int index_ = 0;
for (int i = 0; i < max_dim; i++) { for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] > 1) { if (x_dims_array[i] > 1) {
...@@ -45,26 +207,26 @@ inline int GetElementwiseIndex(const int *x_dims_array, ...@@ -45,26 +207,26 @@ inline int GetElementwiseIndex(const int *x_dims_array,
} }
template <typename Functor, typename T, typename OutType = T> template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCPU(const DenseTensor &x, void CommonForwardBroadcastCPU(const DenseTensor& x,
const DenseTensor &y, const DenseTensor& y,
DenseTensor *z, DenseTensor* z,
int *x_dims_array, int* x_dims_array,
int *y_dims_array, int* y_dims_array,
int *out_dims_array, int* out_dims_array,
int max_dim, int max_dim,
const paddle::platform::CPUDeviceContext &ctx, const paddle::platform::CPUDeviceContext& ctx,
Functor func, Functor func,
const bool is_xsize_larger = true) { const bool is_xsize_larger = true) {
std::vector<int> index_array(max_dim, 0); std::vector<int> index_array(max_dim, 0);
const T *x_data = x.data<T>(); const T* x_data = x.data<T>();
const T *y_data = y.data<T>(); const T* y_data = y.data<T>();
PADDLE_ENFORCE_NOT_NULL(x_data, PADDLE_ENFORCE_NOT_NULL(x_data,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The input X should not be empty.")); "The input X should not be empty."));
PADDLE_ENFORCE_NOT_NULL(y_data, PADDLE_ENFORCE_NOT_NULL(y_data,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The input Y should not be empty.")); "The input Y should not be empty."));
OutType *out_data = z->mutable_data<OutType>(); OutType* out_data = z->mutable_data<OutType>();
const int out_size = std::accumulate( const int out_size = std::accumulate(
out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>()); out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
...@@ -84,12 +246,12 @@ void CommonForwardBroadcastCPU(const DenseTensor &x, ...@@ -84,12 +246,12 @@ void CommonForwardBroadcastCPU(const DenseTensor &x,
template <typename Functor, typename T, typename OutType = T> template <typename Functor, typename T, typename OutType = T>
void CommonElementwiseBroadcastForward( void CommonElementwiseBroadcastForward(
const paddle::platform::CPUDeviceContext &dev_ctx, const paddle::platform::CPUDeviceContext& dev_ctx,
const DenseTensor &x, const DenseTensor& x,
const DenseTensor &y, const DenseTensor& y,
DenseTensor *z, DenseTensor* z,
const DDim &x_dims, const DDim& x_dims,
const DDim &y_dims, const DDim& y_dims,
Functor func, Functor func,
int axis, int axis,
const bool is_xsize_larger = true) { const bool is_xsize_larger = true) {
...@@ -110,13 +272,13 @@ void CommonElementwiseBroadcastForward( ...@@ -110,13 +272,13 @@ void CommonElementwiseBroadcastForward(
std::vector<int> x_dims_array(max_dim); std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim); std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim); std::vector<int> out_dims_array(max_dim);
general::GetBroadcastDimsArrays(x_dims, funcs::GetBroadcastDimsArrays(x_dims,
y_dims, y_dims,
x_dims_array.data(), x_dims_array.data(),
y_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), out_dims_array.data(),
max_dim, max_dim,
axis); axis);
CommonForwardBroadcastCPU<Functor, T, OutType>(x, CommonForwardBroadcastCPU<Functor, T, OutType>(x,
y, y,
...@@ -140,12 +302,12 @@ void CommonElementwiseBroadcastForward( ...@@ -140,12 +302,12 @@ void CommonElementwiseBroadcastForward(
// TODO(liuyiqun): optimize the CPU implementation to support all broadcast // TODO(liuyiqun): optimize the CPU implementation to support all broadcast
// cases and avoid the need of XxxInverseFunctor. // cases and avoid the need of XxxInverseFunctor.
template <typename Functor, typename T, typename OutType = T> template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx, void ElementwiseCompute(const paddle::platform::CPUDeviceContext& dev_ctx,
const DenseTensor &x, const DenseTensor& x,
const DenseTensor &y, const DenseTensor& y,
int axis, int axis,
Functor func, Functor func,
DenseTensor *z) { DenseTensor* z) {
z->mutable_data<OutType>(); z->mutable_data<OutType>();
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y.dims();
...@@ -155,7 +317,7 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx, ...@@ -155,7 +317,7 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx,
is_xsize_larger = false; is_xsize_larger = false;
max_dim = y_dims.size(); max_dim = y_dims.size();
} }
general:: funcs::
TransformFunctor<Functor, T, paddle::platform::CPUDeviceContext, OutType> TransformFunctor<Functor, T, paddle::platform::CPUDeviceContext, OutType>
functor(x, y, z, dev_ctx, func, is_xsize_larger); functor(x, y, z, dev_ctx, func, is_xsize_larger);
if (x_dims == y_dims) { if (x_dims == y_dims) {
...@@ -179,25 +341,25 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx, ...@@ -179,25 +341,25 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx,
int pre, n, post, is_run_common_broadcast, axis_trim = 0; int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) { if (is_xsize_larger) {
auto y_dims_trimed = general::trim_trailing_singular_dims(y_dims); auto y_dims_trimed = funcs::trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
general::get_mid_dims(x_dims, funcs::get_mid_dims(x_dims,
y_dims_trimed, y_dims_trimed,
axis_trim, axis_trim,
&pre, &pre,
&n, &n,
&post, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
} else { } else {
auto x_dims_trimed = general::trim_trailing_singular_dims(x_dims); auto x_dims_trimed = funcs::trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
general::get_mid_dims(y_dims, funcs::get_mid_dims(y_dims,
x_dims_trimed, x_dims_trimed,
axis_trim, axis_trim,
&pre, &pre,
&n, &n,
&post, &post,
&is_run_common_broadcast); &is_run_common_broadcast);
} }
// special case for common implementation. // special case for common implementation.
// case 1: x=[2,3,1,5], y=[2,1,4,1] // case 1: x=[2,3,1,5], y=[2,1,4,1]
...@@ -219,10 +381,10 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx, ...@@ -219,10 +381,10 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx,
template <typename Functor> template <typename Functor>
struct SameDimsElementwiseCompute { struct SameDimsElementwiseCompute {
void operator()(const paddle::platform::CPUDeviceContext &dev_ctx, void operator()(const paddle::platform::CPUDeviceContext& dev_ctx,
const DenseTensor &x, const DenseTensor& x,
const DenseTensor &y, const DenseTensor& y,
DenseTensor *z) { DenseTensor* z) {
Functor()(dev_ctx, x, y, z); Functor()(dev_ctx, x, y, z);
} }
}; };
......
...@@ -18,9 +18,11 @@ ...@@ -18,9 +18,11 @@
#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/hybird/cpu/elementwise.h"
#include "paddle/pten/kernels/cpu/elementwise_impl.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/hybird/eigen/reduce.h" #include "paddle/pten/kernels/hybird/eigen/reduce.h"
#include "paddle/pten/kernels/hybird/general/elementwise_functor.h"
#include "paddle/pten/kernels/hybird/general/reduce_impl.h" #include "paddle/pten/kernels/hybird/general/reduce_impl.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
...@@ -30,29 +32,28 @@ ...@@ -30,29 +32,28 @@
namespace pten { namespace pten {
#define DEFINE_CPU_ELEMENTWISE_OP(name) \ #define DEFINE_CPU_ELEMENTWISE_OP(name) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \ void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
int axis, \ int axis, \
DenseTensor* out) { \ DenseTensor* out) { \
out->mutable_data<T>(); \ out->mutable_data<T>(); \
if (x.dims() == y.dims()) { \ if (x.dims() == y.dims()) { \
SameDimsElementwiseCompute< \ SameDimsElementwiseCompute<SameDims##name##Functor<CPUContext, T>>()( \
general::SameDims##name##Functor<CPUContext, T>>()( \ dev_ctx, x, y, out); \
dev_ctx, x, y, out); \ } else { \
} else { \ auto x_dims = x.dims(); \
auto x_dims = x.dims(); \ auto y_dims = y.dims(); \
auto y_dims = y.dims(); \ if (x_dims.size() >= y_dims.size()) { \
if (x_dims.size() >= y_dims.size()) { \ ElementwiseCompute<funcs::name##Functor<T>, T>( \
ElementwiseCompute<general::name##Functor<T>, T>( \ dev_ctx, x, y, axis, funcs::name##Functor<T>(), out); \
dev_ctx, x, y, axis, general::name##Functor<T>(), out); \ } else { \
} else { \ ElementwiseCompute<funcs::Inverse##name##Functor<T>, T>( \
ElementwiseCompute<general::Inverse##name##Functor<T>, T>( \ dev_ctx, x, y, axis, funcs::Inverse##name##Functor<T>(), out); \
dev_ctx, x, y, axis, general::Inverse##name##Functor<T>(), out); \ } \
} \ } \
} \
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -76,17 +77,17 @@ void DivideKernel(const Context& dev_ctx, ...@@ -76,17 +77,17 @@ void DivideKernel(const Context& dev_ctx,
// allocate memory for out // allocate memory for out
out->mutable_data<T>(); out->mutable_data<T>();
if (x.dims() == y.dims() && std::is_floating_point<T>::value) { if (x.dims() == y.dims() && std::is_floating_point<T>::value) {
SameDimsElementwiseCompute<general::SameDimsDivideFunctor<CPUContext, T>>()( SameDimsElementwiseCompute<SameDimsDivideFunctor<CPUContext, T>>()(
dev_ctx, x, y, out); dev_ctx, x, y, out);
} else { } else {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto y_dims = y.dims(); auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { if (x_dims.size() >= y_dims.size()) {
ElementwiseCompute<general::DivideFunctor<T>, T>( ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, x, y, axis, general::DivideFunctor<T>(), out); dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out);
} else { } else {
ElementwiseCompute<general::InverseDivideFunctor<T>, T>( ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
dev_ctx, x, y, axis, general::InverseDivideFunctor<T>(), out); dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out);
} }
} }
} }
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
namespace pten { namespace pten {
namespace general { namespace funcs {
using DDim = paddle::framework::DDim; using DDim = paddle::framework::DDim;
...@@ -378,6 +378,5 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, ...@@ -378,6 +378,5 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
} }
} }
} }
} // namespace funcs
} // namespace general
} // namespace pten } // namespace pten
...@@ -17,50 +17,13 @@ limitations under the License. */ ...@@ -17,50 +17,13 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/hybird/blas/elementwise.h"
#include "paddle/pten/kernels/hybird/eigen/elementwise.h"
namespace pten { namespace pten {
namespace general { namespace funcs {
// Define the binary functors used in elementwise ops. // Define the binary functors used in elementwise ops.
// Add // Add
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsAddFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsAddFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
blas::ElementwiseAdd<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename DevCtx, typename T>
struct SameDimsAddFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
eigen::ElementwiseAdd<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename T> template <typename T>
struct AddFunctor { struct AddFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; } inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; }
...@@ -71,40 +34,6 @@ struct InverseAddFunctor { ...@@ -71,40 +34,6 @@ struct InverseAddFunctor {
}; };
// Subtract // Subtract
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsSubtractFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsSubtractFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
blas::ElementwiseSub<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename DevCtx, typename T>
struct SameDimsSubtractFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
eigen::ElementwiseSub<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename T> template <typename T>
struct SubtractFunctor { struct SubtractFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; } inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; }
...@@ -114,43 +43,17 @@ struct InverseSubtractFunctor { ...@@ -114,43 +43,17 @@ struct InverseSubtractFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b - a; } inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b - a; }
}; };
// Divide // Multiply
template <typename DevCtx, typename T, class Enable = void> template <typename T>
struct SameDimsDivideFunctor { struct MultiplyFunctor {
void operator()(const DevCtx& dev_ctx, inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsDivideFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
paddle::platform::errors::InvalidArgument(
"If use SameDimsDivideFunctor, template args(T) must be floating "
"point. ");
}
}; };
template <typename T>
template <typename DevCtx, typename T> struct InverseMultiplyFunctor {
struct SameDimsDivideFunctor< inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b * a; }
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
blas::ElementwiseDiv<DevCtx, T>(dev_ctx, x, y, z);
}
}; };
// Divide
#define DIV_ERROR_INFO \ #define DIV_ERROR_INFO \
"InvalidArgumentError: Integer division by zero encountered in " \ "InvalidArgumentError: Integer division by zero encountered in " \
"(floor) divide. Please check the input value." "(floor) divide. Please check the input value."
...@@ -176,48 +79,5 @@ struct InverseDivideFunctor { ...@@ -176,48 +79,5 @@ struct InverseDivideFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b / a; } inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b / a; }
}; };
// Multiply } // namespace funcs
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsMultiplyFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsMultiplyFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
blas::ElementwiseMul<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename DevCtx, typename T>
struct SameDimsMultiplyFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
eigen::ElementwiseMul<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename T>
struct MultiplyFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};
template <typename T>
struct InverseMultiplyFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b * a; }
};
} // namespace general
} // namespace pten } // namespace pten
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/pten/kernels/math_kernel.h" #include "paddle/pten/kernels/math_kernel.h"
#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h" #include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h"
#include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h" #include "paddle/pten/kernels/hybird/cuda/reduce/reduce.h"
#include "paddle/pten/kernels/hybird/general/elementwise_functor.h"
#include "paddle/pten/kernels/hybird/general/reduce_impl.h" #include "paddle/pten/kernels/hybird/general/reduce_impl.h"
#ifdef __NVCC__ #ifdef __NVCC__
...@@ -39,21 +39,21 @@ namespace kps = paddle::operators::kernel_primitives; ...@@ -39,21 +39,21 @@ namespace kps = paddle::operators::kernel_primitives;
namespace pten { namespace pten {
#define DEFINE_CUDA_ELEMENTWISE_OP(name) \ #define DEFINE_CUDA_ELEMENTWISE_OP(name) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \ void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& y, \ const DenseTensor& y, \
int axis, \ int axis, \
DenseTensor* out) { \ DenseTensor* out) { \
std::vector<const DenseTensor*> inputs; \ std::vector<const DenseTensor*> inputs; \
std::vector<DenseTensor*> outputs; \ std::vector<DenseTensor*> outputs; \
inputs.emplace_back(&x); \ inputs.emplace_back(&x); \
inputs.emplace_back(&y); \ inputs.emplace_back(&y); \
outputs.emplace_back(out); \ outputs.emplace_back(out); \
out->mutable_data<T>(); \ out->mutable_data<T>(); \
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( \ LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( \
dev_ctx, inputs, &outputs, axis, general::name##Functor<T>()); \ dev_ctx, inputs, &outputs, axis, funcs::name##Functor<T>()); \
} }
/** /**
......
add_subdirectory(eigen) add_subdirectory(eigen)
add_subdirectory(blas)
add_subdirectory(general) add_subdirectory(general)
cc_library(pten_transpose_cpu SRCS transpose.cc DEPS dense_tensor pten_context) cc_library(pten_transpose_cpu SRCS transpose.cc DEPS dense_tensor pten_context)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
namespace blas {
template <typename DevCtx, typename T>
void ElementwiseAdd(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VADD(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
template <typename DevCtx, typename T>
void ElementwiseSub(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VSUB(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
template <typename DevCtx, typename T>
void ElementwiseDiv(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VDIV(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
template <typename DevCtx, typename T>
void ElementwiseMul(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VMUL(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
} // namespace blas
} // namespace pten
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h" #include "paddle/fluid/platform/function_traits.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/hybird/general/elementwise_base.h" #include "paddle/pten/kernels/funcs/elementwise_base.h"
namespace pten { namespace pten {
namespace kps = paddle::operators::kernel_primitives; namespace kps = paddle::operators::kernel_primitives;
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/hybird/eigen/common.h"
namespace pten {
namespace eigen {
template <typename DevCtx, typename T>
void ElementwiseAdd(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
out->mutable_data<T>();
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x + eigen_y;
}
template <typename DevCtx, typename T>
void ElementwiseSub(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x - eigen_y;
}
template <typename DevCtx, typename T>
void ElementwiseMul(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x * eigen_y;
}
} // namespace eigen
} // namespace pten
...@@ -342,7 +342,6 @@ def source_include(header_file_path): ...@@ -342,7 +342,6 @@ def source_include(header_file_path):
#include "paddle/pten/api/include/kernel_signature.h" #include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h" #include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/kernel_declare.h"
#include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册