Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
be29b8ee
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
be29b8ee
编写于
8月 27, 2021
作者:
J
JYChen
提交者:
GitHub
8月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add uniform_ op and UT (#33934)
上级
5a72cf43
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
673 addition
and
4 deletion
+673
-4
paddle/fluid/operators/uniform_random_inplace_op.cc
paddle/fluid/operators/uniform_random_inplace_op.cc
+181
-0
paddle/fluid/operators/uniform_random_inplace_op.cu
paddle/fluid/operators/uniform_random_inplace_op.cu
+171
-0
paddle/fluid/operators/uniform_random_inplace_op_xpu.cc
paddle/fluid/operators/uniform_random_inplace_op_xpu.cc
+96
-0
python/paddle/fluid/tests/unittests/test_uniform_random_inplace_op.py
...e/fluid/tests/unittests/test_uniform_random_inplace_op.py
+180
-0
python/paddle/tensor/__init__.py
python/paddle/tensor/__init__.py
+2
-0
python/paddle/tensor/random.py
python/paddle/tensor/random.py
+43
-4
未找到文件。
paddle/fluid/operators/uniform_random_inplace_op.cc
0 → 100644
浏览文件 @
be29b8ee
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
class
UniformRandomInplaceOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddComment
(
R"DOC(
This operator fills self tensor with random values sampled from a
uniform distribution. The random result is in a range of [min, max).
)DOC"
);
AddInput
(
"X"
,
"The input tensor."
);
AddOutput
(
"Out"
,
"The output tensor of uniform random op"
);
AddAttr
<
float
>
(
"min"
,
"Minimum value of uniform random. [default -1.0]."
)
.
SetDefault
(
-
1.0
f
);
AddAttr
<
float
>
(
"max"
,
"Maximun value of uniform random. [default 1.0]."
)
.
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"seed"
,
"Random seed used for generating samples. "
"If seed is 0, it will use the seed of the global default "
"generator (which can be set by paddle.seed). "
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. [default 0]."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"diag_num"
,
"The number of diag elements. Note that if "
"diag_num is 0, it means without diag init.[default 0]."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"diag_step"
,
"The step between two diag element.[default 0]."
)
.
SetDefault
(
0
);
AddAttr
<
float
>
(
"diag_val"
,
"The value of diag element. [default 1.0]."
)
.
SetDefault
(
1.0
f
);
}
};
class
UniformRandomInplaceOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"UniformRandomInplaceOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"UniformRandomInplaceOp"
);
PADDLE_ENFORCE_LT
(
ctx
->
Attrs
().
Get
<
float
>
(
"min"
),
ctx
->
Attrs
().
Get
<
float
>
(
"max"
),
platform
::
errors
::
InvalidArgument
(
"The uniform_random's min must less then max. But received min = "
"%f great than or equal max = %f."
,
ctx
->
Attrs
().
Get
<
float
>
(
"min"
),
ctx
->
Attrs
().
Get
<
float
>
(
"max"
)));
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
int
>
(
"diag_num"
),
0
,
platform
::
errors
::
InvalidArgument
(
"The uniform_random's diag_num must greater than or "
"equal 0. But recevied diag_num (%d) < 0."
,
ctx
->
Attrs
().
Get
<
int
>
(
"diag_num"
)));
PADDLE_ENFORCE_GE
(
ctx
->
Attrs
().
Get
<
int
>
(
"diag_step"
),
0
,
platform
::
errors
::
InvalidArgument
(
"The uniform_random's diag_step must greater than or "
"equal 0. But recevied diag_step (%d) < 0."
,
ctx
->
Attrs
().
Get
<
int
>
(
"diag_step"
)));
auto
xdim
=
ctx
->
GetInputDim
(
"X"
);
ctx
->
SetOutputDim
(
"Out"
,
xdim
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
};
template
<
typename
T
>
class
CPUUniformRandomInplaceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
out_var
=
ctx
.
OutputVar
(
"Out"
);
auto
*
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
T
*
data
=
tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
size
=
tensor
->
numel
();
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"min"
)),
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"max"
)));
auto
engine
=
paddle
::
framework
::
GetCPURandomEngine
(
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"seed"
)));
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
*
engine
);
}
}
};
class
UniformRandomInplaceOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{}
};
class
UniformRandomInplaceGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
"Out_Grad"
,
"UniformRandomInplaceGradOp"
);
auto
x_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
SetOutputDim
(
x_grad_name
,
x_dims
);
}
}
};
template
<
typename
T
>
class
UniformRandomInplaceGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
retv
)
const
override
{
retv
->
SetType
(
this
->
ForwardOpType
()
+
"_grad"
);
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetAttrMap
(
this
->
Attrs
());
}
};
template
<
typename
T
>
class
CPUUniformRandomInplaceGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
dx
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
dx
)
{
auto
*
data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
fill
(
data
,
data
+
dx
->
numel
(),
T
(
0
));
}
}
};
}
// namespace operators
}
// namespace paddle
DECLARE_INPLACE_OP_INFERER
(
UniformRandomInplaceInferer
,
{
"X"
,
"Out"
});
DECLARE_INPLACE_OP_INFERER
(
UniformRandomInplaceGradInplaceInferer
,
{
paddle
::
framework
::
GradVarName
(
"Out"
),
paddle
::
framework
::
GradVarName
(
"X"
)});
REGISTER_OPERATOR
(
uniform_random_inplace
,
paddle
::
operators
::
UniformRandomInplaceOp
,
paddle
::
operators
::
UniformRandomInplaceOpMaker
,
paddle
::
operators
::
UniformRandomInplaceGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
operators
::
UniformRandomInplaceGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
paddle
::
operators
::
UniformRandomInplaceOpVarTypeInference
,
UniformRandomInplaceInferer
);
REGISTER_OPERATOR
(
uniform_random_inplace_grad
,
paddle
::
operators
::
UniformRandomInplaceGradOp
,
UniformRandomInplaceGradInplaceInferer
);
REGISTER_OP_CPU_KERNEL
(
uniform_random_inplace
,
paddle
::
operators
::
CPUUniformRandomInplaceKernel
<
float
>
,
paddle
::
operators
::
CPUUniformRandomInplaceKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
uniform_random_inplace_grad
,
paddle
::
operators
::
CPUUniformRandomInplaceGradKernel
<
float
>
,
paddle
::
operators
::
CPUUniformRandomInplaceGradKernel
<
double
>
);
paddle/fluid/operators/uniform_random_inplace_op.cu
0 → 100644
浏览文件 @
be29b8ee
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
UniformGenerator
{
T
min_
,
max_
;
unsigned
int
seed_
;
T
diag_val_
;
unsigned
int
diag_num_
;
unsigned
int
diag_step_
;
__host__
__device__
UniformGenerator
(
T
min
,
T
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
T
diag_val
)
:
min_
(
min
),
max_
(
max
),
seed_
(
seed
),
diag_num_
(
diag_num
),
diag_step_
(
diag_step
),
diag_val_
(
diag_val
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
);
T
out
=
dist
(
rng
);
unsigned
int
remainder
=
n
%
(
diag_step_
+
1
);
if
(
remainder
==
0
&&
diag_num_
>
n
/
(
diag_step_
+
1
))
{
out
=
diag_val_
;
}
return
out
;
}
};
template
<
typename
T
>
struct
UniformGeneratorOffset
{
T
min_
,
max_
;
unsigned
int
seed_
;
T
diag_val_
;
unsigned
int
diag_num_
;
unsigned
int
diag_step_
;
int
offset_
;
__host__
__device__
UniformGeneratorOffset
(
T
min
,
T
max
,
int
seed
,
int
diag_num
,
int
diag_step
,
T
diag_val
,
int
offset
)
:
min_
(
min
),
max_
(
max
),
seed_
(
seed
),
diag_num_
(
diag_num
),
diag_step_
(
diag_step
),
diag_val_
(
diag_val
),
offset_
(
offset
)
{}
__host__
__device__
T
operator
()(
const
unsigned
int
n
)
const
{
thrust
::
minstd_rand
rng
;
rng
.
seed
(
seed_
);
thrust
::
uniform_real_distribution
<
T
>
dist
(
min_
,
max_
);
rng
.
discard
(
n
+
offset_
);
T
out
=
dist
(
rng
);
unsigned
int
remainder
=
n
%
(
diag_step_
+
1
);
if
(
remainder
==
0
&&
diag_num_
>
n
/
(
diag_step_
+
1
))
{
out
=
diag_val_
;
}
return
out
;
}
};
template
<
typename
T
>
__global__
void
fill_value
(
int64_t
size
,
T
*
data
,
float
value
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
size
;
idx
+=
blockDim
.
x
)
{
data
[
idx
]
=
static_cast
<
T
>
(
value
);
}
}
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random as uniform_random_op.cu.
template
<
typename
T
>
class
GPUUniformRandomInplaceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
out_var
=
ctx
.
OutputVar
(
"Out"
);
auto
*
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
T
*
data
=
tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"seed"
));
bool
seed_flag
=
false
;
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
seed_flag
=
true
;
}
T
min
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"min"
));
T
max
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"max"
));
unsigned
int
diag_num
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"diag_num"
));
unsigned
int
diag_step
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"diag_step"
));
T
diag_val
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"diag_val"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int64_t
size
=
tensor
->
numel
();
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
ctx
.
GetPlace
()).
GetDeviceId
();
auto
gen_cuda
=
framework
::
GetDefaultCUDAGenerator
(
device_id
);
if
(
gen_cuda
->
GetIsInitPy
()
&&
seed_flag
)
{
auto
seed_offset
=
gen_cuda
->
IncrementOffset
(
1
);
int
gen_offset
=
size
*
seed_offset
.
second
;
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
UniformGeneratorOffset
<
T
>
(
min
,
max
,
seed_offset
.
first
,
diag_num
,
diag_step
,
diag_val
,
gen_offset
));
}
else
{
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
UniformGenerator
<
T
>
(
min
,
max
,
seed
,
diag_num
,
diag_step
,
diag_val
));
}
}
};
template
<
typename
T
>
class
GPUUniformRandomInplaceGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#ifdef __HIPCC__
const
int64_t
kMaxBlockDim
=
256
;
#else
const
int64_t
kMaxBlockDim
=
512
;
#endif
auto
*
dx
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
size
=
dx
->
numel
();
int64_t
kBlockDim
=
std
::
min
(
size
,
kMaxBlockDim
);
fill_value
<
T
><<<
1
,
kBlockDim
,
0
>>>
(
size
,
data
,
static_cast
<
float
>
(
0
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
uniform_random_inplace
,
paddle
::
operators
::
GPUUniformRandomInplaceKernel
<
float
>
,
paddle
::
operators
::
GPUUniformRandomInplaceKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
uniform_random_inplace_grad
,
paddle
::
operators
::
GPUUniformRandomInplaceGradKernel
<
float
>
,
paddle
::
operators
::
GPUUniformRandomInplaceGradKernel
<
double
>
);
paddle/fluid/operators/uniform_random_inplace_op_xpu.cc
0 → 100644
浏览文件 @
be29b8ee
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
XPUUniformRandomInplaceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
out_var
=
ctx
.
OutputVar
(
"Out"
);
auto
*
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
T
*
data
=
tensor
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
size
=
tensor
->
numel
();
std
::
unique_ptr
<
T
[]
>
data_cpu
(
new
T
[
size
]);
std
::
uniform_real_distribution
<
T
>
dist
(
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"min"
)),
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"max"
)));
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"seed"
));
auto
engine
=
framework
::
GetCPURandomEngine
(
seed
);
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data_cpu
[
i
]
=
dist
(
*
engine
);
}
unsigned
int
diag_num
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"diag_num"
));
unsigned
int
diag_step
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"diag_step"
));
auto
diag_val
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"diag_val"
));
if
(
diag_num
>
0
)
{
PADDLE_ENFORCE_GT
(
size
,
(
diag_num
-
1
)
*
(
diag_step
+
1
),
platform
::
errors
::
InvalidArgument
(
"ShapeInvalid: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d"
,
diag_num
,
diag_step
,
(
diag_num
-
1
)
*
(
diag_step
+
1
),
size
));
for
(
int64_t
i
=
0
;
i
<
diag_num
;
++
i
)
{
int64_t
pos
=
i
*
diag_step
+
i
;
data_cpu
[
pos
]
=
diag_val
;
}
}
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
XPUPlace
,
ctx
.
GetPlace
()),
data
,
platform
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
data_cpu
.
get
()),
size
*
sizeof
(
T
));
}
};
template
<
typename
T
>
class
XPUUniformRandomInplaceGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
dx
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
if
(
dx
)
{
T
*
data
=
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
size
=
dx
->
numel
();
std
::
unique_ptr
<
T
[]
>
data_cpu
(
new
T
[
size
]);
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data_cpu
[
i
]
=
T
(
0
);
}
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
XPUPlace
,
ctx
.
GetPlace
()),
data
,
platform
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
data_cpu
.
get
()),
size
*
sizeof
(
T
));
}
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_XPU_KERNEL
(
uniform_random_inplace
,
paddle
::
operators
::
XPUUniformRandomInplaceKernel
<
float
>
);
REGISTER_OP_XPU_KERNEL
(
uniform_random_inplace_grad
,
paddle
::
operators
::
XPUUniformRandomInplaceGradKernel
<
float
>
);
#endif // PADDLE_WITH_XPU
python/paddle/fluid/tests/unittests/test_uniform_random_inplace_op.py
0 → 100644
浏览文件 @
be29b8ee
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
paddle.fluid
as
fluid
import
numpy
as
np
class
TestUniformRandomInplaceOpDtype
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
def
test_uniform_random_inplace_op_dtype
(
self
):
def
test_fp32
():
tensor_fp32
=
paddle
.
ones
(
self
.
shape
,
dtype
=
paddle
.
float32
)
tensor_fp32
.
uniform_
()
self
.
assertEqual
(
tensor_fp32
.
dtype
,
paddle
.
float32
)
def
test_fp64
():
tensor_fp64
=
paddle
.
ones
(
self
.
shape
,
paddle
.
float64
)
tensor_fp64
.
uniform_
()
self
.
assertEqual
(
tensor_fp64
.
dtype
,
paddle
.
float64
)
places
=
[
'cpu'
]
if
fluid
.
core
.
is_compiled_with_cuda
():
places
.
append
(
'gpu'
)
for
place
in
places
:
paddle
.
set_device
(
place
)
test_fp32
()
test_fp64
()
class
TestUniformRandomInplaceOpIsInplace
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
def
test_uniform_random_inplace_op_is_inplace
(
self
):
tensor_a
=
paddle
.
ones
(
self
.
shape
)
tensor_b
=
tensor_a
.
uniform_
()
self
.
assertTrue
(
tensor_a
is
tensor_b
)
class
TestUniformRandomInplaceOpSeedIsZero
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
self
.
seed
=
0
def
test_uniform_random_inplace_op_seed_is_zero
(
self
):
tensor
=
paddle
.
ones
(
self
.
shape
)
tensor
.
uniform_
(
seed
=
self
.
seed
)
tensor_data_first
=
tensor
.
numpy
()
tensor
.
uniform_
(
seed
=
self
.
seed
)
tensor_data_second
=
tensor
.
numpy
()
self
.
assertFalse
((
tensor_data_first
==
tensor_data_second
).
all
())
class
TestUniformRandomInplaceOpSeedIsNotZero
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
self
.
seed
=
10
def
test_uniform_random_inplace_op_seed_is_not_zero
(
self
):
tensor
=
paddle
.
ones
(
self
.
shape
)
tensor
.
uniform_
(
seed
=
self
.
seed
)
tensor_data_first
=
tensor
.
numpy
()
tensor
.
uniform_
(
seed
=
self
.
seed
)
tensor_data_second
=
tensor
.
numpy
()
self
.
assertTrue
((
tensor_data_first
==
tensor_data_second
).
all
())
class
TestUniformRandomInplaceOpWithinRange
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
self
.
min
=
-
2
self
.
max
=
1
self
.
seed
=
10
def
test_uniform_random_inplace_op_within_range
(
self
):
tensor
=
paddle
.
ones
(
self
.
shape
)
tensor
.
uniform_
(
min
=
self
.
min
,
max
=
self
.
max
,
seed
=
self
.
seed
)
tensor_data
=
tensor
.
numpy
()
self
.
assertTrue
((
tensor_data
>
self
.
min
).
all
()
and
(
tensor_data
<
self
.
max
).
all
())
class
TestUniformRandomInplaceOpShape
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
def
test_uniform_random_inplace_op_shape
(
self
):
tensor
=
paddle
.
ones
(
self
.
shape
)
tensor
.
uniform_
()
tensor_shape_np
=
np
.
array
(
tensor
.
shape
)
origin_shape
=
np
.
array
(
self
.
shape
)
self
.
assertTrue
((
tensor_shape_np
==
origin_shape
).
all
())
class
TestUniformRandomInplaceOpDistribution
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
self
.
min
=
-
3
self
.
max
=
5
self
.
seed
=
10
self
.
bins
=
100
def
test_uniform_random_inplace_op_distribution
(
self
):
tensor
=
paddle
.
ones
(
self
.
shape
)
tensor
.
uniform_
(
self
.
min
,
self
.
max
,
self
.
seed
)
hist
,
_
=
np
.
histogram
(
tensor
.
numpy
()[
0
],
bins
=
self
.
bins
)
prob
=
hist
/
float
(
self
.
shape
[
0
])
prob_expect
=
np
.
ones
((
self
.
bins
,
))
/
float
(
self
.
bins
)
self
.
assertTrue
(
np
.
allclose
(
prob
,
prob_expect
,
rtol
=
0
,
atol
=
1e-2
))
class
TestUniformRandomInplaceOpError
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
def
test_uniform_random_inplace_op_error
(
self
):
def
test_attr_error
():
tensor
=
paddle
.
ones
(
self
.
shape
)
tensor
.
uniform_
(
shape
=
self
.
shape
,
min
=-
2
,
max
=
2
)
self
.
assertRaises
(
TypeError
,
test_attr_error
)
class
TestUniformRandomInplaceOpEmptyTensor
(
unittest
.
TestCase
):
def
test_uniform_random_inplace_op_empty_tensor
(
self
):
places
=
[
'cpu'
]
if
fluid
.
core
.
is_compiled_with_cuda
():
places
.
append
(
'gpu'
)
test_shapes
=
[(
200
,
0
),
(
0
,
200
)]
for
place
in
places
:
paddle
.
set_device
(
place
)
for
test_shape
in
test_shapes
:
tensor
=
paddle
.
empty
(
shape
=
test_shape
)
tensor
.
uniform_
()
tensor_shape_np
=
np
.
array
(
tensor
.
shape
)
origin_shape
=
np
.
array
(
test_shape
)
self
.
assertTrue
((
tensor_shape_np
==
origin_shape
).
all
())
class
TestUniformRandomInplaceGrad
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
(
1000
,
784
)
def
test_uniform_random_inplace_grad
(
self
):
def
test_grad
():
tensor_a
=
paddle
.
ones
(
self
.
shape
)
tensor_a
.
stop_gradient
=
False
tensor_b
=
tensor_a
*
0.5
tensor_b
.
uniform_
(
min
=-
2
,
max
=
2
)
loss
=
tensor_b
.
sum
()
loss
.
backward
()
uniform_grad
=
tensor_b
.
grad
.
numpy
()
self
.
assertTrue
((
uniform_grad
==
0
).
all
())
places
=
[
'cpu'
]
if
fluid
.
core
.
is_compiled_with_cuda
():
places
.
append
(
'gpu'
)
for
place
in
places
:
paddle
.
set_device
(
place
)
test_grad
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/tensor/__init__.py
浏览文件 @
be29b8ee
...
@@ -180,6 +180,7 @@ from .random import multinomial # noqa: F401
...
@@ -180,6 +180,7 @@ from .random import multinomial # noqa: F401
from
.random
import
standard_normal
# noqa: F401
from
.random
import
standard_normal
# noqa: F401
from
.random
import
normal
# noqa: F401
from
.random
import
normal
# noqa: F401
from
.random
import
uniform
# noqa: F401
from
.random
import
uniform
# noqa: F401
from
.random
import
uniform_
# noqa: F401
from
.random
import
randn
# noqa: F401
from
.random
import
randn
# noqa: F401
from
.random
import
rand
# noqa: F401
from
.random
import
rand
# noqa: F401
from
.random
import
randint
# noqa: F401
from
.random
import
randint
# noqa: F401
...
@@ -371,6 +372,7 @@ tensor_method_func = [ #noqa
...
@@ -371,6 +372,7 @@ tensor_method_func = [ #noqa
'bitwise_xor'
,
'bitwise_xor'
,
'bitwise_not'
,
'bitwise_not'
,
'broadcast_tensors'
,
'broadcast_tensors'
,
'uniform_'
,
]
]
#this list used in math_op_patch.py for magic_method bind
#this list used in math_op_patch.py for magic_method bind
...
...
python/paddle/tensor/random.py
浏览文件 @
be29b8ee
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# TODO: define random functions
# TODO: define random functions
from
..fluid
import
core
from
..fluid
import
core
from
..fluid.framework
import
in_dygraph_mode
,
Variable
,
convert_np_dtype_to_dtype_
from
..fluid.framework
import
in_dygraph_mode
,
Variable
,
convert_np_dtype_to_dtype_
,
dygraph_only
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.data_feeder
import
check_variable_and_dtype
,
check_type
,
check_dtype
,
check_shape
from
..fluid.data_feeder
import
check_variable_and_dtype
,
check_type
,
check_dtype
,
check_shape
from
..fluid.layers
import
utils
from
..fluid.layers
import
utils
...
@@ -444,9 +444,9 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
...
@@ -444,9 +444,9 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
to generate, ``min`` is included in the range. Default is -1.0.
to generate, ``min`` is included in the range. Default is -1.0.
max(float|int, optional): The upper bound on the range of random values
max(float|int, optional): The upper bound on the range of random values
to generate, ``max`` is excluded in the range. Default is 1.0.
to generate, ``max`` is excluded in the range. Default is 1.0.
seed(int, optional): Random seed used for generating samples.
0 means
seed(int, optional): Random seed used for generating samples.
If seed is 0,
use a seed generated by the system. Note that if seed is not 0,
it will use the seed of the global default generator (which can be set by paddle.seed).
this operator will always generate the same random numbers every
Note that if seed is not 0,
this operator will always generate the same random numbers every
time. Default is 0.
time. Default is 0.
name(str, optional): The default value is None. Normally there is no
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
need for user to set this property. For more information, please
...
@@ -520,6 +520,45 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
...
@@ -520,6 +520,45 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
return
out
return
out
@
dygraph_only
def
uniform_
(
x
,
min
=-
1.0
,
max
=
1.0
,
seed
=
0
,
name
=
None
):
"""
This is the inplace version of OP ``uniform``, which returns a Tensor filled
with random values sampled from a uniform distribution. The output Tensor will
be inplaced with input ``x``. Please refer to :ref:`api_tensor_uniform`.
Args:
x(Tensor): The input tensor to be filled with random values.
min(float|int, optional): The lower bound on the range of random values
to generate, ``min`` is included in the range. Default is -1.0.
max(float|int, optional): The upper bound on the range of random values
to generate, ``max`` is excluded in the range. Default is 1.0.
seed(int, optional): Random seed used for generating samples. If seed is 0,
it will use the seed of the global default generator (which can be set by paddle.seed).
Note that if seed is not 0, this operator will always generate the same random numbers every
time. Default is 0.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: The input tensor x filled with random values sampled from a uniform
distribution in the range [``min``, ``max``).
Examples:
.. code-block:: python
import paddle
# example:
x = paddle.ones(shape=[3, 4])
x.uniform_()
print(x)
# [[ 0.84524226, 0.6921872, 0.56528175, 0.71690357], # random
# [-0.34646994, -0.45116323, -0.09902662, -0.11397249], # random
# [ 0.433519, 0.39483607, -0.8660099, 0.83664286]] # random
"""
return
core
.
ops
.
uniform_random_inplace_
(
x
,
'min'
,
min
,
'max'
,
max
,
'seed'
,
seed
)
def
randint
(
low
=
0
,
high
=
None
,
shape
=
[
1
],
dtype
=
None
,
name
=
None
):
def
randint
(
low
=
0
,
high
=
None
,
shape
=
[
1
],
dtype
=
None
,
name
=
None
):
"""
"""
This OP returns a Tensor filled with random integers from a discrete uniform
This OP returns a Tensor filled with random integers from a discrete uniform
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录