Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
53bdee64
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
53bdee64
编写于
5月 13, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
5月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add tensor support for gaussian_random_op test=develop (#24389)
上级
da4a1db7
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
354 addition
and
133 deletion
+354
-133
paddle/fluid/operators/fill_constant_op.h
paddle/fluid/operators/fill_constant_op.h
+7
-28
paddle/fluid/operators/gaussian_random_op.cc
paddle/fluid/operators/gaussian_random_op.cc
+74
-10
paddle/fluid/operators/gaussian_random_op.cu
paddle/fluid/operators/gaussian_random_op.cu
+30
-4
paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc
paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc
+6
-1
python/paddle/fluid/layers/distributions.py
python/paddle/fluid/layers/distributions.py
+3
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+55
-22
python/paddle/fluid/tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py
.../tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py
+5
-7
python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
...n/paddle/fluid/tests/unittests/test_gaussian_random_op.py
+174
-59
未找到文件。
paddle/fluid/operators/fill_constant_op.h
浏览文件 @
53bdee64
...
...
@@ -20,48 +20,26 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/utils.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
inline
framework
::
DDim
GetShape
(
const
framework
::
ExecutionContext
&
ctx
)
{
inline
framework
::
DDim
GetShape
(
const
framework
::
ExecutionContext
&
ctx
,
std
::
string
op_type
)
{
// 1. shape is a Tensor
if
(
ctx
.
HasInput
(
"ShapeTensor"
))
{
auto
*
shape_tensor
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"ShapeTensor"
);
auto
*
shape_data
=
shape_tensor
->
data
<
int
>
();
framework
::
Tensor
cpu_shape_tensor
;
if
(
platform
::
is_gpu_place
(
shape_tensor
->
place
()))
{
TensorCopySync
(
*
shape_tensor
,
platform
::
CPUPlace
(),
&
cpu_shape_tensor
);
shape_data
=
cpu_shape_tensor
.
data
<
int
>
();
}
auto
vec_shape
=
std
::
vector
<
int
>
(
shape_data
,
shape_data
+
shape_tensor
->
numel
());
auto
vec_shape
=
GetDataFromTensor
<
int
>
(
shape_tensor
);
return
framework
::
make_ddim
(
vec_shape
);
}
// 2. shape is a list/tuple containing Tensor
auto
shape_tensor_list
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"ShapeTensorList"
);
if
(
shape_tensor_list
.
size
()
>
0
)
{
std
::
vector
<
int
>
vec_shape
;
for
(
size_t
i
=
0
;
i
<
shape_tensor_list
.
size
();
++
i
)
{
auto
tensor
=
shape_tensor_list
[
i
];
PADDLE_ENFORCE_EQ
(
tensor
->
dims
(),
framework
::
make_ddim
({
1
}),
platform
::
errors
::
InvalidArgument
(
"If the element type of 'shape'(tensor_list type) in "
"FillConstantOp is Tensor, the shape of this Tensor element must "
"be [1]. But received the Tensor element's shape is [%s]"
,
tensor
->
dims
()));
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
framework
::
Tensor
temp
;
TensorCopySync
(
*
tensor
,
platform
::
CPUPlace
(),
&
temp
);
vec_shape
.
push_back
(
*
temp
.
data
<
int
>
());
}
else
{
vec_shape
.
push_back
(
*
tensor
->
data
<
int
>
());
}
}
auto
vec_shape
=
GetDataFromTensorList
(
shape_tensor_list
);
return
framework
::
make_ddim
(
vec_shape
);
}
...
...
@@ -115,7 +93,8 @@ class FillConstantKernel : public framework::OpKernel<T> {
}
value
=
tensor_data
[
0
];
}
auto
shape
=
GetShape
(
ctx
);
const
std
::
string
op_type
=
"fill_constant"
;
auto
shape
=
GetShape
(
ctx
,
op_type
);
if
(
out_var
->
IsType
<
framework
::
LoDTensor
>
())
{
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/gaussian_random_op.cc
浏览文件 @
53bdee64
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include <random>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
...
...
@@ -22,8 +22,37 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
CPUGaussianRandomKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
Attr
<
float
>
(
"mean"
);
float
std
=
context
.
Attr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
seed
=
std
::
random_device
()();
}
engine
.
seed
(
seed
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
const
std
::
string
op_type
=
"gaussian_random"
;
auto
shape
=
GetShape
(
context
,
op_type
);
tensor
->
Resize
(
shape
);
int64_t
size
=
tensor
->
numel
();
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
}
}
};
template
<
typename
T
>
class
CPUGaussianRandomBatchSizeLikeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
Attr
<
float
>
(
"mean"
);
...
...
@@ -58,12 +87,26 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
for
(
auto
dim
:
shape
)
{
temp
.
push_back
(
static_cast
<
int64_t
>
(
dim
));
}
PADDLE_ENFORCE_GT
(
shape
.
size
(),
0UL
,
platform
::
errors
::
InvalidArgument
(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0, but reveived shape.size() is %d"
,
shape
.
size
()));
if
(
shape
.
empty
()
&&
ctx
->
HasInput
(
"ShapeTensor"
))
{
auto
shape_dims
=
ctx
->
GetInputDim
(
"ShapeTensor"
);
int
num_ele
=
1
;
for
(
int
i
=
0
;
i
<
shape_dims
.
size
();
++
i
)
{
num_ele
*=
shape_dims
[
i
];
}
auto
vec_dims
=
std
::
vector
<
int
>
(
num_ele
,
-
1
);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
vec_dims
));
return
;
}
if
(
!
(
ctx
->
HasInput
(
"ShapeTensor"
)
&&
!
ctx
->
HasInputs
(
"ShapeTensorList"
)))
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
0UL
,
platform
::
errors
::
InvalidArgument
(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0, but reveived shape.size() is %d"
,
shape
.
size
()));
}
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
}
...
...
@@ -85,6 +128,16 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
ctx
.
Attr
<
int
>
(
"dtype"
)),
ctx
.
device_context
(),
layout
,
library
);
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
if
(
var_name
==
"ShapeTensor"
||
var_name
==
"ShapeTensorList"
)
{
return
expected_kernel_type
;
}
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
};
class
GaussianRandomOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -94,7 +147,18 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
std
::
vector
<
int64_t
>>
(
"shape"
,
"(vector<int64_t>) "
"The dimension of random tensor."
);
"The dimension of random tensor."
)
.
SetDefault
({});
AddInput
(
"ShapeTensor"
,
"(Tensor<int>), optional). The shape of the output."
"It has a higher priority than Attr(shape)."
)
.
AsDispensable
();
AddInput
(
"ShapeTensorList"
,
"(vector<Tensor<int>>, optional). The shape of the output. "
"It has a higher priority than Attr(shape)."
"The shape of the element in vector must be [1]."
)
.
AsDuplicable
()
.
AsDispensable
();
AddAttr
<
float
>
(
"mean"
,
"(float, default 0.0) "
"mean of random tensor."
)
...
...
@@ -135,5 +199,5 @@ REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
CPUGaussianRandomKernel
<
float
>
,
ops
::
CPUGaussianRandomKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random_batch_size_like
,
ops
::
CPUGaussianRandomKernel
<
float
>
,
ops
::
CPUGaussianRandomKernel
<
double
>
);
ops
::
CPUGaussianRandom
BatchSizeLike
Kernel
<
float
>
,
ops
::
CPUGaussianRandom
BatchSizeLike
Kernel
<
double
>
);
paddle/fluid/operators/gaussian_random_op.cu
浏览文件 @
53bdee64
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <thrust/transform.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -41,7 +42,6 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
if
(
seed
==
0
)
{
std
::
random_device
rd
;
...
...
@@ -50,6 +50,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
T
mean
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"std"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
const
std
::
string
op_type
=
"gaussian_random"
;
auto
shape
=
GetShape
(
context
,
op_type
);
tensor
->
Resize
(
shape
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
size
=
tensor
->
numel
();
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
...
...
@@ -57,12 +62,33 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
>
class
GPUGaussianRandomBatchSizeLikeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
}
T
mean
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"std"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int64_t
size
=
tensor
->
numel
();
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
gaussian_random
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
float
>
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
gaussian_random_batch_size_like
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
float
>
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
gaussian_random_batch_size_like
,
paddle
::
operators
::
GPUGaussianRandomBatchSizeLikeKernel
<
float
>
,
paddle
::
operators
::
GPUGaussianRandomBatchSizeLikeKernel
<
double
>
);
paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc
浏览文件 @
53bdee64
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/mean_op.h"
namespace
paddle
{
...
...
@@ -26,7 +27,6 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
float
mean
=
context
.
Attr
<
float
>
(
"mean"
);
float
std
=
context
.
Attr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
...
...
@@ -35,6 +35,11 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
engine
.
seed
(
seed
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
const
std
::
string
op_type
=
"gaussian_random"
;
auto
shape
=
GetShape
(
context
,
op_type
);
tensor
->
Resize
(
shape
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
size
=
tensor
->
numel
();
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
...
...
python/paddle/fluid/layers/distributions.py
浏览文件 @
53bdee64
...
...
@@ -357,8 +357,9 @@ class Normal(Distribution):
output_shape
=
shape
+
batch_shape
zero_tmp
=
tensor
.
fill_constant_batch_size_like
(
self
.
loc
+
self
.
scale
,
batch_shape
+
shape
,
self
.
loc
.
dtype
,
0.
)
normal_random_tmp
=
nn
.
gaussian_random_batch_size_like
(
zero_tmp
,
zero_tmp
.
shape
,
mean
=
0.
,
std
=
1.
,
seed
=
seed
)
zero_tmp_shape
=
nn
.
shape
(
zero_tmp
)
normal_random_tmp
=
nn
.
gaussian_random
(
zero_tmp_shape
,
mean
=
0.
,
std
=
1.
,
seed
=
seed
)
output
=
normal_random_tmp
*
(
zero_tmp
+
self
.
scale
)
+
self
.
loc
return
nn
.
reshape
(
output
,
output_shape
)
else
:
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
53bdee64
...
...
@@ -10169,33 +10169,55 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
Generate a random tensor whose data is drawn from a Gaussian distribution.
Args:
shape (
Tuple[int] | List[int
]): Shape of the generated random tensor.
shape (
tuple[int] | list[int] | Variable | list[Variable
]): Shape of the generated random tensor.
mean (float): Mean of the random tensor, defaults to 0.0.
std (float): Standard deviation of the random tensor, defaults to 1.0.
seed (int): ${seed_comment}
dtype(np.dtype | core.VarDesc.VarType | str): Output data type, float32 or float64.
Returns:
Variable: Random tensor whose data is drawn from a Gaussian distribution, dtype: flaot32 or float64 as specified.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
result_1 = fluid.layers.gaussian_random(shape=[3, 4])
# example 2:
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3)
dim_2 = fluid.layers.fill_constant([1],"int32",5)
result_2 = fluid.layers.gaussian_random(shape=[dim_1, dim_2])
# example 3:
# attr shape is a Variable, the data type must be int64 or int32.
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = fluid.layers.gaussian_random(var_shape)
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = fluid.layers.gaussian_random(var_shape_int32)
.. code-block:: python
# declarative mode
import numpy as np
from paddle import fluid
x = fluid.layers.gaussian_random((2, 3), std=2., seed=10)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
start = fluid.default_startup_program()
main = fluid.default_main_program()
exe.run(start)
x_np, = exe.run(main, feed={}, fetch_list=[x])
...
...
@@ -10209,33 +10231,44 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
place = fluid.CPUPlace()
with dg.guard(place) as g:
x = fluid.layers.gaussian_random((2, 4), mean=2., dtype="float32", seed=10)
x_np = x.numpy()
x_np = x.numpy()
x_np
# array([[2.3060477 , 2.676496 , 3.9911983 , 0.9990833 ],
# [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32)
"""
helper = LayerHelper('gaussian_random', **locals())
check_type(shape, 'shape', (list, tuple), 'fluid.layers.gaussian_random')
check_dtype(dtype, 'dtype', ['float32', 'float64'],
'fluid.layers.gaussian_random')
out = helper.create_variable_for_type_inference(dtype)
if not isinstance(shape, (list, tuple, Variable)):
raise TypeError(
"The type of 'shape' in fill_constant must be Variable, list or tuple, but "
"received %s." % (type(shape)))
c_dtype = convert_np_dtype_to_dtype_(dtype)
attrs = {
'mean': mean,
'std': std,
'seed': seed,
'dtype': c_dtype,
'use_mkldnn': False
}
inputs = {}
utils._get_shape_tensor_inputs(
inputs=inputs,
helper=helper,
attrs=attrs,
shape=shape,
op_type='gaussian_random')
helper.append_op(
type='gaussian_random',
inputs=inputs,
outputs={'Out': out},
attrs={
'shape': shape,
'mean': mean,
'std': std,
'seed': seed,
'dtype': c_dtype,
'use_mkldnn': False
})
attrs=attrs)
return out
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py
浏览文件 @
53bdee64
...
...
@@ -27,17 +27,15 @@ class TestMKLDNNGaussianRandomOpSeed10(TestGaussianRandomOp):
class
TestMKLDNNGaussianRandomOpSeed0
(
TestGaussianRandomOp
):
def
setUp
(
self
):
TestGaussianRandomOp
.
setUp
(
self
)
self
.
use_mkldnn
=
True
self
.
attrs
=
{
"shape"
:
[
1
000
,
784
],
"mean"
:
.
0
,
"std"
:
1.
,
"seed"
:
0
,
"shape"
:
[
1
23
,
92
],
"mean"
:
1
.0
,
"std"
:
2.0
,
"seed"
:
1
0
,
"use_mkldnn"
:
self
.
use_mkldnn
}
def
init_kernel_type
(
self
):
self
.
use_mkldnn
=
True
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
浏览文件 @
53bdee64
...
...
@@ -15,98 +15,213 @@
from
__future__
import
print_function
import
unittest
import
numpy
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
paddle.fluid.executor
import
Executor
from
op_test
import
OpTest
class
TestGaussianRandomOp
(
unittest
.
TestCase
):
class
TestGaussianRandomOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"gaussian_random"
self
.
inputs
=
{}
self
.
use_mkldnn
=
False
self
.
init_kernel_type
()
self
.
attrs
=
{
"shape"
:
[
1
000
,
784
],
"mean"
:
.
0
,
"std"
:
1
.
,
"shape"
:
[
1
23
,
92
],
"mean"
:
1
.0
,
"std"
:
2
.
,
"seed"
:
10
,
"use_mkldnn"
:
self
.
use_mkldnn
}
self
.
outputs
=
[
"Out"
]
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
123
,
92
),
dtype
=
'float32'
)}
def
test_c
pu
(
self
):
self
.
gaussian_random_test
(
place
=
fluid
.
CPUPlace
()
)
def
test_c
heck_output
(
self
):
self
.
check_output_customized
(
self
.
verify_output
)
def
test_gpu
(
self
):
if
core
.
is_compiled_with_cuda
():
self
.
gaussian_random_test
(
place
=
fluid
.
CUDAPlace
(
0
))
def
verify_output
(
self
,
outs
):
self
.
assertEqual
(
outs
[
0
].
shape
,
(
123
,
92
))
hist
,
_
=
np
.
histogram
(
outs
[
0
],
range
=
(
-
3
,
5
))
hist
=
hist
.
astype
(
"float32"
)
hist
/=
float
(
outs
[
0
].
size
)
data
=
np
.
random
.
normal
(
size
=
(
123
,
92
),
loc
=
1
,
scale
=
2
)
hist2
,
_
=
np
.
histogram
(
data
,
range
=
(
-
3
,
5
))
hist2
=
hist2
.
astype
(
"float32"
)
hist2
/=
float
(
outs
[
0
].
size
)
self
.
assertTrue
(
np
.
allclose
(
hist
,
hist2
,
rtol
=
0
,
atol
=
0.01
),
"hist: "
+
str
(
hist
)
+
" hist2: "
+
str
(
hist2
))
def
gaussian_random_test
(
self
,
place
):
program
=
fluid
.
Program
()
block
=
program
.
global_block
()
vout
=
block
.
create_var
(
name
=
"Out"
)
op
=
block
.
append_op
(
type
=
self
.
op_type
,
outputs
=
{
"Out"
:
vout
},
attrs
=
self
.
attrs
)
# Situation 2: Attr(shape) is a list(with tensor)
class
TestGaussianRandomOp_ShapeTensorList
(
TestGaussianRandomOp
):
def
setUp
(
self
):
'''Test gaussian_random op with specified value
'''
self
.
op_type
=
"gaussian_random"
self
.
init_data
()
shape_tensor_list
=
[]
for
index
,
ele
in
enumerate
(
self
.
shape
):
shape_tensor_list
.
append
((
"x"
+
str
(
index
),
np
.
ones
(
(
1
)).
astype
(
'int32'
)
*
ele
))
op
.
desc
.
infer_var_type
(
block
.
desc
)
op
.
desc
.
infer_shape
(
block
.
desc
)
self
.
attrs
=
{
'shape'
:
self
.
infer_shape
,
'mean'
:
self
.
mean
,
'std'
:
self
.
std
,
'seed'
:
self
.
seed
,
'use_mkldnn'
:
self
.
use_mkldnn
}
fetch_list
=
[]
for
var_name
in
self
.
outputs
:
fetch_list
.
append
(
block
.
var
(
var_name
))
self
.
inputs
=
{
"ShapeTensorList"
:
shape_tensor_list
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
123
,
92
),
dtype
=
'float32'
)}
exe
=
Executor
(
place
)
outs
=
exe
.
run
(
program
,
fetch_list
=
fetch_list
)
tensor
=
outs
[
0
]
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
infer_shape
=
[
-
1
,
92
]
self
.
use_mkldnn
=
False
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
self
.
assertAlmostEqual
(
numpy
.
mean
(
tensor
),
.
0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
numpy
.
std
(
tensor
),
1.
,
delta
=
0.1
)
def
test_check_output
(
self
):
self
.
check_output_customized
(
self
.
verify_output
)
def
init_kernel_type
(
self
):
pass
class
TestGaussianRandomOp2_ShapeTensorList
(
TestGaussianRandomOp_ShapeTensorList
):
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
infer_shape
=
[
-
1
,
-
1
]
self
.
use_mkldnn
=
False
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
class
TestGaussianRandomOp3_ShapeTensorList
(
TestGaussianRandomOp_ShapeTensorList
):
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
infer_shape
=
[
123
,
-
1
]
self
.
use_mkldnn
=
True
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
class
TestGaussianRandomOp4_ShapeTensorList
(
TestGaussianRandomOp_ShapeTensorList
):
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
infer_shape
=
[
123
,
-
1
]
self
.
use_mkldnn
=
False
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
class
TestGaussianRandomOpError
(
unittest
.
TestCase
):
# Situation 3: shape is a tensor
class
TestGaussianRandomOp1_ShapeTensor
(
TestGaussianRandomOp
):
def
setUp
(
self
):
'''Test gaussian_random op with specified value
'''
self
.
op_type
=
"gaussian_random"
self
.
in
puts
=
{}
self
.
in
it_data
()
self
.
use_mkldnn
=
False
self
.
inputs
=
{
"ShapeTensor"
:
np
.
array
(
self
.
shape
).
astype
(
"int32"
)}
self
.
attrs
=
{
"shape"
:
[
1000
,
784
],
"mean"
:
.
0
,
"std"
:
1.
,
"seed"
:
10
,
"use_mkldnn"
:
self
.
use_mkldnn
'mean'
:
self
.
mean
,
'std'
:
self
.
std
,
'seed'
:
self
.
seed
,
'use_mkldnn'
:
self
.
use_mkldnn
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
123
,
92
),
dtype
=
'float32'
)}
self
.
outputs
=
[
"Out"
]
def
test_errors
(
self
):
program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
fluid
.
Program
(),
program
):
input_data
=
numpy
.
random
.
random
((
2
,
4
)).
astype
(
"float32"
)
block
=
program
.
global_block
()
vout
=
block
.
create_var
(
name
=
"Out"
,
dtype
=
'int32'
)
normal_initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
1.0
,
seed
=
0
)
def
test_Variable
():
# the input type must be Variable
normal_initializer
(
input_data
)
self
.
assertRaises
(
TypeError
,
test_Variable
)
def
test_type
():
# dtype must be float32 or float64
normal_initializer
(
vout
)
self
.
assertRaises
(
TypeError
,
test_type
)
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
use_mkldnn
=
False
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
# Test python API
class
TestGaussianRandomAPI
(
unittest
.
TestCase
):
def
test_api
(
self
):
positive_2_int32
=
fluid
.
layers
.
fill_constant
([
1
],
"int32"
,
2000
)
positive_2_int64
=
fluid
.
layers
.
fill_constant
([
1
],
"int64"
,
500
)
shape_tensor_int32
=
fluid
.
data
(
name
=
"shape_tensor_int32"
,
shape
=
[
2
],
dtype
=
"int32"
)
shape_tensor_int64
=
fluid
.
data
(
name
=
"shape_tensor_int64"
,
shape
=
[
2
],
dtype
=
"int64"
)
out_1
=
fluid
.
layers
.
gaussian_random
(
shape
=
[
2000
,
500
],
dtype
=
"float32"
,
mean
=
0.0
,
std
=
1.0
,
seed
=
10
)
out_2
=
fluid
.
layers
.
gaussian_random
(
shape
=
[
2000
,
positive_2_int32
],
dtype
=
"float32"
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
out_3
=
fluid
.
layers
.
gaussian_random
(
shape
=
[
2000
,
positive_2_int64
],
dtype
=
"float32"
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
out_4
=
fluid
.
layers
.
gaussian_random
(
shape
=
shape_tensor_int32
,
dtype
=
"float32"
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
out_5
=
fluid
.
layers
.
gaussian_random
(
shape
=
shape_tensor_int64
,
dtype
=
"float32"
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
out_6
=
fluid
.
layers
.
gaussian_random
(
shape
=
shape_tensor_int64
,
dtype
=
np
.
float32
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
exe
=
fluid
.
Executor
(
place
=
fluid
.
CPUPlace
())
res_1
,
res_2
,
res_3
,
res_4
,
res_5
,
res_6
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"shape_tensor_int32"
:
np
.
array
([
2000
,
500
]).
astype
(
"int32"
),
"shape_tensor_int64"
:
np
.
array
([
2000
,
500
]).
astype
(
"int64"
),
},
fetch_list
=
[
out_1
,
out_2
,
out_3
,
out_4
,
out_5
,
out_6
])
self
.
assertAlmostEqual
(
np
.
mean
(
res_1
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_1
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_2
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_2
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_3
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_3
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_4
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_5
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_5
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_5
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_6
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_6
),
1.
,
delta
=
0.1
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录