未验证 提交 890d6bc0 编写于 作者: Z Zhang Zheng 提交者: GitHub

Modify some contents for elementwise op impl (#32414)

上级 1064f2b8
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/complex64.h"
...@@ -34,7 +33,9 @@ namespace operators { ...@@ -34,7 +33,9 @@ namespace operators {
*/ */
template <typename T> template <typename T>
struct CudaAddFunctor { struct CudaAddFunctor {
inline HOSTDEVICE T operator()(T args[]) const { return args[0] + args[1]; } __device__ __forceinline__ T operator()(const T* args) const {
return args[0] + args[1];
}
}; };
template <typename T> template <typename T>
......
...@@ -13,6 +13,17 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#else
#define ELEMENTWISE_BLOCK_SIZE 512
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -90,8 +101,7 @@ struct ElementwiseDataWrapper { ...@@ -90,8 +101,7 @@ struct ElementwiseDataWrapper {
template <ElementwiseType ET, int VecSize, typename T, typename Functor> template <ElementwiseType ET, int VecSize, typename T, typename Functor>
__device__ void VectorizedKernelImpl( __device__ void VectorizedKernelImpl(
ElementwiseDataWrapper<ET, VecSize, T> data, int size, Functor func, ElementwiseDataWrapper<ET, VecSize, T> data, Functor func, int tid) {
int tid) {
using VecType = CudaAlignedVector<T, VecSize>; using VecType = CudaAlignedVector<T, VecSize>;
VecType ins_vec[ET]; VecType ins_vec[ET];
VecType out_vec; VecType out_vec;
...@@ -121,10 +131,9 @@ __device__ void VectorizedKernelImpl( ...@@ -121,10 +131,9 @@ __device__ void VectorizedKernelImpl(
data.store_vector(out_vec, tid); data.store_vector(out_vec, tid);
} }
template <ElementwiseType ET, typename T, typename Functor> template <ElementwiseType ET, int VecSize, typename T, typename Functor>
__device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, 1, T> data, __device__ void ScalarKernelImpl(ElementwiseDataWrapper<ET, VecSize, T> data,
int size, Functor func, int start, Functor func, int start, int remain) {
int remain) {
T ins[ET]; T ins[ET];
T out; T out;
...@@ -146,12 +155,11 @@ __global__ void VectorizedKernel(const T *__restrict__ in0, ...@@ -146,12 +155,11 @@ __global__ void VectorizedKernel(const T *__restrict__ in0,
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = size - VecSize * tid; int remain = size - VecSize * tid;
remain = remain > 0 ? remain : 0; remain = remain > 0 ? remain : 0;
auto data = ElementwiseDataWrapper<ET, VecSize, T>(out, in0, in1);
if (remain >= VecSize) { if (remain >= VecSize) {
auto data = ElementwiseDataWrapper<ET, VecSize, T>(out, in0, in1); VectorizedKernelImpl(data, func, tid);
VectorizedKernelImpl(data, size, func, tid);
} else { } else {
auto data = ElementwiseDataWrapper<ET, 1, T>(out, in0, in1); ScalarKernelImpl(data, func, tid * VecSize, remain);
ScalarKernelImpl(data, size, func, tid * VecSize, remain);
} }
} }
...@@ -162,7 +170,7 @@ __global__ void ScalarKernel(const T *__restrict__ in0, ...@@ -162,7 +170,7 @@ __global__ void ScalarKernel(const T *__restrict__ in0,
auto data = ElementwiseDataWrapper<ET, 1, T>(out, in0, in1); auto data = ElementwiseDataWrapper<ET, 1, T>(out, in0, in1);
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int remain = tid < size ? 1 : 0; int remain = tid < size ? 1 : 0;
ScalarKernelImpl(data, size, func, tid, remain); ScalarKernelImpl(data, func, tid, remain);
} }
template <ElementwiseType ET, typename T, typename Functor> template <ElementwiseType ET, typename T, typename Functor>
...@@ -173,7 +181,7 @@ void LaunchElementwiseCudaKernel( ...@@ -173,7 +181,7 @@ void LaunchElementwiseCudaKernel(
// calculate the max vec_size for all ins and outs // calculate the max vec_size for all ins and outs
auto size = ins[0]->numel(); auto size = ins[0]->numel();
int vec_size = GetVectorizedSize<T>(ins, *outs); int vec_size = GetVectorizedSize<T>(ins, *outs);
int block_size = PADDLE_CUDA_THREAD_SIZE; int block_size = ELEMENTWISE_BLOCK_SIZE;
int grid_size = int grid_size =
((size + vec_size - 1) / vec_size + block_size - 1) / block_size; ((size + vec_size - 1) / vec_size + block_size - 1) / block_size;
const T *in0 = ins[0]->data<T>(); const T *in0 = ins[0]->data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册