Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
53bdee64
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
53bdee64
编写于
5月 13, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
5月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add tensor support for gaussian_random_op test=develop (#24389)
上级
da4a1db7
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
354 addition
and
133 deletion
+354
-133
paddle/fluid/operators/fill_constant_op.h
paddle/fluid/operators/fill_constant_op.h
+7
-28
paddle/fluid/operators/gaussian_random_op.cc
paddle/fluid/operators/gaussian_random_op.cc
+74
-10
paddle/fluid/operators/gaussian_random_op.cu
paddle/fluid/operators/gaussian_random_op.cu
+30
-4
paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc
paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc
+6
-1
python/paddle/fluid/layers/distributions.py
python/paddle/fluid/layers/distributions.py
+3
-2
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+55
-22
python/paddle/fluid/tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py
.../tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py
+5
-7
python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
...n/paddle/fluid/tests/unittests/test_gaussian_random_op.py
+174
-59
未找到文件。
paddle/fluid/operators/fill_constant_op.h
浏览文件 @
53bdee64
...
...
@@ -20,48 +20,26 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/utils.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
inline
framework
::
DDim
GetShape
(
const
framework
::
ExecutionContext
&
ctx
)
{
inline
framework
::
DDim
GetShape
(
const
framework
::
ExecutionContext
&
ctx
,
std
::
string
op_type
)
{
// 1. shape is a Tensor
if
(
ctx
.
HasInput
(
"ShapeTensor"
))
{
auto
*
shape_tensor
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"ShapeTensor"
);
auto
*
shape_data
=
shape_tensor
->
data
<
int
>
();
framework
::
Tensor
cpu_shape_tensor
;
if
(
platform
::
is_gpu_place
(
shape_tensor
->
place
()))
{
TensorCopySync
(
*
shape_tensor
,
platform
::
CPUPlace
(),
&
cpu_shape_tensor
);
shape_data
=
cpu_shape_tensor
.
data
<
int
>
();
}
auto
vec_shape
=
std
::
vector
<
int
>
(
shape_data
,
shape_data
+
shape_tensor
->
numel
());
auto
vec_shape
=
GetDataFromTensor
<
int
>
(
shape_tensor
);
return
framework
::
make_ddim
(
vec_shape
);
}
// 2. shape is a list/tuple containing Tensor
auto
shape_tensor_list
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"ShapeTensorList"
);
if
(
shape_tensor_list
.
size
()
>
0
)
{
std
::
vector
<
int
>
vec_shape
;
for
(
size_t
i
=
0
;
i
<
shape_tensor_list
.
size
();
++
i
)
{
auto
tensor
=
shape_tensor_list
[
i
];
PADDLE_ENFORCE_EQ
(
tensor
->
dims
(),
framework
::
make_ddim
({
1
}),
platform
::
errors
::
InvalidArgument
(
"If the element type of 'shape'(tensor_list type) in "
"FillConstantOp is Tensor, the shape of this Tensor element must "
"be [1]. But received the Tensor element's shape is [%s]"
,
tensor
->
dims
()));
if
(
platform
::
is_gpu_place
(
tensor
->
place
()))
{
framework
::
Tensor
temp
;
TensorCopySync
(
*
tensor
,
platform
::
CPUPlace
(),
&
temp
);
vec_shape
.
push_back
(
*
temp
.
data
<
int
>
());
}
else
{
vec_shape
.
push_back
(
*
tensor
->
data
<
int
>
());
}
}
auto
vec_shape
=
GetDataFromTensorList
(
shape_tensor_list
);
return
framework
::
make_ddim
(
vec_shape
);
}
...
...
@@ -115,7 +93,8 @@ class FillConstantKernel : public framework::OpKernel<T> {
}
value
=
tensor_data
[
0
];
}
auto
shape
=
GetShape
(
ctx
);
const
std
::
string
op_type
=
"fill_constant"
;
auto
shape
=
GetShape
(
ctx
,
op_type
);
if
(
out_var
->
IsType
<
framework
::
LoDTensor
>
())
{
tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/gaussian_random_op.cc
浏览文件 @
53bdee64
...
...
@@ -14,7 +14,7 @@ limitations under the License. */
#include <random>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
...
...
@@ -22,8 +22,37 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
CPUGaussianRandomKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
Attr
<
float
>
(
"mean"
);
float
std
=
context
.
Attr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
if
(
seed
==
0
)
{
seed
=
std
::
random_device
()();
}
engine
.
seed
(
seed
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
const
std
::
string
op_type
=
"gaussian_random"
;
auto
shape
=
GetShape
(
context
,
op_type
);
tensor
->
Resize
(
shape
);
int64_t
size
=
tensor
->
numel
();
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
}
}
};
template
<
typename
T
>
class
CPUGaussianRandomBatchSizeLikeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
Attr
<
float
>
(
"mean"
);
...
...
@@ -58,12 +87,26 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
for
(
auto
dim
:
shape
)
{
temp
.
push_back
(
static_cast
<
int64_t
>
(
dim
));
}
PADDLE_ENFORCE_GT
(
shape
.
size
(),
0UL
,
platform
::
errors
::
InvalidArgument
(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0, but reveived shape.size() is %d"
,
shape
.
size
()));
if
(
shape
.
empty
()
&&
ctx
->
HasInput
(
"ShapeTensor"
))
{
auto
shape_dims
=
ctx
->
GetInputDim
(
"ShapeTensor"
);
int
num_ele
=
1
;
for
(
int
i
=
0
;
i
<
shape_dims
.
size
();
++
i
)
{
num_ele
*=
shape_dims
[
i
];
}
auto
vec_dims
=
std
::
vector
<
int
>
(
num_ele
,
-
1
);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
vec_dims
));
return
;
}
if
(
!
(
ctx
->
HasInput
(
"ShapeTensor"
)
&&
!
ctx
->
HasInputs
(
"ShapeTensorList"
)))
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
0UL
,
platform
::
errors
::
InvalidArgument
(
"Attribute(shape) of GaussianRandomOp must be set "
"and shape.size() > 0, but reveived shape.size() is %d"
,
shape
.
size
()));
}
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
temp
));
}
...
...
@@ -85,6 +128,16 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
ctx
.
Attr
<
int
>
(
"dtype"
)),
ctx
.
device_context
(),
layout
,
library
);
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
if
(
var_name
==
"ShapeTensor"
||
var_name
==
"ShapeTensorList"
)
{
return
expected_kernel_type
;
}
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
};
class
GaussianRandomOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -94,7 +147,18 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
std
::
vector
<
int64_t
>>
(
"shape"
,
"(vector<int64_t>) "
"The dimension of random tensor."
);
"The dimension of random tensor."
)
.
SetDefault
({});
AddInput
(
"ShapeTensor"
,
"(Tensor<int>), optional). The shape of the output."
"It has a higher priority than Attr(shape)."
)
.
AsDispensable
();
AddInput
(
"ShapeTensorList"
,
"(vector<Tensor<int>>, optional). The shape of the output. "
"It has a higher priority than Attr(shape)."
"The shape of the element in vector must be [1]."
)
.
AsDuplicable
()
.
AsDispensable
();
AddAttr
<
float
>
(
"mean"
,
"(float, default 0.0) "
"mean of random tensor."
)
...
...
@@ -135,5 +199,5 @@ REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
REGISTER_OP_CPU_KERNEL
(
gaussian_random
,
ops
::
CPUGaussianRandomKernel
<
float
>
,
ops
::
CPUGaussianRandomKernel
<
double
>
);
REGISTER_OP_CPU_KERNEL
(
gaussian_random_batch_size_like
,
ops
::
CPUGaussianRandomKernel
<
float
>
,
ops
::
CPUGaussianRandomKernel
<
double
>
);
ops
::
CPUGaussianRandom
BatchSizeLike
Kernel
<
float
>
,
ops
::
CPUGaussianRandom
BatchSizeLike
Kernel
<
double
>
);
paddle/fluid/operators/gaussian_random_op.cu
浏览文件 @
53bdee64
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <thrust/transform.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -41,7 +42,6 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
if
(
seed
==
0
)
{
std
::
random_device
rd
;
...
...
@@ -50,6 +50,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
T
mean
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"std"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
const
std
::
string
op_type
=
"gaussian_random"
;
auto
shape
=
GetShape
(
context
,
op_type
);
tensor
->
Resize
(
shape
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
size
=
tensor
->
numel
();
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
...
...
@@ -57,12 +62,33 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
}
};
template
<
typename
T
>
class
GPUGaussianRandomBatchSizeLikeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
if
(
seed
==
0
)
{
std
::
random_device
rd
;
seed
=
rd
();
}
T
mean
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"mean"
));
T
std
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"std"
));
thrust
::
counting_iterator
<
unsigned
int
>
index_sequence_begin
(
0
);
int64_t
size
=
tensor
->
numel
();
thrust
::
transform
(
index_sequence_begin
,
index_sequence_begin
+
size
,
thrust
::
device_ptr
<
T
>
(
data
),
GaussianGenerator
<
T
>
(
mean
,
std
,
seed
));
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
gaussian_random
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
float
>
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
gaussian_random_batch_size_like
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
float
>
,
paddle
::
operators
::
GPUGaussianRandomKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
gaussian_random_batch_size_like
,
paddle
::
operators
::
GPUGaussianRandomBatchSizeLikeKernel
<
float
>
,
paddle
::
operators
::
GPUGaussianRandomBatchSizeLikeKernel
<
double
>
);
paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc
浏览文件 @
53bdee64
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/mean_op.h"
namespace
paddle
{
...
...
@@ -26,7 +27,6 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
float
mean
=
context
.
Attr
<
float
>
(
"mean"
);
float
std
=
context
.
Attr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
std
::
minstd_rand
engine
;
...
...
@@ -35,6 +35,11 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
engine
.
seed
(
seed
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
const
std
::
string
op_type
=
"gaussian_random"
;
auto
shape
=
GetShape
(
context
,
op_type
);
tensor
->
Resize
(
shape
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
size
=
tensor
->
numel
();
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
dist
(
engine
);
...
...
python/paddle/fluid/layers/distributions.py
浏览文件 @
53bdee64
...
...
@@ -357,8 +357,9 @@ class Normal(Distribution):
output_shape
=
shape
+
batch_shape
zero_tmp
=
tensor
.
fill_constant_batch_size_like
(
self
.
loc
+
self
.
scale
,
batch_shape
+
shape
,
self
.
loc
.
dtype
,
0.
)
normal_random_tmp
=
nn
.
gaussian_random_batch_size_like
(
zero_tmp
,
zero_tmp
.
shape
,
mean
=
0.
,
std
=
1.
,
seed
=
seed
)
zero_tmp_shape
=
nn
.
shape
(
zero_tmp
)
normal_random_tmp
=
nn
.
gaussian_random
(
zero_tmp_shape
,
mean
=
0.
,
std
=
1.
,
seed
=
seed
)
output
=
normal_random_tmp
*
(
zero_tmp
+
self
.
scale
)
+
self
.
loc
return
nn
.
reshape
(
output
,
output_shape
)
else
:
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
53bdee64
...
...
@@ -10169,33 +10169,55 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
Generate a random tensor whose data is drawn from a Gaussian distribution.
Args:
shape (
Tuple[int] | List[int
]): Shape of the generated random tensor.
shape (
tuple[int] | list[int] | Variable | list[Variable
]): Shape of the generated random tensor.
mean (float): Mean of the random tensor, defaults to 0.0.
std (float): Standard deviation of the random tensor, defaults to 1.0.
seed (int): ${seed_comment}
dtype(np.dtype | core.VarDesc.VarType | str): Output data type, float32 or float64.
Returns:
Variable: Random tensor whose data is drawn from a Gaussian distribution, dtype: flaot32 or float64 as specified.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
result_1 = fluid.layers.gaussian_random(shape=[3, 4])
# example 2:
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3)
dim_2 = fluid.layers.fill_constant([1],"int32",5)
result_2 = fluid.layers.gaussian_random(shape=[dim_1, dim_2])
# example 3:
# attr shape is a Variable, the data type must be int64 or int32.
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = fluid.layers.gaussian_random(var_shape)
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = fluid.layers.gaussian_random(var_shape_int32)
.. code-block:: python
# declarative mode
import numpy as np
from paddle import fluid
x = fluid.layers.gaussian_random((2, 3), std=2., seed=10)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
start = fluid.default_startup_program()
main = fluid.default_main_program()
exe.run(start)
x_np, = exe.run(main, feed={}, fetch_list=[x])
...
...
@@ -10209,33 +10231,44 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
place = fluid.CPUPlace()
with dg.guard(place) as g:
x = fluid.layers.gaussian_random((2, 4), mean=2., dtype="float32", seed=10)
x_np = x.numpy()
x_np = x.numpy()
x_np
# array([[2.3060477 , 2.676496 , 3.9911983 , 0.9990833 ],
# [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32)
"""
helper = LayerHelper('gaussian_random', **locals())
check_type(shape, 'shape', (list, tuple), 'fluid.layers.gaussian_random')
check_dtype(dtype, 'dtype', ['float32', 'float64'],
'fluid.layers.gaussian_random')
out = helper.create_variable_for_type_inference(dtype)
if not isinstance(shape, (list, tuple, Variable)):
raise TypeError(
"The type of 'shape' in fill_constant must be Variable, list or tuple, but "
"received %s." % (type(shape)))
c_dtype = convert_np_dtype_to_dtype_(dtype)
attrs = {
'mean': mean,
'std': std,
'seed': seed,
'dtype': c_dtype,
'use_mkldnn': False
}
inputs = {}
utils._get_shape_tensor_inputs(
inputs=inputs,
helper=helper,
attrs=attrs,
shape=shape,
op_type='gaussian_random')
helper.append_op(
type='gaussian_random',
inputs=inputs,
outputs={'Out': out},
attrs={
'shape': shape,
'mean': mean,
'std': std,
'seed': seed,
'dtype': c_dtype,
'use_mkldnn': False
})
attrs=attrs)
return out
...
...
python/paddle/fluid/tests/unittests/mkldnn/test_gaussian_random_mkldnn_op.py
浏览文件 @
53bdee64
...
...
@@ -27,17 +27,15 @@ class TestMKLDNNGaussianRandomOpSeed10(TestGaussianRandomOp):
class
TestMKLDNNGaussianRandomOpSeed0
(
TestGaussianRandomOp
):
def
setUp
(
self
):
TestGaussianRandomOp
.
setUp
(
self
)
self
.
use_mkldnn
=
True
self
.
attrs
=
{
"shape"
:
[
1
000
,
784
],
"mean"
:
.
0
,
"std"
:
1.
,
"seed"
:
0
,
"shape"
:
[
1
23
,
92
],
"mean"
:
1
.0
,
"std"
:
2.0
,
"seed"
:
1
0
,
"use_mkldnn"
:
self
.
use_mkldnn
}
def
init_kernel_type
(
self
):
self
.
use_mkldnn
=
True
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_gaussian_random_op.py
浏览文件 @
53bdee64
...
...
@@ -15,98 +15,213 @@
from
__future__
import
print_function
import
unittest
import
numpy
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
paddle.fluid.executor
import
Executor
from
op_test
import
OpTest
class
TestGaussianRandomOp
(
unittest
.
TestCase
):
class
TestGaussianRandomOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"gaussian_random"
self
.
inputs
=
{}
self
.
use_mkldnn
=
False
self
.
init_kernel_type
()
self
.
attrs
=
{
"shape"
:
[
1
000
,
784
],
"mean"
:
.
0
,
"std"
:
1
.
,
"shape"
:
[
1
23
,
92
],
"mean"
:
1
.0
,
"std"
:
2
.
,
"seed"
:
10
,
"use_mkldnn"
:
self
.
use_mkldnn
}
self
.
outputs
=
[
"Out"
]
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
123
,
92
),
dtype
=
'float32'
)}
def
test_c
pu
(
self
):
self
.
gaussian_random_test
(
place
=
fluid
.
CPUPlace
()
)
def
test_c
heck_output
(
self
):
self
.
check_output_customized
(
self
.
verify_output
)
def
test_gpu
(
self
):
if
core
.
is_compiled_with_cuda
():
self
.
gaussian_random_test
(
place
=
fluid
.
CUDAPlace
(
0
))
def
verify_output
(
self
,
outs
):
self
.
assertEqual
(
outs
[
0
].
shape
,
(
123
,
92
))
hist
,
_
=
np
.
histogram
(
outs
[
0
],
range
=
(
-
3
,
5
))
hist
=
hist
.
astype
(
"float32"
)
hist
/=
float
(
outs
[
0
].
size
)
data
=
np
.
random
.
normal
(
size
=
(
123
,
92
),
loc
=
1
,
scale
=
2
)
hist2
,
_
=
np
.
histogram
(
data
,
range
=
(
-
3
,
5
))
hist2
=
hist2
.
astype
(
"float32"
)
hist2
/=
float
(
outs
[
0
].
size
)
self
.
assertTrue
(
np
.
allclose
(
hist
,
hist2
,
rtol
=
0
,
atol
=
0.01
),
"hist: "
+
str
(
hist
)
+
" hist2: "
+
str
(
hist2
))
def
gaussian_random_test
(
self
,
place
):
program
=
fluid
.
Program
()
block
=
program
.
global_block
()
vout
=
block
.
create_var
(
name
=
"Out"
)
op
=
block
.
append_op
(
type
=
self
.
op_type
,
outputs
=
{
"Out"
:
vout
},
attrs
=
self
.
attrs
)
# Situation 2: Attr(shape) is a list(with tensor)
class
TestGaussianRandomOp_ShapeTensorList
(
TestGaussianRandomOp
):
def
setUp
(
self
):
'''Test gaussian_random op with specified value
'''
self
.
op_type
=
"gaussian_random"
self
.
init_data
()
shape_tensor_list
=
[]
for
index
,
ele
in
enumerate
(
self
.
shape
):
shape_tensor_list
.
append
((
"x"
+
str
(
index
),
np
.
ones
(
(
1
)).
astype
(
'int32'
)
*
ele
))
op
.
desc
.
infer_var_type
(
block
.
desc
)
op
.
desc
.
infer_shape
(
block
.
desc
)
self
.
attrs
=
{
'shape'
:
self
.
infer_shape
,
'mean'
:
self
.
mean
,
'std'
:
self
.
std
,
'seed'
:
self
.
seed
,
'use_mkldnn'
:
self
.
use_mkldnn
}
fetch_list
=
[]
for
var_name
in
self
.
outputs
:
fetch_list
.
append
(
block
.
var
(
var_name
))
self
.
inputs
=
{
"ShapeTensorList"
:
shape_tensor_list
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
123
,
92
),
dtype
=
'float32'
)}
exe
=
Executor
(
place
)
outs
=
exe
.
run
(
program
,
fetch_list
=
fetch_list
)
tensor
=
outs
[
0
]
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
infer_shape
=
[
-
1
,
92
]
self
.
use_mkldnn
=
False
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
self
.
assertAlmostEqual
(
numpy
.
mean
(
tensor
),
.
0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
numpy
.
std
(
tensor
),
1.
,
delta
=
0.1
)
def
test_check_output
(
self
):
self
.
check_output_customized
(
self
.
verify_output
)
def
init_kernel_type
(
self
):
pass
class
TestGaussianRandomOp2_ShapeTensorList
(
TestGaussianRandomOp_ShapeTensorList
):
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
infer_shape
=
[
-
1
,
-
1
]
self
.
use_mkldnn
=
False
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
class
TestGaussianRandomOp3_ShapeTensorList
(
TestGaussianRandomOp_ShapeTensorList
):
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
infer_shape
=
[
123
,
-
1
]
self
.
use_mkldnn
=
True
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
class
TestGaussianRandomOp4_ShapeTensorList
(
TestGaussianRandomOp_ShapeTensorList
):
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
infer_shape
=
[
123
,
-
1
]
self
.
use_mkldnn
=
False
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
class
TestGaussianRandomOpError
(
unittest
.
TestCase
):
# Situation 3: shape is a tensor
class
TestGaussianRandomOp1_ShapeTensor
(
TestGaussianRandomOp
):
def
setUp
(
self
):
'''Test gaussian_random op with specified value
'''
self
.
op_type
=
"gaussian_random"
self
.
in
puts
=
{}
self
.
in
it_data
()
self
.
use_mkldnn
=
False
self
.
inputs
=
{
"ShapeTensor"
:
np
.
array
(
self
.
shape
).
astype
(
"int32"
)}
self
.
attrs
=
{
"shape"
:
[
1000
,
784
],
"mean"
:
.
0
,
"std"
:
1.
,
"seed"
:
10
,
"use_mkldnn"
:
self
.
use_mkldnn
'mean'
:
self
.
mean
,
'std'
:
self
.
std
,
'seed'
:
self
.
seed
,
'use_mkldnn'
:
self
.
use_mkldnn
}
self
.
outputs
=
{
'Out'
:
np
.
zeros
((
123
,
92
),
dtype
=
'float32'
)}
self
.
outputs
=
[
"Out"
]
def
test_errors
(
self
):
program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
fluid
.
Program
(),
program
):
input_data
=
numpy
.
random
.
random
((
2
,
4
)).
astype
(
"float32"
)
block
=
program
.
global_block
()
vout
=
block
.
create_var
(
name
=
"Out"
,
dtype
=
'int32'
)
normal_initializer
=
fluid
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
1.0
,
seed
=
0
)
def
test_Variable
():
# the input type must be Variable
normal_initializer
(
input_data
)
self
.
assertRaises
(
TypeError
,
test_Variable
)
def
test_type
():
# dtype must be float32 or float64
normal_initializer
(
vout
)
self
.
assertRaises
(
TypeError
,
test_type
)
def
init_data
(
self
):
self
.
shape
=
[
123
,
92
]
self
.
use_mkldnn
=
False
self
.
mean
=
1.0
self
.
std
=
2.0
self
.
seed
=
10
# Test python API
class
TestGaussianRandomAPI
(
unittest
.
TestCase
):
def
test_api
(
self
):
positive_2_int32
=
fluid
.
layers
.
fill_constant
([
1
],
"int32"
,
2000
)
positive_2_int64
=
fluid
.
layers
.
fill_constant
([
1
],
"int64"
,
500
)
shape_tensor_int32
=
fluid
.
data
(
name
=
"shape_tensor_int32"
,
shape
=
[
2
],
dtype
=
"int32"
)
shape_tensor_int64
=
fluid
.
data
(
name
=
"shape_tensor_int64"
,
shape
=
[
2
],
dtype
=
"int64"
)
out_1
=
fluid
.
layers
.
gaussian_random
(
shape
=
[
2000
,
500
],
dtype
=
"float32"
,
mean
=
0.0
,
std
=
1.0
,
seed
=
10
)
out_2
=
fluid
.
layers
.
gaussian_random
(
shape
=
[
2000
,
positive_2_int32
],
dtype
=
"float32"
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
out_3
=
fluid
.
layers
.
gaussian_random
(
shape
=
[
2000
,
positive_2_int64
],
dtype
=
"float32"
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
out_4
=
fluid
.
layers
.
gaussian_random
(
shape
=
shape_tensor_int32
,
dtype
=
"float32"
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
out_5
=
fluid
.
layers
.
gaussian_random
(
shape
=
shape_tensor_int64
,
dtype
=
"float32"
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
out_6
=
fluid
.
layers
.
gaussian_random
(
shape
=
shape_tensor_int64
,
dtype
=
np
.
float32
,
mean
=
0.
,
std
=
1.0
,
seed
=
10
)
exe
=
fluid
.
Executor
(
place
=
fluid
.
CPUPlace
())
res_1
,
res_2
,
res_3
,
res_4
,
res_5
,
res_6
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"shape_tensor_int32"
:
np
.
array
([
2000
,
500
]).
astype
(
"int32"
),
"shape_tensor_int64"
:
np
.
array
([
2000
,
500
]).
astype
(
"int64"
),
},
fetch_list
=
[
out_1
,
out_2
,
out_3
,
out_4
,
out_5
,
out_6
])
self
.
assertAlmostEqual
(
np
.
mean
(
res_1
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_1
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_2
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_2
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_3
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_3
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_4
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_5
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_5
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_5
),
1.
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
mean
(
res_6
),
0.0
,
delta
=
0.1
)
self
.
assertAlmostEqual
(
np
.
std
(
res_6
),
1.
,
delta
=
0.1
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录