未验证 提交 fcd93b32 编写于 作者: L limingshu 提交者: GitHub

Support Div and FloorDiv functor in elementwise system (#33053)

上级 cd95ea82
...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -23,38 +22,37 @@ namespace plat = paddle::platform; ...@@ -23,38 +22,37 @@ namespace plat = paddle::platform;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T, typename Enable = void>
struct CudaDivFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
return args[0] / args[1];
}
};
template <typename T> template <typename T>
struct SameDimsElemwiseDiv<platform::CUDADeviceContext, T> { struct CudaDivFunctor<T,
void operator()(const framework::ExecutionContext& ctx, typename std::enable_if_t<std::is_integral<T>::value>> {
const framework::Tensor* x, const framework::Tensor* y, inline HOSTDEVICE T operator()(const T* args) const {
framework::Tensor* z) { PADDLE_ENFORCE(args[1] != 0,
DivRangeFunctor<T> functor(x->data<T>(), y->data<T>(), z->data<T>()); "Invalid Argument Error: Integer division by zero "
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); "encountered in divide. Please check the input value.");
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx, return args[0] / args[1];
x->numel());
for_range(functor);
} }
}; };
template <> template <typename T>
struct SameDimsElemwiseDiv<platform::CUDADeviceContext, platform::float16> { class ElementwiseDivKernel<platform::CUDADeviceContext, T>
void operator()(const framework::ExecutionContext& ctx, : public framework::OpKernel<T> {
const framework::Tensor* x, const framework::Tensor* y, public:
framework::Tensor* z) { void Compute(const framework::ExecutionContext& ctx) const override {
auto size = x->numel(); std::vector<const framework::Tensor*> ins;
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) / std::vector<framework::Tensor*> outs;
PADDLE_CUDA_THREAD_SIZE, const auto& cuda_ctx =
1); ctx.template device_context<platform::CUDADeviceContext>();
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 = int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
reinterpret_cast<const half*>(x->data<platform::float16>()); LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
const half* y2 = cuda_ctx, ins, &outs, axis, CudaDivFunctor<T>());
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseDivCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
} }
}; };
......
...@@ -12,11 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h" #include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
struct CudaFloorDivFunctor {
inline HOSTDEVICE T operator()(const T argv[]) const {
PADDLE_ENFORCE(argv[1] != 0,
"InvalidArgument: divide by zero "
"encountered in floor-divide ops, please check.\n");
return static_cast<T>(std::trunc(argv[0] / argv[1]));
}
};
template <typename T>
class ElementwiseFloorDivKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaFloorDivFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_floordiv, elementwise_floordiv,
ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int>, ops::ElementwiseFloorDivKernel<plat::CUDADeviceContext, int>,
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
#ifdef __HIPCC__ #ifdef __HIPCC__
...@@ -28,19 +28,62 @@ namespace operators { ...@@ -28,19 +28,62 @@ namespace operators {
enum ElementwiseType { kUnary = 1, kBinary = 2 }; enum ElementwiseType { kUnary = 1, kBinary = 2 };
/*
* According to NVIDIA, if number of threads per block is 64/128/256/512,
* cuda performs better. And number of blocks should be greater (at least
* 2x~4x) than number of SMs. Hence, SM count is took into account within
* this function to determine the right number of threads per block.
*/
inline int GetThreadsConfig(const platform::CUDADeviceContext &ctx,
int64_t numel, int vec_size) {
int threads = ELEMENTWISE_BLOCK_SIZE;
int sm_count = ctx.GetSMCount();
int active_threads_num = numel / vec_size;
if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 1));
} else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads = platform::RoundToPowerOfTwo(active_threads_num / (sm_count << 2));
}
// Number of threads per block shall be larger than 64.
return std::max(64, threads);
}
/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template <typename T> template <typename T>
int GetVectorizedSizeImpl(const T *pointer) { int GetVectorizedSizeImpl(const T *pointer) {
constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
uint64_t address = reinterpret_cast<uint64_t>(pointer); uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec8 =
std::alignment_of<CudaAlignedVector<T, 8>>::value; // NOLINT
constexpr int vec4 = constexpr int vec4 =
std::alignment_of<CudaAlignedVector<T, 4>>::value; // NOLINT std::alignment_of<CudaAlignedVector<T, 4>>::value; // NOLINT
constexpr int vec2 = constexpr int vec2 =
std::alignment_of<CudaAlignedVector<T, 2>>::value; // NOLINT std::alignment_of<CudaAlignedVector<T, 2>>::value; // NOLINT
if (address % vec4 == 0) { if (address % vec8 == 0) {
return 4; /*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return std::min(4, valid_vec_size);
} else if (address % vec4 == 0) {
return std::min(4, valid_vec_size);
} else if (address % vec2 == 0) { } else if (address % vec2 == 0) {
return 2; return std::min(2, valid_vec_size);
} else {
return 1;
} }
return 1;
} }
template <typename InT, typename OutT> template <typename InT, typename OutT>
...@@ -96,7 +139,7 @@ struct ElementwiseDataWrapper { ...@@ -96,7 +139,7 @@ struct ElementwiseDataWrapper {
template <ElementwiseType ET, int VecSize, typename InT, typename OutT, template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor> typename Functor>
__device__ void VectorizedKernelImpl( __device__ inline void VectorizedKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func, ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int tid) { int tid) {
using InVecType = CudaAlignedVector<InT, VecSize>; using InVecType = CudaAlignedVector<InT, VecSize>;
...@@ -104,34 +147,30 @@ __device__ void VectorizedKernelImpl( ...@@ -104,34 +147,30 @@ __device__ void VectorizedKernelImpl(
InVecType ins_vec[ET]; InVecType ins_vec[ET];
OutVecType out_vec; OutVecType out_vec;
InT *ins_ptr[ET]; InT *ins_ptr[ET];
OutT *out_ptr; InT ins[ET];
#pragma unroll #pragma unroll
for (int i = 0; i < ET; ++i) { for (int i = 0; i < ET; ++i) {
ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i])); ins_ptr[i] = reinterpret_cast<InT *>(&(ins_vec[i]));
} }
out_ptr = reinterpret_cast<OutT *>(&out_vec);
// load // load
data.load_vector(ins_vec, tid); data.load_vector(ins_vec, tid);
// compute // compute
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; ++i) { for (int i = 0; i < VecSize; ++i) {
InT ins[ET];
#pragma unroll #pragma unroll
for (int j = 0; j < ET; ++j) { for (int j = 0; j < ET; ++j) {
ins[j] = ins_ptr[j][i]; ins[j] = ins_ptr[j][i];
} }
out_ptr[i] = func(ins); out_vec.val[i] = func(ins);
} }
// store // store
data.store_vector(out_vec, tid); data.store_vector(out_vec, tid);
} }
template <ElementwiseType ET, int VecSize, typename InT, typename OutT, template <ElementwiseType ET, int VecSize, typename InT, typename OutT,
typename Functor> typename Functor>
__device__ void ScalarKernelImpl( __device__ inline void ScalarKernelImpl(
ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func, ElementwiseDataWrapper<ET, VecSize, InT, OutT> data, Functor func,
int start, int remain) { int start, int remain) {
InT ins[ET]; InT ins[ET];
...@@ -182,7 +221,7 @@ void LaunchSameDimsElementwiseCudaKernel( ...@@ -182,7 +221,7 @@ void LaunchSameDimsElementwiseCudaKernel(
// calculate the max vec_size for all ins and outs // calculate the max vec_size for all ins and outs
auto size = ins[0]->numel(); auto size = ins[0]->numel();
int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs); int vec_size = GetVectorizedSize<InT, OutT>(ins, *outs);
int block_size = ELEMENTWISE_BLOCK_SIZE; int block_size = GetThreadsConfig(ctx, size, vec_size);
int grid_size = int grid_size =
((size + vec_size - 1) / vec_size + block_size - 1) / block_size; ((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
const InT *in0 = ins[0]->data<InT>(); const InT *in0 = ins[0]->data<InT>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册