提交 6752b06f 编写于 作者: E emailweixu 提交者: Yi Wang

Generating random numbers with given batch size (#8337)

* Generating random numbers with given batch size

uniform_random_batch_size_like_op
gaussian_random_batch_size_like_op

* More comments about random seed.

* Move test_*_random_batch_size_like_op to unittests
上级 118d950e
...@@ -176,6 +176,20 @@ op_library(pool_op SRCS pool_op.cc DEPS pooling) ...@@ -176,6 +176,20 @@ op_library(pool_op SRCS pool_op.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col) op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
endif() endif()
cc_library(batch_size_like SRCS batch_size_like.cc DEPS op_registry)
op_library(fill_constant_batch_size_like_op
SRCS fill_constant_batch_size_like_op.cc fill_constant_batch_size_like_op.cu.cc
DEPS batch_size_like)
op_library(uniform_random_batch_size_like_op
SRCS uniform_random_batch_size_like_op.cc
DEPS batch_size_like uniform_random_op)
op_library(gaussian_random_batch_size_like_op
SRCS gaussian_random_batch_size_like_op.cc
DEPS batch_size_like gaussian_random_op)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions # FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor) op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle {
namespace operators {
void BatchSizeLikeOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of %s should not be null.", Type());
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of %s should not be null.",
Type());
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = framework::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
PADDLE_ENFORCE_GE(input_dim_idx, 0);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
PADDLE_ENFORCE_GE(output_dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
}
BatchSizeLikeOpMaker::BatchSizeLikeOpMaker(OpProto *proto,
OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(Tensor) Tensor "
"whose input_dim_idx'th dimension specifies the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("input_dim_idx",
"(int, default 0) The index of input's batch size dimension")
.SetDefault(0);
AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class BatchSizeLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
};
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker);
};
} // namespace operators
} // namespace paddle
...@@ -13,42 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,42 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fill_constant_batch_size_like_op.h" #include "paddle/fluid/operators/fill_constant_batch_size_like_op.h"
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { class FillConstantBatchSizeLikeOp : public BatchSizeLikeOp {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("Input"),
"Input(Input) of FillConstantBatchSizeLikeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FillConstantBatchSizeLikeOp should not be null.");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = framework::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
PADDLE_ENFORCE_GE(input_dim_idx, 0);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
PADDLE_ENFORCE_GE(output_dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
}
protected: protected:
using BatchSizeLikeOp::BatchSizeLikeOp;
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
...@@ -57,28 +29,14 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel { ...@@ -57,28 +29,14 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
} }
}; };
class FillConstantBatchSizeLikeOpMaker class FillConstantBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
: public framework::OpProtoAndCheckerMaker {
public: public:
FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker) FillConstantBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : BatchSizeLikeOpMaker(proto, op_checker) {
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5 (FP32)) " "(int, default 5 (FP32)) "
"Output data type") "Output data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
AddInput("Input",
"(Tensor) Tensor "
"whose dim_idx th dimension is used to specify the batch_size");
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
AddAttr<int>("input_dim_idx",
"(int, default 0) The index of input's batch size dimension")
.SetDefault(0);
AddAttr<int>("output_dim_idx",
"(int, default 0) The index of output's batch size dimension")
.SetDefault(0);
AddAttr<float>("value", "(float, default 0) The value to be filled") AddAttr<float>("value", "(float, default 0) The value to be filled")
.SetDefault(0.0f); .SetDefault(0.0f);
AddComment(R"DOC( AddComment(R"DOC(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle {
namespace operators {
class GaussianRandomBatchSizeLikeOp : public BatchSizeLikeOp {
protected:
using BatchSizeLikeOp::BatchSizeLikeOp;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
GaussianRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
AddAttr<float>("mean",
"(float, default 0.0) "
"mean of random tensor.")
.SetDefault(.0f);
AddAttr<float>("std",
"(float, default 1.0) "
"std of random tensor.")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
"Random seed of generator."
"0 means use system wide seed."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("dtype",
"(int, default 5(FP32)) "
"Output data type.")
.SetDefault(framework::proto::DataType::FP32);
AddComment(R"DOC(
GaussianRandom Operator.
Used to initialize tensors with gaussian random generator.
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
gaussian_random_batch_size_like,
paddle::operators::GaussianRandomBatchSizeLikeOp,
paddle::operators::GaussianRandomBatchSizeLikeOpMaker);
// Kernels are registered in gaussian_random_op.cc and gaussian_random_op.cu
...@@ -88,7 +88,9 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,7 +88,9 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("seed", AddAttr<int>("seed",
"(int, default 0) " "(int, default 0) "
"Random seed of generator." "Random seed of generator."
"0 means use system wide seed.") "0 means use system wide seed."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("dtype", AddAttr<int>("dtype",
"(int, default 5(FP32)) " "(int, default 5(FP32)) "
...@@ -110,4 +112,8 @@ Used to initialize tensors with gaussian random generator. ...@@ -110,4 +112,8 @@ Used to initialize tensors with gaussian random generator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker); ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>); REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>,
ops::CPUGaussianRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
ops::CPUGaussianRandomKernel<float>,
ops::CPUGaussianRandomKernel<double>);
...@@ -61,4 +61,8 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> { ...@@ -61,4 +61,8 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(gaussian_random, REGISTER_OP_CUDA_KERNEL(gaussian_random,
paddle::operators::GPUGaussianRandomKernel<float>); paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(gaussian_random_batch_size_like,
paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/batch_size_like.h"
namespace paddle {
namespace operators {
class UniformRandomBatchSizeLikeOp : public BatchSizeLikeOp {
protected:
using BatchSizeLikeOp::BatchSizeLikeOp;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class UniformRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
public:
UniformRandomBatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: BatchSizeLikeOpMaker(proto, op_checker) {
AddComment(R"DOC(
Uniform random operator
This operator initializes a tensor with the same batch_size as the Input tensor
with random values sampled from a uniform distribution.
)DOC");
AddAttr<float>("min",
"(float, default -1.0) "
"Minimum value of uniform random")
.SetDefault(-1.0f);
AddAttr<float>("max",
"(float, default 1.0) "
"Maximun value of uniform random")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
"Random seed used for generating samples. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
uniform_random_batch_size_like,
paddle::operators::UniformRandomBatchSizeLikeOp,
paddle::operators::UniformRandomBatchSizeLikeOpMaker);
// Kernels are registered in uniform_random_op.cc and uniform_random_op.cu
...@@ -96,7 +96,9 @@ uniform distribution. ...@@ -96,7 +96,9 @@ uniform distribution.
AddAttr<int>("seed", AddAttr<int>("seed",
"(int, default 0) " "(int, default 0) "
"Random seed used for generating samples. " "Random seed used for generating samples. "
"0 means use a seed generated by the system.") "0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type") AddAttr<int>("dtype", "(int, default 5(FP32)) Output tensor data type")
.SetDefault(framework::proto::DataType::FP32); .SetDefault(framework::proto::DataType::FP32);
...@@ -110,3 +112,6 @@ REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, ...@@ -110,3 +112,6 @@ REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>, paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>); paddle::operators::CPUUniformRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(uniform_random_batch_size_like,
paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>);
...@@ -66,3 +66,6 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -66,3 +66,6 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
REGISTER_OP_CUDA_KERNEL(uniform_random, REGISTER_OP_CUDA_KERNEL(uniform_random,
paddle::operators::GPUUniformRandomKernel<float>, paddle::operators::GPUUniformRandomKernel<float>,
paddle::operators::GPUUniformRandomKernel<double>); paddle::operators::GPUUniformRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(uniform_random_batch_size_like,
paddle::operators::GPUUniformRandomKernel<float>,
paddle::operators::GPUUniformRandomKernel<double>);
...@@ -66,6 +66,9 @@ __all__ = [ ...@@ -66,6 +66,9 @@ __all__ = [
'logical_xor', 'logical_xor',
'logical_not', 'logical_not',
'uniform_random', 'uniform_random',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'cumsum', 'cumsum',
] + __activations__ ] + __activations__
......
...@@ -248,7 +248,11 @@ class OpTest(unittest.TestCase): ...@@ -248,7 +248,11 @@ class OpTest(unittest.TestCase):
return feed_map return feed_map
def check_output_with_place(self, place, atol): def calc_output(self, place):
outs, _ = self._calc_output(place)
return outs
def _calc_output(self, place):
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
program = Program() program = Program()
...@@ -281,7 +285,10 @@ class OpTest(unittest.TestCase): ...@@ -281,7 +285,10 @@ class OpTest(unittest.TestCase):
feed=feed_map, feed=feed_map,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=False) return_numpy=False)
return outs, fetch_list
def check_output_with_place(self, place, atol):
outs, fetch_list = self._calc_output(place)
for out_name, out_dup in Operator.get_op_outputs(self.op_type): for out_name, out_dup in Operator.get_op_outputs(self.op_type):
if out_name not in self.outputs: if out_name not in self.outputs:
continue continue
...@@ -340,6 +347,15 @@ class OpTest(unittest.TestCase): ...@@ -340,6 +347,15 @@ class OpTest(unittest.TestCase):
for place in places: for place in places:
self.check_output_with_place(place, atol) self.check_output_with_place(place, atol)
def check_output_customized(self, checker):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0))
for place in places:
outs = self.calc_output(place)
outs = [np.array(out) for out in outs]
checker(outs)
def __assert_is_close(self, numeric_grads, analytic_grads, names, def __assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix): max_relative_error, msg_prefix):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestGaussianRandomBatchSizeLike(OpTest):
def setUp(self):
self.op_type = "gaussian_random_batch_size_like"
self.inputs = {'Input': np.zeros((500, 2000), dtype="float32")}
self.attrs = {'mean': 1., 'std': 2., 'shape': [-1, 2000]}
self.outputs = {'Out': np.zeros((500, 2000), dtype='float32')}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
self.assertEqual(outs[0].shape, (500, 2000))
hist, _ = np.histogram(outs[0], range=(-3, 5))
hist = hist.astype("float32")
hist /= float(outs[0].size)
data = np.random.normal(size=(500, 2000), loc=1, scale=2)
hist2, _ = np.histogram(data, range=(-3, 5))
hist2 = hist2.astype("float32")
hist2 /= float(outs[0].size)
self.assertTrue(
np.allclose(
hist, hist2, rtol=0, atol=0.01),
"hist: " + str(hist) + " hist2: " + str(hist2))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestUniformRandomBatchSizeLike(OpTest):
def setUp(self):
self.op_type = "uniform_random_batch_size_like"
self.inputs = {'Input': np.zeros((500, 2000), dtype="float32")}
self.attrs = {'min': 1., 'max': 2., 'shape': [-1, 2000]}
self.outputs = {'Out': np.zeros((500, 2000), dtype='float32')}
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
self.assertEqual(outs[0].shape, (500, 2000))
hist, _ = np.histogram(outs[0], range=(1, 2))
hist = hist.astype("float32")
hist /= float(outs[0].size)
prob = 0.1 * np.ones((10))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
if __name__ == "__main__":
unittest.main()
...@@ -13,14 +13,11 @@ ...@@ -13,14 +13,11 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
import numpy import numpy as np
from op_test import OpTest
from paddle.v2.fluid.op import Operator
import paddle.v2.fluid.core as core
import paddle.v2.fluid as fluid
class TestUniformRandomOp(OpTest):
class TestUniformRandomOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.op_type = "uniform_random" self.op_type = "uniform_random"
self.inputs = {} self.inputs = {}
...@@ -30,35 +27,20 @@ class TestUniformRandomOp(unittest.TestCase): ...@@ -30,35 +27,20 @@ class TestUniformRandomOp(unittest.TestCase):
"max": 10.0, "max": 10.0,
"seed": 10 "seed": 10
} }
self.outputs = ["Out"] self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
def test_cpu(self):
self.uniform_random_test(place=core.CPUPlace())
def test_gpu(self):
if core.is_compiled_with_cuda():
self.uniform_random_test(place=core.CUDAPlace(0))
def uniform_random_test(self, place):
program = fluid.Program()
block = program.global_block()
vout = block.create_var(name="Out")
op = block.append_op(
type=self.op_type, outputs={"Out": vout}, attrs=self.attrs)
op.desc.infer_var_type(block.desc) def test_check_output(self):
op.desc.infer_shape(block.desc) self.check_output_customized(self.verify_output)
fetch_list = []
for var_name in self.outputs:
fetch_list.append(block.var(var_name))
exe = fluid.Executor(place)
outs = exe.run(program, fetch_list=fetch_list)
def verify_output(self, outs):
tensor = outs[0] tensor = outs[0]
hist, _ = np.histogram(outs[0], range=(-5, 10))
self.assertAlmostEqual(tensor.mean(), 2.5, delta=0.1) hist = hist.astype("float32")
hist /= float(outs[0].size)
prob = 0.1 * np.ones((10))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册