Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b3466387
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b3466387
编写于
3月 01, 2022
作者:
L
Leo Chen
提交者:
GitHub
3月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[phi] move uniform_random to phi (#39937)
* move uniform_random to phi * fit selected_rows * replace mutable_data
上级
08b43cce
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
1011 addition
and
7 deletion
+1011
-7
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+3
-0
paddle/fluid/operators/uniform_random_op.cc
paddle/fluid/operators/uniform_random_op.cc
+0
-4
paddle/fluid/operators/uniform_random_op.cu
paddle/fluid/operators/uniform_random_op.cu
+0
-3
paddle/phi/kernels/cpu/uniform_random_kernel.cc
paddle/phi/kernels/cpu/uniform_random_kernel.cc
+115
-0
paddle/phi/kernels/funcs/aligned_vector.h
paddle/phi/kernels/funcs/aligned_vector.h
+75
-0
paddle/phi/kernels/funcs/distribution_helper.h
paddle/phi/kernels/funcs/distribution_helper.h
+249
-0
paddle/phi/kernels/funcs/index_impl.cu.h
paddle/phi/kernels/funcs/index_impl.cu.h
+93
-0
paddle/phi/kernels/gpu/uniform_random_kernel.cu
paddle/phi/kernels/gpu/uniform_random_kernel.cu
+163
-0
paddle/phi/kernels/selected_rows/uniform_random_kernel.cc
paddle/phi/kernels/selected_rows/uniform_random_kernel.cc
+88
-0
paddle/phi/kernels/uniform_random_kernel.h
paddle/phi/kernels/uniform_random_kernel.h
+66
-0
paddle/phi/ops/compat/uniform_random_sig.cc
paddle/phi/ops/compat/uniform_random_sig.cc
+159
-0
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
b3466387
...
@@ -2074,6 +2074,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2074,6 +2074,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
}
pt_kernel_context
->
AssignInputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
pt_kernel_context
->
AssignInputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
}
}
VLOG
(
4
)
<<
"Done inputs"
;
for
(
size_t
i
=
0
;
i
<
output_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_names
.
size
();
++
i
)
{
auto
it
=
ctx
.
outputs
.
find
(
output_names
[
i
]);
auto
it
=
ctx
.
outputs
.
find
(
output_names
[
i
]);
...
@@ -2118,6 +2119,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2118,6 +2119,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
pt_kernel_context
->
AssignOutputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
pt_kernel_context
->
AssignOutputRange
(
std
::
make_pair
(
start_idx
,
end_idx
),
i
);
}
}
VLOG
(
4
)
<<
"Done outputs"
;
for
(
size_t
i
=
0
;
i
<
attr_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
attr_names
.
size
();
++
i
)
{
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
phi
::
ScalarArray
)))
{
if
(
attr_defs
[
i
].
type_index
==
std
::
type_index
(
typeid
(
phi
::
ScalarArray
)))
{
...
@@ -2226,6 +2228,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
...
@@ -2226,6 +2228,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
}
}
}
}
}
VLOG
(
4
)
<<
"Done attributes"
;
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/operators/uniform_random_op.cc
浏览文件 @
b3466387
...
@@ -281,10 +281,6 @@ REGISTER_OPERATOR(
...
@@ -281,10 +281,6 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
paddle
::
operators
::
UniformRandomOpVarTypeInference
);
paddle
::
operators
::
UniformRandomOpVarTypeInference
);
REGISTER_OP_CPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
CPUUniformRandomKernel
<
float
>
,
paddle
::
operators
::
CPUUniformRandomKernel
<
double
>
,
paddle
::
operators
::
CPUUniformRandomKernel
<
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
uniform_random_batch_size_like
,
uniform_random_batch_size_like
,
paddle
::
operators
::
CPUUniformRandomKernel
<
float
>
,
paddle
::
operators
::
CPUUniformRandomKernel
<
float
>
,
...
...
paddle/fluid/operators/uniform_random_op.cu
浏览文件 @
b3466387
...
@@ -58,9 +58,6 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
...
@@ -58,9 +58,6 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
uniform_random
,
paddle
::
operators
::
GPUUniformRandomKernel
<
float
>
,
paddle
::
operators
::
GPUUniformRandomKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
uniform_random_batch_size_like
,
REGISTER_OP_CUDA_KERNEL
(
uniform_random_batch_size_like
,
paddle
::
operators
::
GPUUniformRandomKernel
<
float
>
,
paddle
::
operators
::
GPUUniformRandomKernel
<
float
>
,
paddle
::
operators
::
GPUUniformRandomKernel
<
double
>
);
paddle
::
operators
::
GPUUniformRandomKernel
<
double
>
);
paddle/phi/kernels/cpu/uniform_random_kernel.cc
0 → 100644
浏览文件 @
b3466387
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/uniform_random_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
>
inline
void
UniformRealDistribution
(
T
*
data
,
const
int64_t
&
size
,
const
float
&
min
,
const
float
&
max
,
std
::
shared_ptr
<
std
::
mt19937_64
>
engine
)
{
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
min
),
static_cast
<
T
>
(
max
));
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
*
engine
);
}
}
template
<
>
inline
void
UniformRealDistribution
(
phi
::
dtype
::
bfloat16
*
data
,
const
int64_t
&
size
,
const
float
&
min
,
const
float
&
max
,
std
::
shared_ptr
<
std
::
mt19937_64
>
engine
)
{
std
::
uniform_real_distribution
<
float
>
dist
(
min
,
max
);
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
static_cast
<
phi
::
dtype
::
bfloat16
>
(
dist
(
*
engine
));
}
}
template
<
typename
T
,
typename
Context
>
void
UniformRandomRawKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
float
diag_val
,
DenseTensor
*
out
)
{
out
->
Resize
(
phi
::
make_ddim
(
shape
.
GetData
()));
VLOG
(
4
)
<<
out
->
dims
();
T
*
data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
size
=
out
->
numel
();
std
::
shared_ptr
<
std
::
mt19937_64
>
engine
;
if
(
seed
)
{
engine
=
std
::
make_shared
<
std
::
mt19937_64
>
();
engine
->
seed
(
seed
);
}
else
{
engine
=
dev_ctx
.
GetGenerator
()
->
GetCPUEngine
();
}
UniformRealDistribution
<
T
>
(
data
,
size
,
min
,
max
,
engine
);
if
(
diag_num
>
0
)
{
PADDLE_ENFORCE_GT
(
size
,
(
diag_num
-
1
)
*
(
diag_step
+
1
),
phi
::
errors
::
InvalidArgument
(
"ShapeInvalid: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d"
,
diag_num
,
diag_step
,
(
diag_num
-
1
)
*
(
diag_step
+
1
),
size
));
for
(
int64_t
i
=
0
;
i
<
diag_num
;
++
i
)
{
int64_t
pos
=
i
*
diag_step
+
i
;
data
[
pos
]
=
diag_val
;
}
}
}
template
<
typename
T
,
typename
Context
>
void
UniformRandomKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
DenseTensor
*
out
)
{
UniformRandomRawKernel
<
T
>
(
dev_ctx
,
shape
,
dtype
,
min
,
max
,
seed
,
0
,
0
,
0.0
f
,
out
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
uniform_random_raw
,
CPU
,
ALL_LAYOUT
,
phi
::
UniformRandomRawKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
uniform_random
,
CPU
,
ALL_LAYOUT
,
phi
::
UniformRandomKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/funcs/aligned_vector.h
0 → 100644
浏览文件 @
b3466387
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.1 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.1
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/hostdevice.h"
namespace
phi
{
// Aligned vector generates vectorized load/store on CUDA.
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
AlignedVector
{
T
val
[
Size
];
HOSTDEVICE
inline
const
T
&
operator
[](
int
i
)
const
{
return
val
[
i
];
}
HOSTDEVICE
inline
T
&
operator
[](
int
i
)
{
return
val
[
i
];
}
};
template
<
typename
T
,
int
Size
>
HOSTDEVICE
inline
void
Load
(
const
T
*
addr
,
AlignedVector
<
T
,
Size
>*
vec
)
{
const
AlignedVector
<
T
,
Size
>*
addr_vec
=
reinterpret_cast
<
const
AlignedVector
<
T
,
Size
>*>
(
addr
);
*
vec
=
*
addr_vec
;
}
template
<
typename
T
,
int
Size
>
HOSTDEVICE
inline
void
Store
(
const
AlignedVector
<
T
,
Size
>&
vec
,
T
*
addr
)
{
AlignedVector
<
T
,
Size
>*
addr_vec
=
reinterpret_cast
<
AlignedVector
<
T
,
Size
>*>
(
addr
);
*
addr_vec
=
vec
;
}
/*
* Only the address of input data is the multiplier of 1,2,4, vectorized load
* with corresponding multiplier-value is possible. Moreover, the maximum length
* of vectorized load is 128 bits once. Hence, valid length of vectorized load
* shall be determined under both former constraints.
*/
template
<
typename
T
>
int
GetVectorizedSize
(
const
T
*
pointer
)
{
constexpr
int
max_load_bits
=
128
;
int
valid_vec_size
=
max_load_bits
/
CHAR_BIT
/
sizeof
(
T
);
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
constexpr
int
vec8
=
std
::
alignment_of
<
AlignedVector
<
T
,
8
>>::
value
;
// NOLINT
constexpr
int
vec4
=
std
::
alignment_of
<
AlignedVector
<
T
,
4
>>::
value
;
// NOLINT
constexpr
int
vec2
=
std
::
alignment_of
<
AlignedVector
<
T
,
2
>>::
value
;
// NOLINT
if
(
address
%
vec8
==
0
)
{
/*
* Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code
* below can be changed into " return std::min(8, valid_vec_size); " .
*/
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec4
==
0
)
{
return
std
::
min
(
4
,
valid_vec_size
);
}
else
if
(
address
%
vec2
==
0
)
{
return
std
::
min
(
2
,
valid_vec_size
);
}
else
{
return
1
;
}
}
}
// namespace phi
paddle/phi/kernels/funcs/distribution_helper.h
0 → 100644
浏览文件 @
b3466387
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#endif
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
#endif
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif
namespace
phi
{
namespace
distribution
{
/********************* Transformation Function **********************/
template
<
typename
T
>
struct
exponential_transform
{
explicit
exponential_transform
(
T
lambda
)
:
lambda_
(
lambda
)
{}
HOSTDEVICE
inline
T
operator
()(
T
val
)
const
{
#if defined(__NVCC__) || defined(__HIPCC__)
if
(
std
::
is_same
<
T
,
double
>::
value
)
{
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
log
(
val
);
}
else
{
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
__logf
(
val
);
}
#else
return
static_cast
<
T
>
(
-
1.0
)
/
lambda_
*
std
::
log
(
static_cast
<
T
>
(
1.0
)
-
val
);
#endif
}
private:
T
lambda_
;
};
template
<
typename
T
>
struct
uniform_transform
{
explicit
uniform_transform
(
T
min
,
T
max
)
:
range_
(
max
-
min
),
min_
(
min
)
{}
HOSTDEVICE
inline
T
operator
()(
T
val
)
const
{
if
(
UNLIKELY
(
val
==
static_cast
<
T
>
(
1.0
)))
{
return
min_
;
}
else
{
return
val
*
range_
+
min_
;
}
}
private:
T
range_
;
T
min_
;
};
template
<
typename
T
>
struct
normal_transform
{
explicit
normal_transform
(
T
mean
,
T
std
)
:
mean_
(
mean
),
std_
(
std
)
{}
HOSTDEVICE
inline
T
operator
()(
T
val
)
const
{
return
val
*
std_
+
mean_
;
}
private:
T
mean_
;
T
std_
;
};
#if defined(__NVCC__) || defined(__HIPCC__)
namespace
kps
=
phi
::
kps
;
/*********************** Distribution Function *************************/
template
<
typename
T
>
struct
uniform_distribution
;
template
<
typename
T
>
struct
normal_distribution
;
#if defined(__NVCC__)
template
<
>
struct
uniform_distribution
<
float
>
{
__device__
inline
float4
operator
()(
curandStatePhilox4_32_10_t
*
state
)
const
{
return
curand_uniform4
(
state
);
}
static
constexpr
int
kReturnsCount
=
4
;
};
template
<
>
struct
uniform_distribution
<
double
>
{
__device__
inline
double2
operator
()(
curandStatePhilox4_32_10_t
*
state
)
const
{
return
curand_uniform2_double
(
state
);
}
static
constexpr
int
kReturnsCount
=
2
;
};
template
<
>
struct
normal_distribution
<
float
>
{
__device__
inline
float4
operator
()(
curandStatePhilox4_32_10_t
*
state
)
const
{
return
curand_normal4
(
state
);
}
static
constexpr
int
kReturnsCount
=
4
;
};
template
<
>
struct
normal_distribution
<
double
>
{
__device__
inline
double2
operator
()(
curandStatePhilox4_32_10_t
*
state
)
const
{
return
curand_normal2_double
(
state
);
}
static
constexpr
int
kReturnsCount
=
2
;
};
#else
template
<
>
struct
uniform_distribution
<
float
>
{
__device__
inline
float4
operator
()(
hiprandStatePhilox4_32_10_t
*
state
)
const
{
return
hiprand_uniform4
(
state
);
}
static
constexpr
int
kReturnsCount
=
4
;
};
template
<
>
struct
uniform_distribution
<
double
>
{
__device__
inline
double2
operator
()(
hiprandStatePhilox4_32_10_t
*
state
)
const
{
return
hiprand_uniform2_double
(
state
);
}
static
constexpr
int
kReturnsCount
=
2
;
};
template
<
>
struct
normal_distribution
<
float
>
{
__device__
inline
float4
operator
()(
hiprandStatePhilox4_32_10_t
*
state
)
const
{
return
hiprand_normal4
(
state
);
}
static
constexpr
int
kReturnsCount
=
4
;
};
template
<
>
struct
normal_distribution
<
double
>
{
__device__
inline
double2
operator
()(
hiprandStatePhilox4_32_10_t
*
state
)
const
{
return
hiprand_normal2_double
(
state
);
}
static
constexpr
int
kReturnsCount
=
2
;
};
#endif
/******** Launch GPU function of distribution and transformation *********/
template
<
typename
T
,
typename
DistOp
,
typename
TransformOp
>
__global__
void
DistributionKernel
(
size_t
size
,
uint64_t
seed
,
uint64_t
offset
,
DistOp
dist
,
TransformOp
trans
,
T
*
out_data
,
size_t
stride
)
{
size_t
idx
=
static_cast
<
size_t
>
(
BLOCK_ID_X
*
BLOCK_NUM_X
);
static
constexpr
int
kCount
=
DistOp
::
kReturnsCount
;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
+
THREAD_ID_X
,
offset
,
&
state
);
using
SType
=
curandStatePhilox4_32_10_t
;
#else
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
+
THREAD_ID_X
,
offset
,
&
state
);
using
SType
=
hiprandStatePhilox4_32_10_t
;
#endif
size_t
total_thread
=
GRID_NUM_X
*
BLOCK_NUM_X
;
T
args
[
kCount
];
T
result
[
kCount
];
for
(
size_t
i
=
idx
;
i
<
size
;
i
+=
total_thread
*
kCount
)
{
kps
::
ElementwiseRandom
<
SType
,
T
,
kCount
,
1
,
DistOp
>
(
&
args
[
0
],
dist
,
&
state
);
kps
::
ElementwiseUnary
<
T
,
T
,
kCount
,
1
,
1
,
TransformOp
>
(
&
result
[
0
],
&
args
[
0
],
trans
);
kps
::
WriteData
<
T
,
T
,
kCount
,
1
,
1
,
true
>
(
out_data
+
i
,
&
result
[
0
],
size
-
i
,
1
,
stride
,
1
);
__syncthreads
();
}
}
template
<
typename
T
,
typename
DistOp
,
typename
TransformOp
>
void
distribution_and_transform
(
const
GPUContext
&
dev_ctx
,
DenseTensor
*
out
,
DistOp
dist
,
TransformOp
trans
)
{
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
size
=
out
->
numel
();
int64_t
device_id
=
dev_ctx
.
GetPlace
().
GetDeviceId
();
auto
gen_cuda
=
dev_ctx
.
GetGenerator
();
size_t
block_size
=
256
;
size_t
expect_grid_size
=
(
size
+
block_size
-
1
)
/
block_size
;
const
auto
&
prop
=
backends
::
gpu
::
GetDeviceProperties
(
device_id
);
size_t
max_grid_size
=
(
prop
.
maxThreadsPerMultiProcessor
/
block_size
)
*
prop
.
multiProcessorCount
;
size_t
grid_size
=
expect_grid_size
>
max_grid_size
?
max_grid_size
:
expect_grid_size
;
size_t
total_thread
=
block_size
*
grid_size
;
size_t
curand4_loop_times
=
(
size
+
4
*
total_thread
-
1
)
/
(
4
*
total_thread
);
// 'increment' shoulde be multiple of 4
uint64_t
increment
=
curand4_loop_times
*
4
;
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
increment
);
uint64_t
seed
=
seed_offset
.
first
;
uint64_t
offset
=
seed_offset
.
second
;
DistributionKernel
<
T
,
DistOp
,
TransformOp
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
size
,
seed
,
offset
,
dist
,
trans
,
out_data
,
total_thread
);
}
#endif
}
// namespace distribution
}
// namespace phi
paddle/phi/kernels/funcs/index_impl.cu.h
0 → 100644
浏览文件 @
b3466387
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace
phi
{
template
<
typename
T
,
typename
Functor
,
int
VecSize
>
__global__
void
VectorizedIndexKernel
(
T
*
out
,
size_t
numel
,
size_t
main_offset
,
Functor
func
)
{
size_t
data_offset
=
BLOCK_ID_X
*
BLOCK_NUM_X
*
VecSize
;
size_t
stride
=
BLOCK_NUM_X
*
GRID_NUM_X
*
VecSize
;
size_t
args
[
VecSize
];
T
result
[
VecSize
];
for
(;
data_offset
<
main_offset
;
data_offset
+=
stride
)
{
kps
::
InitWithDataIndex
<
size_t
,
VecSize
,
1
,
1
>
(
&
args
[
0
],
data_offset
);
kps
::
ElementwiseUnary
<
size_t
,
T
,
VecSize
,
1
,
1
,
Functor
>
(
&
result
[
0
],
&
args
[
0
],
func
);
kps
::
WriteData
<
T
,
VecSize
,
1
,
1
,
false
>
(
out
+
data_offset
,
&
result
[
0
],
BLOCK_NUM_X
*
VecSize
);
}
size_t
num
=
numel
-
data_offset
;
if
(
num
>
0
)
{
kps
::
InitWithDataIndex
<
size_t
,
VecSize
,
1
,
1
>
(
&
args
[
0
],
data_offset
);
kps
::
ElementwiseUnary
<
size_t
,
T
,
VecSize
,
1
,
1
,
Functor
>
(
&
result
[
0
],
&
args
[
0
],
func
);
kps
::
WriteData
<
T
,
VecSize
,
1
,
1
,
true
>
(
out
+
data_offset
,
&
result
[
0
],
num
);
}
}
template
<
typename
T
,
typename
Functor
>
void
IndexKernel
(
const
KPDevice
&
dev_ctx
,
DenseTensor
*
out
,
Functor
func
)
{
int
numel
=
out
->
numel
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
if
(
numel
<=
0
)
return
;
int
vec_size
=
phi
::
GetVectorizedSize
(
out_data
);
#ifdef PADDLE_WITH_XPU_KP
int
block
=
64
;
int
grid
=
8
;
auto
stream
=
dev_ctx
.
x_context
()
->
xpu_stream
;
#else
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
,
vec_size
);
int
grid
=
config
.
block_per_grid
.
x
;
int
block
=
config
.
thread_per_block
.
x
;
auto
stream
=
dev_ctx
.
stream
();
#endif
size_t
main_offset
=
(
numel
/
(
vec_size
*
block
))
*
vec_size
*
block
;
switch
(
vec_size
)
{
case
4
:
VectorizedIndexKernel
<
T
,
Functor
,
4
><<<
grid
,
block
,
0
,
stream
>>>
(
out_data
,
numel
,
main_offset
,
func
);
break
;
case
2
:
VectorizedIndexKernel
<
T
,
Functor
,
2
><<<
grid
,
block
,
0
,
stream
>>>
(
out_data
,
numel
,
main_offset
,
func
);
break
;
case
1
:
VectorizedIndexKernel
<
T
,
Functor
,
1
><<<
grid
,
block
,
0
,
stream
>>>
(
out_data
,
numel
,
main_offset
,
func
);
break
;
default:
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Unsupported vectorized size: %d !"
,
vec_size
));
break
;
}
}
}
}
// namespace phi
paddle/phi/kernels/gpu/uniform_random_kernel.cu
0 → 100644
浏览文件 @
b3466387
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/uniform_random_kernel.h"
#include "gflags/gflags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
DECLARE_bool
(
use_curand
);
namespace
phi
{
template
<
typename
T
>
struct
UniformGenerator
{
T
min_
,
max_
;
unsigned
int
seed_
;
T
diag_val_
;
unsigned
int
diag_num_
;
unsigned
int
diag_step_
;
__host__
__device__
UniformGenerator
(
T
min
,
T
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
T
diag_val
)
:
min_
(
min
),
max_
(
max
),
seed_
(
seed
),
diag_num_
(
diag_num
),
diag_step_
(
diag_step
),
diag_val_
(
diag_val
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
);
T
out
=
dist
(
rng
);
unsigned
int
remainder
=
n
%
(
diag_step_
+
1
);
if
(
remainder
==
0
&&
diag_num_
>
n
/
(
diag_step_
+
1
))
{
out
=
diag_val_
;
}
return
out
;
}
};
template
<
typename
T
>
struct
UniformGeneratorOffset
{
T
min_
,
max_
;
unsigned
int
seed_
;
T
diag_val_
;
unsigned
int
diag_num_
;
unsigned
int
diag_step_
;
int
offset_
;
__host__
__device__
UniformGeneratorOffset
(
T
min
,
T
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
T
diag_val
,
int
offset
)
:
min_
(
min
),
max_
(
max
),
seed_
(
seed
),
diag_num_
(
diag_num
),
diag_step_
(
diag_step
),
diag_val_
(
diag_val
),
offset_
(
offset
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
+
offset_
);
T
out
=
dist
(
rng
);
unsigned
int
remainder
=
n
%
(
diag_step_
+
1
);
if
(
remainder
==
0
&&
diag_num_
>
n
/
(
diag_step_
+
1
))
{
out
=
diag_val_
;
}
return
out
;
}
};
template
<
typename
T
,
typename
Context
>
void
UniformRandomRawKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
float
diag_val
,
DenseTensor
*
out
)
{
out
->
Resize
(
phi
::
make_ddim
(
shape
.
GetData
()));
T
*
data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
size
=
out
->
numel
();
bool
seed_flag
=
false
;
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
seed_flag
=
true
;
}
auto
generator
=
dev_ctx
.
GetGenerator
();
if
(
generator
->
GetIsInitPy
()
&&
seed_flag
)
{
if
(
FLAGS_use_curand
)
{
using
MT
=
typename
kps
::
details
::
MPTypeTrait
<
T
>::
Type
;
distribution
::
uniform_distribution
<
MT
>
dist
;
distribution
::
uniform_transform
<
MT
>
trans
(
min
,
max
);
distribution
::
distribution_and_transform
<
T
>
(
dev_ctx
,
out
,
dist
,
trans
);
}
else
{
auto
seed_offset
=
generator
->
IncrementOffset
(
1
);
int64_t
gen_offset
=
size
*
seed_offset
.
second
;
auto
func
=
UniformGeneratorOffset
<
T
>
(
min
,
max
,
seed_offset
.
first
,
diag_num
,
diag_step
,
diag_val
,
gen_offset
);
IndexKernel
<
T
,
UniformGeneratorOffset
<
T
>>
(
dev_ctx
,
out
,
func
);
}
}
else
{
auto
func
=
UniformGenerator
<
T
>
(
min
,
max
,
seed
,
diag_num
,
diag_step
,
diag_val
);
IndexKernel
<
T
,
UniformGenerator
<
T
>>
(
dev_ctx
,
out
,
func
);
}
}
template
<
typename
T
,
typename
Context
>
void
UniformRandomKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
DenseTensor
*
out
)
{
UniformRandomRawKernel
<
T
>
(
dev_ctx
,
shape
,
dtype
,
min
,
max
,
seed
,
0
,
0
,
0.0
f
,
out
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
uniform_random_raw
,
GPU
,
ALL_LAYOUT
,
phi
::
UniformRandomRawKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
uniform_random
,
GPU
,
ALL_LAYOUT
,
phi
::
UniformRandomKernel
,
float
,
double
)
{}
paddle/phi/kernels/selected_rows/uniform_random_kernel.cc
0 → 100644
浏览文件 @
b3466387
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/uniform_random_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
UniformRandomRawSRKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
float
diag_val
,
SelectedRows
*
out
)
{
phi
::
UniformRandomRawKernel
<
T
>
(
dev_ctx
,
shape
,
dtype
,
min
,
max
,
seed
,
diag_num
,
diag_step
,
diag_val
,
out
->
mutable_value
());
}
template
<
typename
T
,
typename
Context
>
void
UniformRandomSRKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
SelectedRows
*
out
)
{
phi
::
UniformRandomKernel
<
T
>
(
dev_ctx
,
shape
,
dtype
,
min
,
max
,
seed
,
out
->
mutable_value
());
}
}
// namespace phi
PD_REGISTER_KERNEL
(
uniform_random_raw_sr
,
CPU
,
ALL_LAYOUT
,
phi
::
UniformRandomRawSRKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
PD_REGISTER_KERNEL
(
uniform_random_sr
,
CPU
,
ALL_LAYOUT
,
phi
::
UniformRandomSRKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
uniform_random_raw_sr
,
GPU
,
ALL_LAYOUT
,
phi
::
UniformRandomRawSRKernel
,
float
,
double
)
{}
PD_REGISTER_KERNEL
(
uniform_random_sr
,
GPU
,
ALL_LAYOUT
,
phi
::
UniformRandomSRKernel
,
float
,
double
)
{}
#endif
paddle/phi/kernels/uniform_random_kernel.h
0 → 100644
浏览文件 @
b3466387
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/selected_rows.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
UniformRandomRawKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
float
diag_val
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
UniformRandomKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
UniformRandomRawSRKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
float
diag_val
,
SelectedRows
*
out
);
template
<
typename
T
,
typename
Context
>
void
UniformRandomSRKernel
(
const
Context
&
dev_ctx
,
const
ScalarArray
&
shape
,
DataType
dtype
,
float
min
,
float
max
,
int
seed
,
SelectedRows
*
out
);
}
// namespace phi
paddle/phi/ops/compat/uniform_random_sig.cc
0 → 100644
浏览文件 @
b3466387
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
UniformRandomOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
int
diag_num
=
paddle
::
any_cast
<
int
>
(
ctx
.
Attr
(
"diag_num"
));
if
(
ctx
.
IsDenseTensorOutput
(
"Out"
))
{
if
(
diag_num
)
{
if
(
ctx
.
InputSize
(
"ShapeTensorList"
)
>
0
)
{
return
KernelSignature
(
"uniform_random_raw"
,
{},
{
"ShapeTensorList"
,
"dtype"
,
"min"
,
"max"
,
"seed"
,
"diag_num"
,
"diag_step"
,
"diag_val"
},
{
"Out"
});
}
else
{
const
auto
&
shape
=
paddle
::
any_cast
<
std
::
vector
<
int64_t
>>
(
ctx
.
Attr
(
"shape"
));
if
(
ctx
.
HasInput
(
"ShapeTensor"
)
&&
shape
.
empty
())
{
return
KernelSignature
(
"uniform_random_raw"
,
{},
{
"ShapeTensor"
,
"dtype"
,
"min"
,
"max"
,
"seed"
,
"diag_num"
,
"diag_step"
,
"diag_val"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"uniform_random_raw"
,
{},
{
"shape"
,
"dtype"
,
"min"
,
"max"
,
"seed"
,
"diag_num"
,
"diag_step"
,
"diag_val"
},
{
"Out"
});
}
}
}
else
{
if
(
ctx
.
InputSize
(
"ShapeTensorList"
)
>
0
)
{
return
KernelSignature
(
"uniform_random"
,
{},
{
"ShapeTensorList"
,
"dtype"
,
"min"
,
"max"
,
"seed"
},
{
"Out"
});
}
else
{
const
auto
&
shape
=
paddle
::
any_cast
<
std
::
vector
<
int64_t
>>
(
ctx
.
Attr
(
"shape"
));
if
(
ctx
.
HasInput
(
"ShapeTensor"
)
&&
shape
.
empty
())
{
return
KernelSignature
(
"uniform_random"
,
{},
{
"ShapeTensor"
,
"dtype"
,
"min"
,
"max"
,
"seed"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"uniform_random"
,
{},
{
"shape"
,
"dtype"
,
"min"
,
"max"
,
"seed"
},
{
"Out"
});
}
}
}
}
else
if
(
ctx
.
IsSelectedRowsOutput
(
"Out"
))
{
if
(
diag_num
)
{
if
(
ctx
.
InputSize
(
"ShapeTensorList"
)
>
0
)
{
return
KernelSignature
(
"uniform_random_raw_sr"
,
{},
{
"ShapeTensorList"
,
"dtype"
,
"min"
,
"max"
,
"seed"
,
"diag_num"
,
"diag_step"
,
"diag_val"
},
{
"Out"
});
}
else
{
const
auto
&
shape
=
paddle
::
any_cast
<
std
::
vector
<
int64_t
>>
(
ctx
.
Attr
(
"shape"
));
if
(
ctx
.
HasInput
(
"ShapeTensor"
)
&&
shape
.
empty
())
{
return
KernelSignature
(
"uniform_random_raw_sr"
,
{},
{
"ShapeTensor"
,
"dtype"
,
"min"
,
"max"
,
"seed"
,
"diag_num"
,
"diag_step"
,
"diag_val"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"uniform_random_raw_sr"
,
{},
{
"shape"
,
"dtype"
,
"min"
,
"max"
,
"seed"
,
"diag_num"
,
"diag_step"
,
"diag_val"
},
{
"Out"
});
}
}
}
else
{
if
(
ctx
.
InputSize
(
"ShapeTensorList"
)
>
0
)
{
return
KernelSignature
(
"uniform_random_sr"
,
{},
{
"ShapeTensorList"
,
"dtype"
,
"min"
,
"max"
,
"seed"
},
{
"Out"
});
}
else
{
const
auto
&
shape
=
paddle
::
any_cast
<
std
::
vector
<
int64_t
>>
(
ctx
.
Attr
(
"shape"
));
if
(
ctx
.
HasInput
(
"ShapeTensor"
)
&&
shape
.
empty
())
{
return
KernelSignature
(
"uniform_random_sr"
,
{},
{
"ShapeTensor"
,
"dtype"
,
"min"
,
"max"
,
"seed"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"uniform_random_sr"
,
{},
{
"shape"
,
"dtype"
,
"min"
,
"max"
,
"seed"
},
{
"Out"
});
}
}
}
}
return
KernelSignature
(
"unregistered"
,
{},
{},
{});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
uniform_random
,
phi
::
UniformRandomOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录