未验证 提交 fcd93b32 编写于 作者: L limingshu 提交者: GitHub

Support Div and FloorDiv functor in elementwise system (#33053)

上级 cd95ea82
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
......@@ -23,38 +22,37 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T, typename Enable = void>
struct CudaDivFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] / args[1];
}
};
template <typename T>
struct SameDimsElemwiseDiv<platform::CUDADeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
DivRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>());
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
x->numel());
for_range(functor);
struct CudaDivFunctor<T,
typename std::enable_if_t<std::is_integral<T>::value>> {
inline HOSTDEVICE T operator()(const T* args) const {
PADDLE_ENFORCE(args[1] != 0,
"Invalid Argument Error: Integer division by zero "
"encountered in divide. Please check the input value.");
return args[0] / args[1];
}
};
template <>
struct SameDimsElemwiseDiv<platform::CUDADeviceContext, platform::float16> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
const half* y2 =
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseDivCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
template <typename T>
class ElementwiseDivKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaDivFunctor<T>());
}
};
......
......@@ -12,11 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
struct CudaFloorDivFunctor {
inline HOSTDEVICE T operator()(const T argv[]) const {
PADDLE_ENFORCE(argv[1] != 0,
"InvalidArgument: divide by zero "
"encountered in floor-divide ops, please check.\n");
return static_cast<T>(std::trunc(argv[0] / argv[1]));
}
};
template <typename T>
class ElementwiseFloorDivKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaFloorDivFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_floordiv,
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int>,
......
......@@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h"
#ifdef __HIPCC__
......@@ -28,19 +28,62 @@ namespace operators {
enum ElementwiseType { kUnary = 1, kBinary = 2 };
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
int64_t numel, int vec_size) {
int threads = ELEMENTWISE_BLOCK_SIZE;
int sm_count = ctx.GetSMCount();
int active_threads_num = numel / vec_size;
if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
} else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
}
// Number of threads per block shall be larger than 64.
return std::max(64, threads);
}
/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template <typename T>
int GetVectorizedSizeImpl(const T *pointer) {
constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec8 =
std::alignment_of<CudaAlignedVector<T, 8>>::value; // NOLINT
constexpr int vec4 =
std::alignment_of<CudaAlignedVector<T, 4>>::value; // NOLINT
constexpr int vec2 =
std::alignment_of<CudaAlignedVector<T, 2>>::value; // NOLINT
if (address % vec4 == 0) {
return 4;
if (address % vec8 == 0) {
/*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return std::min(4, valid_vec_size);
} else if (address % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0) {
return 2;
return std::min(2, valid_vec_size);
} else {
return 1;
}
return 1;
}
template <typename InT, typename OutT>
......@@ -96,7 +139,7 @@ struct ElementwiseDataWrapper {
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__device__ void VectorizedKernelImpl(
__device__ inline void VectorizedKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int tid) {
using InVecType = CudaAlignedVector<InT, VecSize>;
......@@ -104,34 +147,30 @@ __device__ void VectorizedKernelImpl(
InVecType ins_vec[ET];
OutVecType out_vec;
InT *ins_ptr[ET];
OutT *out_ptr;
InT ins[ET];
#pragma unroll
for (int i = 0; i < ET; ++i) {
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
}
out_ptr = reinterpret_cast<OutT *>(&out_vec);
// load
data.load_vector(ins_vec, tid);
// compute
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
InT ins[ET];
#pragma unroll
for (int j = 0; j < ET; ++j) {
ins[j] = ins_ptr[j][i];
}
out_ptr[i] = func(ins);
out_vec.val[i] = func(ins);
}
// store
data.store_vector(out_vec, tid);
}
template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor>
__device__ void ScalarKernelImpl(
__device__ inline void ScalarKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int start, int remain) {
InT ins[ET];
......@@ -182,7 +221,7 @@ void LaunchSameDimsElementwiseCudaKernel(
// calculate the max vec_size for all ins and outs
auto size = ins[0]->numel();
int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
int block_size = ELEMENTWISE_BLOCK_SIZE;
int block_size = GetThreadsConfig(ctx, size, vec_size);
int grid_size =
((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
const InT *in0 = ins[0]->data<InT>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册