未验证 提交 c1310343 编写于 作者: Y YuanRisheng 提交者: GitHub

[Pten]Refactor the Elementwise_add Kernel (#37043)

* elementwise_add kernel refactor

* fix compile bugs in elementwise_add refactor

* fix compile bugs when run in npu/xpu

* fix bugs when run unit test

* fix bugs when run ci-windows

* modify code as recommended

* code format adjust

* fix bugs when run ci

* fix compile bug when run in ci-windwos
上级 6bf208c3
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
......@@ -19,28 +20,17 @@ limitations under the License. */
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
class ElementwiseAddKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, AddFunctor<T>());
}
};
template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) {
......
......@@ -20,6 +20,13 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/framework/pten_utils.h"
// only can include the headers in paddle/pten/include dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
namespace paddle {
namespace operators {
......@@ -55,12 +62,14 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
if (x->dims() == y->dims()) {
SameDimsElemwiseAdd<DeviceContext, T> LaunchElementwiseCpuKernel;
LaunchElementwiseCpuKernel(ctx, x, y, z);
} else {
LaunchBroadcastElementwiseCpuKernel<DeviceContext, T>(ctx, x, y, z);
}
auto &dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::ElementwiseAdd<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
pt_z.get());
}
};
......
......@@ -139,6 +139,17 @@ class ElementwiseOp : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (Type() == "elementwise_add") {
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
return framework::KernelSignature("elementwise_add", {"X", "Y"},
{"axis"}, {"Out"});
}
}
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
}
};
class ElementwiseOpInferVarType
......
......@@ -162,190 +162,36 @@ struct DimensionsTransform {
}
};
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst, const T *__restrict__ src, uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config, int numel, int num,
bool need_broadcast) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(dst, src, block_offset,
config, numel);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
}
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int Rank, bool IsBoundary = false>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, Arity> &ins, OutT *out,
const framework::Array<bool, Arity> &use_broadcast, uint32_t numel,
const framework::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
int num, Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
int block_offset = blockIdx.x * blockDim.x * VecSize;
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, Rank, IsBoundary>(args[i], ins[i], block_offset,
configs[i], numel, num,
use_broadcast[i]);
}
const bool kCallElementwiseAny =
platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity,
kCallElementwiseAny>()(func, args, result);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
num);
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int Rank>
__global__ void BroadcastKernel(
framework::Array<const InT *__restrict__, Arity> ins, OutT *out,
framework::Array<bool, Arity> use_broadcast, uint32_t numel,
framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
int main_tid, int tail_tid, Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block
if (blockIdx.x < main_tid) {
int num = blockDim.x * VecSize; // blockIdx.x < main_tid
DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, false>(
ins, out, use_broadcast, numel, configs, num, func);
} else { // reminder
int num = tail_tid;
DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, true>(
ins, out, use_broadcast, numel, configs, num, func);
}
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
int Rank>
void LaunchKernel(const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
framework::Tensor *out, Functor func,
DimensionsTransform merge_dims) {
int numel = out->numel();
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / (VecSize * threads);
int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream();
OutT *out_data = out->data<OutT>();
framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
framework::Array<bool, Arity> use_broadcast;
framework::Array<const InT *__restrict__, Arity> ins_data;
for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = ins[i]->data<InT>();
if (use_broadcast[i]) {
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configs[i] = kps::details::BroadcastConfig<Rank>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
BroadcastKernel<InT, OutT, Functor, Arity, VecSize,
Rank><<<blocks, threads, 0, stream>>>(
ins_data, out_data, use_broadcast, numel, configs, main_tid, tail_tid,
func);
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
void LaunchBroadcastKernelForDifferentVecSize(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
int axis, Functor func) {
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>(ctx, ins, out, \
func, merge_dims); \
} break;
switch (merge_dims.dim_size) {
CALL_BROADCAST_FOR_DIM_SIZE(1);
CALL_BROADCAST_FOR_DIM_SIZE(2);
CALL_BROADCAST_FOR_DIM_SIZE(3);
CALL_BROADCAST_FOR_DIM_SIZE(4);
CALL_BROADCAST_FOR_DIM_SIZE(5);
CALL_BROADCAST_FOR_DIM_SIZE(6);
CALL_BROADCAST_FOR_DIM_SIZE(7);
CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(platform::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.\n",
merge_dims.dim_size, framework::DDim::kMaxRank));
}
}
#undef CALL_BROADCAST_FOR_DIM_SIZE
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchBroadcastElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
using Traits = platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(), kArity,
platform::errors::InvalidArgument(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d.",
ins.size(), kArity));
PADDLE_ENFORCE_EQ(kArity, 2,
platform::errors::InvalidArgument(
"Currently only broadcast of binary is supported and "
"verified, but received %d.",
kArity));
int in_vec_size = 4;
framework::Tensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = platform::GetVectorizedSize<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<pten::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<pten::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
pt_inputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*in)));
}
int out_vec_size = platform::GetVectorizedSize<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);
switch (vec_size) {
case 4: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 4>(
ctx, ins, out, axis, func);
break;
}
case 2: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 2>(
ctx, ins, out, axis, func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 1>(
ctx, ins, out, axis, func);
break;
}
default: {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
for (auto out : *outs) {
pt_outputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*out)));
}
for (int i = 0; i < pt_inputs_tmp.size(); i++) {
pt_inputs.push_back(pt_inputs_tmp[i].get());
}
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(
ctx, pt_inputs, &pt_outputs, axis, func);
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
......@@ -353,24 +199,31 @@ void LaunchElementwiseCudaKernel(
const platform::CUDADeviceContext &cuda_ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, int axis, Functor func) {
std::vector<int> dims_size;
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<pten::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<pten::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
pt_inputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*in)));
}
if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
func);
} else {
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
axis, func);
for (auto out : *outs) {
pt_outputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*out)));
}
for (int i = 0; i < pt_inputs_tmp.size(); i++) {
pt_inputs.push_back(pt_inputs_tmp[i].get());
}
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, pt_inputs,
&pt_outputs, axis, func);
}
} // namespace operators
......
......@@ -14,11 +14,17 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/kernels/functions/cuda/elementwise/elementwise.h"
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
......@@ -30,187 +36,38 @@ namespace operators {
namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
int64_t numel, int vec_size) {
int threads = ELEMENTWISE_BLOCK_SIZE;
int sm_count = ctx.GetSMCount();
int active_threads_num = numel / vec_size;
if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
} else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
}
// Number of threads per block shall be larger than 64.
return std::max(64, threads);
}
template <typename InT, typename OutT>
int GetVectorizedSizeForTensors(
const std::vector<const framework::Tensor *> &ins,
const std::vector<framework::Tensor *> &outs) {
int vec_size = 4;
for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
vec_size = std::min<int>(vec_size,
platform::GetVectorizedSize((*iter)->data<InT>()));
}
for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
vec_size = std::min<int>(
vec_size, platform::GetVectorizedSize((*iter)->data<OutT>()));
}
return vec_size;
}
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity,
bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
__device__ inline void operator()(Functor func, InT (*args)[VecSize],
OutT *result);
};
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline void operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(result, args,
func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline void operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(result, args[0],
args[1], func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func, InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
};
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
bool IsBoundary>
__device__ void DealSegment(
const framework::Array<const InT *__restrict__, Arity> &in, OutT *out,
int num, Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
int data_offset = VecSize * blockIdx.x * blockDim.x;
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(args[i], in[i] + data_offset,
num);
}
const bool kCallElementwiseAny =
platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity,
kCallElementwiseAny>()(func, args, result);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + data_offset, result,
num);
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
__global__ void ElementVectorizeKernel(
framework::Array<const InT *__restrict__, Arity> ins, OutT *out, int size,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
int num = size - data_offset;
// the num this time have to deal with
if (VecSize * blockDim.x > num) { // reminder segment
DealSegment<InT, OutT, Functor, Arity, VecSize, true>(ins, out, num, func);
} else { // complete segment
DealSegment<InT, OutT, Functor, Arity, VecSize, false>(ins, out, num, func);
}
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
void ElementwiseCudaKernel(const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs,
Functor func) {
auto numel = ins[0]->numel();
int block_size = GetThreadsConfig(ctx, numel, VecSize);
int grid_size =
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
auto stream = ctx.stream();
OutT *out_data = (*outs)[0]->data<OutT>();
framework::Array<const InT *__restrict__, Arity> ins_data;
for (int i = 0; i < Arity; i++) {
ins_data[i] = ins[i]->data<InT>();
}
ElementVectorizeKernel<InT, OutT, Functor, Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, func);
}
using ElementwiseType = pten::ElementwiseType;
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchSameDimsElementwiseCudaKernel(
const platform::CUDADeviceContext &ctx,
const std::vector<const framework::Tensor *> &ins,
std::vector<framework::Tensor *> *outs, Functor func) {
using Traits = platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(), kArity,
platform::errors::InvalidArgument(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d.",
ins.size(), kArity));
// calculate the max vec_size for all ins and outs
int vec_size = GetVectorizedSizeForTensors<InT, OutT>(ins, *outs);
switch (vec_size) {
case 4:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 4>(ctx, ins, outs,
func);
break;
case 2:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 2>(ctx, ins, outs,
func);
break;
case 1:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 1>(ctx, ins, outs,
func);
break;
default: {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
std::vector<const pten::DenseTensor *> pt_inputs;
std::vector<pten::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
// DenseTensor obj
// generated by MakePtenDenseTensor can be destroyed when exits loop. *_tmp
// can be deleted
// when DenseTensor support copy constructor.
std::vector<std::unique_ptr<pten::DenseTensor>> pt_inputs_tmp;
std::vector<std::unique_ptr<pten::DenseTensor>> pt_outputs_tmp;
for (auto in : ins) {
pt_inputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*in)));
}
for (auto out : *outs) {
pt_outputs_tmp.emplace_back(
std::move(paddle::experimental::MakePtenDenseTensor(*out)));
}
for (int i = 0; i < pt_inputs_tmp.size(); i++) {
pt_inputs.push_back(pt_inputs_tmp[i].get());
}
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
pten::LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(ctx, pt_inputs,
&pt_outputs, func);
}
} // namespace operators
......
......@@ -14,6 +14,8 @@
#pragma once
#include "paddle/fluid/platform/eigen_ext.h"
namespace paddle {
namespace operators {
namespace kernel_primitives {
......
......@@ -23,5 +23,7 @@ namespace experimental {
// TODO(chenweihang): move mean API into stat.h/cc
Tensor mean(const Tensor& x);
Tensor add(const Tensor& x, const Tensor& y);
} // namespace experimental
} // namespace paddle
......@@ -60,5 +60,40 @@ Tensor mean(const Tensor& x) {
return out;
}
Tensor add(const Tensor& x, const Tensor& y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"elementwise_add", kernel_key);
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
auto dense_y = std::dynamic_pointer_cast<pten::DenseTensor>(y.impl());
kernel_context.EmplaceBackInput(dense_y);
kernel_context.EmplaceBackAttr(-1);
// 4. InferShape
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1);
// 5. Prepare outputs
Tensor out;
const auto allocator = std::make_shared<DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);
// 6. Call kernel
kernel(&kernel_context);
return out;
}
} // namespace experimental
} // namespace paddle
......@@ -76,7 +76,7 @@ template <typename T>
T* DenseTensor::mutable_data() {
PADDLE_ENFORCE(
(data_type() == paddle::experimental::CppTypeToDataType<T>::Type()),
paddle::platform::errors::PreconditionNotMet(
paddle::platform::errors::InvalidArgument(
"The type of data (%d) we are trying to retrieve does not match the "
"type of data currently contained in the container (%d).",
static_cast<int>(paddle::experimental::CppTypeToDataType<T>::Type()),
......@@ -88,7 +88,7 @@ template <typename T>
const T* DenseTensor::data() const {
PADDLE_ENFORCE(
(data_type() == paddle::experimental::CppTypeToDataType<T>::Type()),
paddle::platform::errors::PreconditionNotMet(
paddle::platform::errors::InvalidArgument(
"The type of data we are trying to retrieve does not match the "
"type of data currently contained in the container."));
return static_cast<const T*>(data());
......
......@@ -73,4 +73,18 @@ DenseTensor Scale(const ContextT& dev_ctx,
ScaleHost<T>(dev_ctx, x, scale, bias, bias_after_scale, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor ElementwiseAdd(const ContextT& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
ElementwiseAdd<T>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -14,7 +14,7 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/infershape/binary.h"
#include "paddle/pten/kernels/functions/general/elementwise_base.h"
namespace pten {
DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta,
......@@ -129,4 +129,49 @@ DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
return {x_meta.type, ddim_out, x_meta.layout};
}
DenseTensorMeta ElementwiseInferShape(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta,
int axis) {
DenseTensorMeta return_meta(x_meta.type, x_meta.dims, x_meta.layout);
if (x_meta.dims != y_meta.dims) {
auto x_dims = x_meta.dims;
auto y_dims = y_meta.dims;
int max_dim = std::max(x_dims.size(), y_dims.size());
if (x_dims.size() == y_dims.size()) {
PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0),
true,
paddle::platform::errors::InvalidArgument(
"axis should be -1 or 0 while the dimension of "
"tensor X (%s) is equal to the dimension of "
"tensor Y (%s), but received axis: %s",
x_dims.size(),
y_dims.size(),
axis));
}
PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim),
true,
paddle::platform::errors::InvalidArgument(
"The axis range must be [%s, %s), but axis is %s. "
"Please set the axis again.",
-1 * max_dim,
max_dim,
axis));
axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
: axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
general::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
return_meta.dims = paddle::framework::make_ddim(out_dims_array);
}
return_meta.lod = x_meta.lod;
return return_meta;
}
} // namespace pten
......@@ -41,4 +41,7 @@ DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
bool trans_x,
bool trans_y);
DenseTensorMeta ElementwiseInferShape(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta,
int axis);
} // namespace pten
cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function)
cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function blas)
cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_factory)
cc_library(creation_cpu SRCS creation.cc DEPS dense_tensor kernel_context kernel_factory eigen_function)
cc_library(utils_cpu SRCS utils.cc DEPS dense_tensor kernel_context kernel_factory memory convert_utils)
......
......@@ -14,13 +14,16 @@
#include "paddle/pten/kernels/cpu/math.h"
#include "paddle/pten/kernels/functions/cpu/elementwise.h"
#include "paddle/pten/kernels/functions/eigen/mean.h"
#include "paddle/pten/kernels/functions/eigen/scale.h"
#include "paddle/pten/kernels/functions/eigen/sign.h"
#include "paddle/pten/kernels/functions/general/elementwise_functor.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
namespace pten {
......@@ -61,11 +64,35 @@ void ScaleHost(const CPUContext& dev_ctx,
out);
}
template <typename T>
void ElementwiseAdd(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
if (x.dims() == y.dims()) {
SameDimsElementwiseCompute<general::SameDimsAddFunctor<CPUContext, T>>()(
dev_ctx, x, y, out);
} else {
auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseCompute<general::AddFunctor<T>, T>(
dev_ctx, x, y, axis, general::AddFunctor<T>(), out);
} else {
ElementwiseCompute<general::InverseAddFunctor<T>, T>(
dev_ctx, x, y, axis, general::InverseAddFunctor<T>(), out);
}
}
}
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(MathCPU);
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
......@@ -97,3 +124,13 @@ PT_REGISTER_KERNEL("scale.host",
int64_t) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
}
PT_REGISTER_KERNEL("elementwise_add",
CPU,
ANY,
pten::ElementwiseAdd,
float,
double,
int,
int64_t,
complex64,
complex128) {}
......@@ -46,4 +46,11 @@ void ScaleHost(const CPUContext& dev_ctx,
bool bias_after_scale,
DenseTensor* out);
template <typename T>
void ElementwiseAdd(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace pten
......@@ -14,9 +14,11 @@ limitations under the License. */
#include "paddle/pten/kernels/cuda/math.h"
#include "paddle/pten/kernels/functions/cuda/elementwise/elementwise.h"
#include "paddle/pten/kernels/functions/eigen/mean.h"
#include "paddle/pten/kernels/functions/eigen/scale.h"
#include "paddle/pten/kernels/functions/eigen/sign.h"
#include "paddle/pten/kernels/functions/general/elementwise_functor.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
......@@ -26,6 +28,7 @@ limitations under the License. */
namespace cub = hipcub;
#endif
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
......@@ -121,12 +124,30 @@ void ScaleHost(const CUDAContext& dev_ctx,
out);
}
template <typename T>
void ElementwiseAdd(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
inputs.emplace_back(&y);
outputs.emplace_back(out);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, inputs, &outputs, axis, general::AddFunctor<T>());
}
} // namespace pten
// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(MathCUDA);
using float16 = paddle::platform::float16;
using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_KERNEL("sign", CUDA, ANY, pten::Sign, float, double, float16) {}
PT_REGISTER_KERNEL("mean", CUDA, ANY, pten::Mean, float, double, float16) {}
PT_REGISTER_KERNEL("scale",
......@@ -155,3 +176,14 @@ PT_REGISTER_KERNEL("scale.host",
int64_t) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
}
PT_REGISTER_KERNEL("elementwise_add",
CUDA,
ANY,
pten::ElementwiseAdd,
float,
double,
int,
int64_t,
float16,
complex64,
complex128) {}
......@@ -48,6 +48,13 @@ void ScaleHost(const CUDAContext& dev_ctx,
bool bias_after_scale,
DenseTensor* out);
template <typename T>
void ElementwiseAdd(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
} // namespace pten
#endif
add_subdirectory(eigen)
add_subdirectory(blas)
add_subdirectory(general)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
namespace blas {
template <typename DevCtx, typename T>
void ElementwiseAdd(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto blas = paddle::operators::math::GetBlas<DevCtx, T>(dev_ctx);
blas.VADD(x.numel(), x.data<T>(), y.data<T>(), out->mutable_data<T>());
}
} // namespace blas
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/general/elementwise_base.h"
namespace pten {
inline void UpdateElementwiseIndexArray(const int *out_dims_array,
const int max_dim,
int *index_array) {
for (int i = max_dim - 1; i >= 0; --i) {
++index_array[i];
if (index_array[i] >= out_dims_array[i]) {
index_array[i] -= out_dims_array[i];
} else {
break;
}
}
}
inline int GetElementwiseIndex(const int *x_dims_array,
const int max_dim,
const int *index_array) {
int index_ = 0;
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] > 1) {
index_ = index_ * x_dims_array[i] + index_array[i];
}
}
return index_;
}
template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCPU(const DenseTensor &x,
const DenseTensor &y,
DenseTensor *z,
int *x_dims_array,
int *y_dims_array,
int *out_dims_array,
int max_dim,
const paddle::platform::CPUDeviceContext &ctx,
Functor func,
const bool is_xsize_larger = true) {
std::vector<int> index_array(max_dim, 0);
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
PADDLE_ENFORCE_NOT_NULL(x_data,
paddle::platform::errors::InvalidArgument(
"The input X should not be empty."));
PADDLE_ENFORCE_NOT_NULL(y_data,
paddle::platform::errors::InvalidArgument(
"The input Y should not be empty."));
OutType *out_data = z->mutable_data<OutType>();
const int out_size = std::accumulate(
out_dims_array, out_dims_array + max_dim, 1, std::multiplies<int>());
int x_index, y_index;
for (int out_index = 0; out_index < out_size; ++out_index) {
x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data());
if (is_xsize_larger) {
out_data[out_index] = func(x_data[x_index], y_data[y_index]);
} else {
out_data[out_index] = func(y_data[y_index], x_data[x_index]);
}
UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data());
}
}
template <typename Functor, typename T, typename OutType = T>
void CommonElementwiseBroadcastForward(
const paddle::platform::CPUDeviceContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
DenseTensor *z,
const DDim &x_dims,
const DDim &y_dims,
Functor func,
int axis,
const bool is_xsize_larger = true) {
int max_dim = (std::max)(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis,
0,
paddle::platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
max_dim,
paddle::platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
general::GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
CommonForwardBroadcastCPU<Functor, T, OutType>(x,
y,
z,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
dev_ctx,
func,
is_xsize_larger);
}
// It is a common implementation to compute binary calculation with the support
// of broadcast, supporting both CPU and GPU.
// - CPU implementation cannot support the case when x needs broadcast, thus
// this function need to be called with XxxFunctor and XxxInverseFunctor,
// like paddle/fluid/operators/elementwise/elementwise_add_op.h#L49 - L55.
// - GPU implementation supports all the broadcast cases, thus there is no need
// to define and call with XxxInverseFunctor.
// TODO(liuyiqun): optimize the CPU implementation to support all broadcast
// cases and avoid the need of XxxInverseFunctor.
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const paddle::platform::CPUDeviceContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
int axis,
Functor func,
DenseTensor *z) {
auto x_dims = x.dims();
auto y_dims = y.dims();
bool is_xsize_larger = true;
int max_dim = x_dims.size();
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
}
general::
TransformFunctor<Functor, T, paddle::platform::CPUDeviceContext, OutType>
functor(x, y, z, dev_ctx, func, is_xsize_larger);
if (x_dims == y_dims) {
functor.Run();
return;
}
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis,
0,
paddle::platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
max_dim,
paddle::platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
auto y_dims_trimed = general::trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
general::get_mid_dims(x_dims,
y_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
} else {
auto x_dims_trimed = general::trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
general::get_mid_dims(y_dims,
x_dims_trimed,
axis_trim,
&pre,
&n,
&post,
&is_run_common_broadcast);
}
// special case for common implementation.
// case 1: x=[2,3,1,5], y=[2,1,4,1]
// case 2: x=[2,3,4], y=[1,1,4]
if (is_run_common_broadcast == 1) {
CommonElementwiseBroadcastForward<Functor, T, OutType>(
dev_ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger);
return;
}
if (post == 1) {
functor.RunRowWise(n, pre);
return;
} else {
functor.RunMidWise(n, pre, post);
return;
}
}
template <typename Functor>
struct SameDimsElementwiseCompute {
void operator()(const paddle::platform::CPUDeviceContext &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
DenseTensor *z) {
Functor()(dev_ctx, x, y, z);
}
};
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/kernels/functions/cuda/elementwise/elementwise_broadcast.cu.h"
#include "paddle/pten/kernels/functions/cuda/elementwise/elementwise_no_broadcast.cu.h"
namespace pten {
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &cuda_ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
std::vector<int> dims_size;
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}
if (no_broadcast_flag) {
LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(
cuda_ctx, ins, outs, func);
} else {
axis = axis == -1
? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(
cuda_ctx, ins, outs, axis, func);
}
}
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/cuda/elementwise/elementwise_common.cu.h"
namespace pten {
struct DimensionsTransform {
using DimVector = std::vector<int64_t>;
typedef void (*MergeFunctor)(
bool &, std::vector<DimVector> &, DimVector &, int, int);
int64_t dim_size;
DimVector out_dims;
std::vector<DimVector> in_dims;
private:
// To compensate the lackage of input_tensors` dimension with input variable
// 'axis'
void InputDimensionsExtend(int N, int axis) {
for (auto &in_dim : in_dims) {
int64_t in_idx = 0;
if (in_dim.size() < dim_size) {
DimVector tmp_dim(dim_size, 1);
do {
if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
tmp_dim[axis] = in_dim[in_idx];
in_idx++;
axis++;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"recieved %d.",
in_idx + 1,
axis + 1,
out_dims[axis],
in_dim[in_idx]));
}
} while (in_idx < in_dim.size());
in_dim.resize(dim_size);
std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
} else {
do {
if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
in_idx++;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"recieved %d.",
in_idx + 1,
in_idx + 1,
out_dims[in_idx],
in_dim[in_idx]));
}
} while (in_idx < dim_size);
}
std::reverse(in_dim.begin(), in_dim.end());
}
std::reverse(out_dims.begin(), out_dims.end());
}
template <typename MergeFunctor>
__inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
(*vec)[m_idx - 1] = std::accumulate(vec->begin() + l_idx,
vec->begin() + m_idx,
1,
std::multiplies<int64_t>());
vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
};
int64_t i = 0;
while (i < dim_size) {
int cnt = 0;
int low_idx = i;
bool equal = true;
do {
merge_func(equal, in_dims, out_dims, i, N);
if (equal) {
i++;
cnt++;
} else {
break;
}
} while (i < dim_size);
if (cnt > 1) {
for (auto &in_dim : in_dims) {
VectorReorganise(&in_dim, low_idx, i);
}
VectorReorganise(&out_dims, low_idx, i);
dim_size -= --cnt;
i -= cnt;
} else if (cnt < 1) {
i++;
}
}
}
public:
explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins,
const paddle::framework::DDim &dims,
int axis) {
const int N = ins.size();
dim_size = dims.size();
out_dims = paddle::framework::vectorize<int64_t>(dims);
in_dims.resize(N);
for (int j = 0; j < N; ++j) {
in_dims[j] = paddle::framework::vectorize<int64_t>(ins[j]->dims());
}
InputDimensionsExtend(N, axis);
auto merge_sequential_dims = [](bool &equal,
std::vector<DimVector> &in_dims,
DimVector &out,
int i,
int num) {
for (int j = 1; j < num; ++j) {
equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
}
};
auto merge_sequential_one_dims = [](bool &equal,
std::vector<DimVector> &in_dims,
DimVector &out,
int i,
int num) {
equal = in_dims[0][i] == 1;
if (equal) {
for (int j = 1; j < num; ++j) {
equal = in_dims[j][i] == out[i];
}
}
};
// To Merge the dimensions of input_tensors while the consequtive
// equal-dimensions appears.
MergeFunctor merge_ptr = merge_sequential_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
int min_idx = 0;
int min_val = std::accumulate(
in_dims[0].begin(), in_dims[0].end(), 1, std::multiplies<int64_t>());
for (int j = 1; j < N; ++j) {
int temp = std::accumulate(
in_dims[j].begin(), in_dims[j].end(), 1, std::multiplies<int64_t>());
min_val = min_val > temp ? temp : min_val;
min_idx = min_val == temp ? j : min_idx;
}
std::swap(in_dims[0], in_dims[min_idx]);
// To Merge the dimension of input_tensors while the consequtive
// 1-value-dimensions appears.
merge_ptr = merge_sequential_one_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[min_idx], in_dims[0]);
}
};
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const T *__restrict__ src,
uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config,
int numel,
int num,
bool need_broadcast) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
dst, src, block_offset, config, numel);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
}
}
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int VecSize,
int Rank,
bool IsBoundary = false>
__device__ void DealSegment(
const paddle::framework::Array<const InT *__restrict__, Arity> &ins,
OutT *out,
const paddle::framework::Array<bool, Arity> &use_broadcast,
uint32_t numel,
const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
&configs,
int num,
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
int block_offset = blockIdx.x * blockDim.x * VecSize;
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
}
const bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT,
OutT,
VecSize,
Functor,
Arity,
kCallElementwiseAny>()(func, args, result);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
out + block_offset, result, num);
}
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int VecSize,
int Rank>
__global__ void BroadcastKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out,
paddle::framework::Array<bool, Arity> use_broadcast,
uint32_t numel,
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
configs,
int main_tid,
int tail_tid,
Functor func) {
int block_offset = blockIdx.x * blockDim.x * VecSize;
// data offset of this block
if (blockIdx.x < main_tid) {
int num = blockDim.x * VecSize; // blockIdx.x < main_tid
pten::DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, false>(
ins, out, use_broadcast, numel, configs, num, func);
} else { // reminder
int num = tail_tid;
pten::DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, true>(
ins, out, use_broadcast, numel, configs, num, func);
}
}
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int VecSize,
int Rank>
void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
DenseTensor *out,
Functor func,
DimensionsTransform merge_dims) {
int numel = out->numel();
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
int main_tid = numel / (VecSize * threads);
int tail_tid = numel % (VecSize * threads);
auto stream = ctx.stream();
OutT *out_data = out->mutable_data<OutT>();
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
paddle::framework::Array<bool, Arity> use_broadcast;
paddle::framework::Array<const InT *__restrict__, Arity> ins_data;
for (int i = 0; i < Arity; i++) {
use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = ins[i]->data<InT>();
if (use_broadcast[i]) {
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configs[i] = kps::details::BroadcastConfig<Rank>(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
BroadcastKernel<InT,
OutT,
Functor,
Arity,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(ins_data,
out_data,
use_broadcast,
numel,
configs,
main_tid,
tail_tid,
func);
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
void LaunchBroadcastKernelForDifferentVecSize(
const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
DenseTensor *out,
int axis,
Functor func) {
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>( \
ctx, ins, out, func, merge_dims); \
} break;
switch (merge_dims.dim_size) {
CALL_BROADCAST_FOR_DIM_SIZE(1);
CALL_BROADCAST_FOR_DIM_SIZE(2);
CALL_BROADCAST_FOR_DIM_SIZE(3);
CALL_BROADCAST_FOR_DIM_SIZE(4);
CALL_BROADCAST_FOR_DIM_SIZE(5);
CALL_BROADCAST_FOR_DIM_SIZE(6);
CALL_BROADCAST_FOR_DIM_SIZE(7);
CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.\n",
merge_dims.dim_size,
paddle::framework::DDim::kMaxRank));
}
}
#undef CALL_BROADCAST_FOR_DIM_SIZE
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchBroadcastElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
using Traits = paddle::platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(),
kArity,
paddle::platform::errors::InvalidArgument(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
PADDLE_ENFORCE_EQ(kArity,
2,
paddle::platform::errors::InvalidArgument(
"Currently only broadcast of binary is supported and "
"verified, but received %d.",
kArity));
int in_vec_size = 4;
DenseTensor *out = (*outs)[0];
for (auto *in : ins) {
auto temp_size = paddle::platform::GetVectorizedSize<InT>(in->data<InT>());
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
: in_vec_size;
}
int out_vec_size =
paddle::platform::GetVectorizedSize<OutT>(out->data<OutT>());
int vec_size = std::min(out_vec_size, in_vec_size);
switch (vec_size) {
case 4: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 4>(
ctx, ins, out, axis, func);
break;
}
case 2: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 2>(
ctx, ins, out, axis, func);
break;
}
case 1: {
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 1>(
ctx, ins, out, axis, func);
break;
}
default: {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
}
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/function_traits.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/general/elementwise_base.h"
namespace pten {
namespace kps = paddle::operators::kernel_primitives;
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
template <typename InT,
typename OutT,
int VecSize,
typename Functor,
int Arity,
bool CallElementwiseAny = false>
struct ElementwisePrimitiveCaller {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result);
};
template <typename InT, typename OutT, int VecSize, typename Functor, int Arity>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseAny<InT, OutT, VecSize, 1, 1, Arity, Functor>(
result, args, func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseBinary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseTernary<InT, OutT, VecSize, 1, 1, Functor>(
result, args[0], args[1], args[2], func);
}
};
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/kernels/functions/cuda/elementwise/elementwise_common.cu.h"
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif
namespace pten {
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const paddle::platform::CUDADeviceContext &ctx,
int64_t numel,
int vec_size) {
int threads = ELEMENTWISE_BLOCK_SIZE;
int sm_count = ctx.GetSMCount();
int active_threads_num = numel / vec_size;
if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads = paddle::platform::RoundToPowerOfTwo(active_threads_num /
(sm_count << 1));
} else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads = paddle::platform::RoundToPowerOfTwo(active_threads_num /
(sm_count << 2));
}
// Number of threads per block shall be larger than 64.
return std::max(64, threads);
}
template <typename InT,
typename OutT,
typename Functor,
int Arity,
int VecSize,
bool IsBoundary>
__device__ void DealSegment(
const paddle::framework::Array<const InT *__restrict__, Arity> &in,
OutT *out,
int num,
Functor func) {
InT args[Arity][VecSize];
OutT result[VecSize];
int data_offset = VecSize * blockIdx.x * blockDim.x;
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(
args[i], in[i] + data_offset, num);
}
const bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
ElementwisePrimitiveCaller<InT,
OutT,
VecSize,
Functor,
Arity,
kCallElementwiseAny>()(func, args, result);
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
out + data_offset, result, num);
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
__global__ void ElementVectorizeKernel(
paddle::framework::Array<const InT *__restrict__, Arity> ins,
OutT *out,
int size,
Functor func) {
int data_offset = VecSize * blockIdx.x * blockDim.x;
int num = size - data_offset;
// the num this time have to deal with
if (VecSize * blockDim.x > num) { // reminder segment
DealSegment<InT, OutT, Functor, Arity, VecSize, true>(ins, out, num, func);
} else { // complete segment
DealSegment<InT, OutT, Functor, Arity, VecSize, false>(ins, out, num, func);
}
}
template <typename InT, typename OutT>
int GetVectorizedSizeForTensors(const std::vector<const DenseTensor *> &ins,
const std::vector<DenseTensor *> &outs) {
int vec_size = 4;
for (auto iter = ins.begin(); iter != ins.end(); ++iter) {
vec_size = std::min<int>(
vec_size, paddle::platform::GetVectorizedSize((*iter)->data<InT>()));
}
for (auto iter = outs.begin(); iter != outs.end(); ++iter) {
vec_size = std::min<int>(
vec_size, paddle::platform::GetVectorizedSize((*iter)->data<OutT>()));
}
return vec_size;
}
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func) {
auto numel = ins[0]->numel();
int block_size = GetThreadsConfig(ctx, numel, VecSize);
int grid_size =
((numel + VecSize - 1) / VecSize + block_size - 1) / block_size;
auto stream = ctx.stream();
OutT *out_data = (*outs)[0]->mutable_data<OutT>();
paddle::framework::Array<const InT *__restrict__, Arity> ins_data;
for (int i = 0; i < Arity; i++) {
ins_data[i] = ins[i]->data<InT>();
}
ElementVectorizeKernel<InT,
OutT,
Functor,
Arity,
VecSize><<<grid_size, block_size, 0, stream>>>(
ins_data, out_data, numel, func);
}
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
void LaunchSameDimsElementwiseCudaKernel(
const paddle::platform::CUDADeviceContext &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func) {
using Traits = paddle::platform::FunctionTraits<Functor>;
const int kArity =
Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
PADDLE_ENFORCE_EQ(ins.size(),
kArity,
paddle::platform::errors::InvalidArgument(
"The number of inputs is expected to be equal to the "
"arity of functor. But recieved: the number of inputs "
"is %d, the arity of functor is %d.",
ins.size(),
kArity));
// calculate the max vec_size for all ins and outs
int vec_size = GetVectorizedSizeForTensors<InT, OutT>(ins, *outs);
switch (vec_size) {
case 4:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 4>(
ctx, ins, outs, func);
break;
case 2:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 2>(
ctx, ins, outs, func);
break;
case 1:
ElementwiseCudaKernel<InT, OutT, Functor, kArity, 1>(
ctx, ins, outs, func);
break;
default: {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
}
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/eigen/common.h"
namespace pten {
namespace eigen {
template <typename DevCtx, typename T>
void ElementwiseAdd(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto eigen_x = pten::EigenVector<T>::Flatten(x);
auto eigen_y = pten::EigenVector<T>::Flatten(y);
auto eigen_z = pten::EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();
eigen_z.device(place) = eigen_x + eigen_y;
}
} // namespace eigen
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/transform.h"
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
namespace general {
using DDim = paddle::framework::DDim;
using CPUContext = paddle::platform::CPUDeviceContext;
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T, typename DeviceContext>
class MidWiseTransformIterator;
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
template <typename T>
class RowwiseTransformIterator<T, CPUContext>
: public std::iterator<std::random_access_iterator_tag,
T,
std::ptrdiff_t,
T *,
T &> {
public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, CPUContext> &operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
return *this;
}
RowwiseTransformIterator<T, CPUContext> &operator+(int n) {
while (n-- > 0) {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}
bool operator==(const RowwiseTransformIterator<T, CPUContext> &rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const RowwiseTransformIterator<T, CPUContext> &rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T &operator*() { return ptr_[i_]; }
private:
const T *ptr_;
int i_;
int64_t n_;
};
template <typename T>
class MidWiseTransformIterator<T, CPUContext>
: public std::iterator<std::random_access_iterator_tag,
T,
std::ptrdiff_t,
T *,
T &> {
public:
MidWiseTransformIterator(const T *ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T, CPUContext> &operator++() {
++j_;
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}
MidWiseTransformIterator<T, CPUContext> &operator+(int n) {
while (n-- > 0) {
++j_;
if (UNLIKELY(j_ == post_)) {
++i_;
j_ = 0;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
}
return *this;
}
bool operator==(const MidWiseTransformIterator<T, CPUContext> &rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const MidWiseTransformIterator<T, CPUContext> &rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T &operator*() { return ptr_[i_]; }
private:
const T *ptr_;
int64_t i_;
int64_t j_;
int64_t n_;
int64_t post_;
};
#if defined(__NVCC__) || defined(__HIPCC__)
using CUDAContext = paddle::platform::CUDADeviceContext;
template <typename T>
class RowwiseTransformIterator<T, CUDAContext>
: public thrust::iterator_adaptor<RowwiseTransformIterator<T, CUDAContext>,
const T *> {
public:
typedef thrust::iterator_adaptor<RowwiseTransformIterator<T, CUDAContext>,
const T *>
super_t;
HOSTDEVICE RowwiseTransformIterator(const T *x, int n)
: super_t(x), begin_(x), n_(n) {}
friend class thrust::iterator_core_access;
private:
unsigned int n_;
const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (this->base() - begin_) % n_);
}
};
template <typename T>
class MidWiseTransformIterator<T, CUDAContext>
: public thrust::iterator_adaptor<MidWiseTransformIterator<T, CUDAContext>,
const T *> {
public:
typedef thrust::iterator_adaptor<MidWiseTransformIterator<T, CUDAContext>,
const T *>
super_t;
HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post) {}
friend class thrust::iterator_core_access;
private:
unsigned int post_;
unsigned int n_;
const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (((this->base() - begin_) / post_) % n_));
}
};
#endif
template <typename Functor,
typename T,
typename DeviceContext,
typename OutType = T>
class TransformFunctor {
public:
TransformFunctor(const DenseTensor &x,
const DenseTensor &y,
DenseTensor *z,
const DeviceContext &ctx,
Functor func,
const bool is_xsize_larger = true)
: x_(x.data<T>()),
y_(y.data<T>()),
z_(z->mutable_data<OutType>()),
nx_(x.numel()),
ctx_(ctx),
func_(func),
is_xsize_larger_(is_xsize_larger) {
if (is_xsize_larger_ == false) {
nx_ = y.numel();
}
}
inline void Run() const {
paddle::platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
}
inline void RunRowWise(int n, int pre) const {
paddle::platform::Transform<DeviceContext> trans;
if (is_xsize_larger_) {
trans(ctx_,
x_,
x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n),
z_,
func_);
} else {
trans(ctx_,
y_,
y_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(x_, n),
z_,
func_);
}
}
inline void RunMidWise(int n, int pre, int post) const {
paddle::platform::Transform<DeviceContext> trans;
if (is_xsize_larger_) {
trans(ctx_,
x_,
x_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(y_, n, post),
z_,
func_);
} else {
trans(ctx_,
y_,
y_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(x_, n, post),
z_,
func_);
}
}
private:
const T *x_;
const T *y_;
OutType *z_;
int64_t nx_;
const DeviceContext &ctx_;
Functor func_;
bool is_xsize_larger_;
};
inline DDim trim_trailing_singular_dims(const DDim &dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
if (actual_dims_size == dims.size()) return dims;
std::vector<int> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
trim_dims[i] = dims[i];
}
if (trim_dims.size() == 0) {
return DDim(paddle::framework::make_dim());
}
DDim actual_dims = paddle::framework::make_ddim(trim_dims);
return actual_dims;
}
/*
* Out = X ⊙ Y
* If Y's shape does not match X' shape, they will be reshaped.
* For example:
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5)
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*
* New parameter: *is_run_common_broadcast* is a flag to record whether to run
* common broadcast code.
*/
inline void get_mid_dims(const DDim &x_dims,
const DDim &y_dims,
const int axis,
int *pre,
int *n,
int *post,
int *is_run_common_broadcast) {
*pre = 1;
*n = 1;
*post = 1;
*is_run_common_broadcast = 0;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
if (x_dims[i + axis] != y_dims[i]) {
PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1,
true,
paddle::platform::errors::InvalidArgument(
"Broadcast dimension mismatch. Operands "
"could not be broadcast together with the shape of "
"X = [%s] and the shape of Y = [%s]. Received [%d] "
"in X is not equal to [%d] in Y.",
x_dims,
y_dims,
x_dims[i + axis],
y_dims[i]));
*is_run_common_broadcast = 1;
return;
}
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
inline void GetBroadcastDimsArrays(const DDim &x_dims,
const DDim &y_dims,
int *x_dims_array,
int *y_dims_array,
int *out_dims_array,
const int max_dim,
const int axis) {
PADDLE_ENFORCE_GE(
axis,
0,
paddle::platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(axis,
max_dim,
paddle::platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.",
max_dim,
axis));
if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
std::fill(y_dims_array + axis + y_dims.size(), y_dims_array + max_dim, 1);
}
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array);
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array + axis);
} else {
std::fill(x_dims_array, x_dims_array + axis, 1);
if (axis + x_dims.size() < max_dim) {
std::fill(x_dims_array + axis + x_dims.size(), x_dims_array + max_dim, 1);
}
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array + axis);
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array);
}
for (int i = 0; i < max_dim; i++) {
PADDLE_ENFORCE_EQ(
x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
y_dims_array[i] <= 1,
true,
paddle::platform::errors::InvalidArgument(
"Broadcast dimension mismatch. Operands could "
"not be broadcast together with the shape of X = [%s] and "
"the shape of Y = [%s]. Received [%d] in X is not equal to "
"[%d] in Y at i:%d.",
x_dims,
y_dims,
x_dims_array[i],
y_dims_array[i],
i));
if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) ||
(x_dims_array[i] == 1 && y_dims_array[i] == 1)) {
out_dims_array[i] = (std::max)(x_dims_array[i], y_dims_array[i]);
} else {
out_dims_array[i] = -1;
}
}
}
} // namespace general
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/functions/blas/elementwise.h"
#include "paddle/pten/kernels/functions/eigen/elementwise.h"
namespace pten {
namespace general {
// Define the binary functors used in elementwise ops.
// Add
template <typename DevCtx, typename T, class Enable = void>
struct SameDimsAddFunctor {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z);
};
template <typename DevCtx, typename T>
struct SameDimsAddFunctor<
DevCtx,
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
blas::ElementwiseAdd<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename DevCtx, typename T>
struct SameDimsAddFunctor<
DevCtx,
T,
typename std::enable_if<!std::is_floating_point<T>::value>::type> {
void operator()(const DevCtx& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* z) {
eigen::ElementwiseAdd<DevCtx, T>(dev_ctx, x, y, z);
}
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; }
};
template <typename T>
struct InverseAddFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b + a; }
};
} // namespace general
} // namespace pten
......@@ -5,3 +5,4 @@ cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_api pten_api_utils)
cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_api pten_api_utils)
cc_test(test_framework_storage SRCS test_storage.cc DEPS pten_api_utils)
cc_test(test_framework_tensor_utils SRCS test_tensor_utils.cc DEPS pten_api_utils)
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_api pten_api_utils)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/math.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(MathCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(MathCUDA);
#endif
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
// TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(API, add) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x->mutable_data<float>();
auto dense_y = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({10}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y->mutable_data<float>();
float sum[3][10] = {0.0};
for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
sum[i][j] = (i * 10 + j) * 1.0 + j * 2.0;
}
}
for (size_t i = 0; i < 10; ++i) {
dense_y_data[i] = i * 2.0;
}
paddle::experimental::Tensor x(dense_x);
paddle::experimental::Tensor y(dense_y);
// 2. test API
auto out = paddle::experimental::add(x, y);
// 3. check result
ASSERT_EQ(out.shape().size(), 2);
ASSERT_EQ(out.shape()[0], 3);
ASSERT_EQ(out.numel(), 30);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
auto expect_result = sum;
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
auto actual_result0 = dense_out->data<float>()[0];
auto actual_result1 = dense_out->data<float>()[1];
auto actual_result2 = dense_out->data<float>()[10];
ASSERT_NEAR(expect_result[0][0], actual_result0, 1e-6f);
ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f);
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
......@@ -4,3 +4,4 @@ cc_test(test_fill_dev_api SRCS test_fill_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten_api_utils)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/include/math.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
PT_DECLARE_MODULE(MathCPU);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_MODULE(MathCUDA);
#endif
namespace framework = paddle::framework;
using DDim = paddle::framework::DDim;
TEST(DEV_API, elementwise_add) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::DenseTensor dense_x(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({3, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data = dense_x.mutable_data<float>();
pten::DenseTensor dense_y(alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({10}),
pten::DataLayout::NCHW));
auto* dense_y_data = dense_y.mutable_data<float>();
float sum[3][10] = {0.0};
for (size_t i = 0; i < 3; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
sum[i][j] = (i * 10 + j) * 1.0 + j * 2.0;
}
}
for (size_t i = 0; i < 10; ++i) {
dense_y_data[i] = i * 2.0;
}
int axis = 1;
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());
// 2. test API
auto dense_out = pten::ElementwiseAdd<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
dense_y,
axis);
// 3. check result
ASSERT_EQ(dense_out.dims().size(), 2);
ASSERT_EQ(dense_out.dims()[0], 3);
ASSERT_EQ(dense_out.meta().type, pten::DataType::FLOAT32);
ASSERT_EQ(dense_out.meta().layout, pten::DataLayout::NCHW);
auto expect_result = sum;
auto actual_result0 = dense_out.data<float>()[0];
auto actual_result1 = dense_out.data<float>()[1];
auto actual_result2 = dense_out.data<float>()[10];
ASSERT_NEAR(expect_result[0][0], actual_result0, 1e-6f);
ASSERT_NEAR(expect_result[0][1], actual_result1, 1e-6f);
ASSERT_NEAR(expect_result[1][0], actual_result2, 1e-6f);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册