提交 f1eebf75 编写于 作者: S silingtong123 提交者: Tao Luo

improve op uniform_random, argument shape support tensor and tensor in list (#19786)

* test=develop, argument shape support tensor and tensor in list

* test=develop,Increasing the coverage of CI tests

* test=develop, modify the document and update API.spec

* test=develop, modify the doc and update API.spec

* test=develop, modify the doc and update API.spec

* test=develop, modify the interface of UniformInitializer

* test=develop, modify the interface of XavierInitializer and MSRAInitializer

* test=develop, modify based on review's comments

* test=develop, modify based on review's comments

*  test=develop, modify based on review's comments
上级 24010472
......@@ -309,6 +309,7 @@ paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_ta
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', 'c4969dd6bf164f9e6a90414ea4f4e5ad'))
paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', '6a5152a7015c62cb8278fc24cb456459'))
paddle.fluid.layers.mse_loss (ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None), ('document', 'd9ede6469288636e1b3233b461a165c9'))
paddle.fluid.layers.uniform_random (ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=('float32', -1.0, 1.0, 0)), ('document', '126ede8ce0e751244b1b54cd359c89d7'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '88367daf9a30c9ab83adc5d7221e23ef'))
paddle.fluid.layers.double_buffer (ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '44724c493f41a124abc7531c2740e2e3'))
......@@ -403,7 +404,6 @@ paddle.fluid.layers.reciprocal (ArgSpec(args=['x', 'name'], varargs=None, keywor
paddle.fluid.layers.square (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '728233aff902803f5f62e2d340c3bcbb'))
paddle.fluid.layers.softplus (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '74c4e6dfbdfc3453301ea11d722ad3d6'))
paddle.fluid.layers.softsign (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'a70e9320b113ca33c1299bbc032f09d4'))
paddle.fluid.layers.uniform_random (ArgSpec(args=['shape', 'dtype', 'min', 'max', 'seed'], varargs=None, keywords=None, defaults=('float32', -1.0, 1.0, 0)), ('document', '6de6775d9e9ed885056e764982130cfd'))
paddle.fluid.layers.softshrink (ArgSpec(args=['x', 'alpha'], varargs=None, keywords=None, defaults=(None,)), ('document', '958c7bfdfb0b5e92af6ca4a90d24e5ef'))
paddle.fluid.layers.hard_shrink (ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,)), ('document', '386a4103d2884b2f1312ebc1e8ee6486'))
paddle.fluid.layers.cumsum (ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '5ab9d5721a6734fe127069e4314e1309'))
......
......@@ -11,9 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/uniform_random_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
......@@ -26,14 +27,28 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
framework::Tensor *tensor = nullptr;
auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
std::vector<int64_t> new_shape;
auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("ShapeTensorList");
if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) {
if (ctx.HasInput("ShapeTensor")) {
auto *shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor");
new_shape = get_new_data_from_shape_tensor(shape_tensor);
} else if (list_new_shape_tensor.size() > 0) {
new_shape = get_new_shape_from_shape_tensorlist(list_new_shape_tensor);
}
}
if (out_var->IsType<framework::SelectedRows>()) {
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
tensor = selected_rows->mutable_value();
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
if (!new_shape.empty()) shape = new_shape;
tensor->Resize(framework::make_ddim(shape));
selected_rows->mutable_rows()->reserve(shape[0]);
} else if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
if (!new_shape.empty()) tensor->Resize(framework::make_ddim(new_shape));
} else {
PADDLE_THROW(
"uniform_random_op's output only"
......@@ -80,17 +95,53 @@ class UniformRandomOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LT(ctx->Attrs().Get<float>("min"),
ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max");
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
"diag_num must greater than or equal 0");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
"diag_step must greater than or equal 0");
std::vector<int64_t> temp;
temp.reserve(shape.size());
if (ctx->HasInputs("ShapeTensorList")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensorList");
PADDLE_ENFORCE_GT(
inputs_name.size(), 0,
"Input(ShapeTensorList)'size of Op(uniform_random) can't be zero."
"Please check the Attr(shape)'s size of"
"Op(fluid.layers.uniform_random).)");
auto out_dims = std::vector<int>(inputs_name.size(), -1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
return;
}
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
if (ctx->HasInput("ShapeTensor") && shape.empty()) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
PADDLE_ENFORCE_EQ(
shape_dims.size(), 1,
"Input(ShapeTensor)' dimension size of Op(uniform_random) must be 1."
"Please check the Attr(shape)'s dimension size of"
"Op(fluid.layers.uniform_random).)");
int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
auto vec_dims = std::vector<int64_t>(num_ele, -1);
auto out_dims = framework::make_ddim(vec_dims);
ctx->SetOutputDim("Out", out_dims);
return;
}
PADDLE_ENFORCE_EQ(
shape.empty(), false,
"if there is no Input(ShapeTensorList) and no Input(ShapeTensor),the "
"attr(shape) information must "
"be set by Attr(shape).");
std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape.size());
for (auto dim : shape) {
temp.push_back(static_cast<int64_t>(dim));
tensor_shape.push_back(static_cast<int64_t>(dim));
}
ctx->SetOutputDim("Out", framework::make_ddim(temp));
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
}
protected:
......@@ -100,18 +151,44 @@ class UniformRandomOp : public framework::OperatorWithKernel {
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ShapeTensorList" || var_name == "ShapeTensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("ShapeTensor",
"(Tensor<int64_t>, optional). If provided, uniform_ranodom "
"according to "
"this given shape. That is to say it has a higher priority than "
"the shape attribute, while the shape attribute still should be "
"set correctly to gurantee shape inference in compile time.")
.AsDispensable();
AddInput("ShapeTensorList",
"(vector<Tensor<int64_t>>, optional). If provided, uniform_random "
"will use this"
"The shape of the tensor in vector MUST BE [1]"
"it has the highest priority compare with Input(Shape) and "
"attr(shape).")
.AsDuplicable()
.AsDispensable();
AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC(
This operator initializes a tensor with random values sampled from a
uniform distribution. The random result is in set [min, max].
)DOC");
AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor");
AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor")
.SetDefault({});
AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].")
.SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].")
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <thrust/transform.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
namespace paddle {
namespace operators {
......@@ -58,12 +58,28 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
framework::Tensor* tensor = nullptr;
auto out_var = context.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) {
std::vector<int64_t> new_shape;
auto list_new_shape_tensor =
context.MultiInput<framework::Tensor>("ShapeTensorList");
if (list_new_shape_tensor.size() > 0 || context.HasInput("ShapeTensor")) {
if (context.HasInput("ShapeTensor")) {
auto* shape_tensor = context.Input<framework::Tensor>("ShapeTensor");
new_shape = get_new_data_from_shape_tensor(shape_tensor);
} else if (list_new_shape_tensor.size() > 0) {
new_shape = get_new_shape_from_shape_tensorlist(list_new_shape_tensor);
}
}
if (out_var->IsType<framework::SelectedRows>()) {
auto* selected_rows = out_var->GetMutable<framework::SelectedRows>();
tensor = selected_rows->mutable_value();
auto shape = context.Attr<std::vector<int64_t>>("shape");
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
if (!new_shape.empty()) shape = new_shape;
tensor->Resize(framework::make_ddim(shape));
selected_rows->mutable_rows()->reserve(shape[0]);
} else if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
if (!new_shape.empty()) tensor->Resize(framework::make_ddim(new_shape));
} else {
PADDLE_THROW(
"uniform_random_op's output only"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline std::vector<int64_t> get_new_data_from_shape_tensor(
const Tensor *new_data_tensor) {
auto *new_data = new_data_tensor->data<int64_t>();
if (platform::is_gpu_place(new_data_tensor->place())) {
framework::Tensor cpu_starts_tensor;
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<int64_t>();
}
std::vector<int64_t> vec_new_data(new_data,
new_data + new_data_tensor->numel());
return vec_new_data;
}
inline std::vector<int64_t> get_new_shape_from_shape_tensorlist(
const std::vector<const Tensor *> &list_new_shape_tensor) {
std::vector<int64_t> vec_new_shape;
vec_new_shape.reserve(list_new_shape_tensor.size());
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
"shape of dim tensor should be [1]");
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int64_t>(*temp.data<int64_t>()));
} else {
vec_new_shape.push_back(static_cast<int64_t>(*tensor->data<int64_t>()));
}
}
return vec_new_shape;
}
} // namespace operators
} // namespace paddle
......@@ -281,6 +281,7 @@ class UniformInitializer(Initializer):
op = block._prepend_op(
type="uniform_random",
inputs={},
outputs={"Out": out_var},
attrs={
"shape": var.shape,
......@@ -565,6 +566,7 @@ class XavierInitializer(Initializer):
limit = np.sqrt(6.0 / float(fan_in + fan_out))
op = block._prepend_op(
type="uniform_random",
inputs={},
outputs={"Out": out_var},
attrs={
"shape": out_var.shape,
......@@ -691,6 +693,7 @@ class MSRAInitializer(Initializer):
limit = np.sqrt(6.0 / float(fan_in))
op = block._prepend_op(
type="uniform_random",
inputs={},
outputs={"Out": out_var},
attrs={
"shape": out_var.shape,
......
......@@ -208,7 +208,7 @@ class Uniform(Distribution):
return nn.reshape(output, output_shape)
else:
output_shape = shape + batch_shape
output = ops.uniform_random(
output = nn.uniform_random(
output_shape, seed=seed) * (tensor.zeros(
output_shape, dtype=self.low.dtype) +
(self.high - self.low)) + self.low
......
......@@ -222,6 +222,7 @@ __all__ = [
'shard_index',
'hard_swish',
'mse_loss',
'uniform_random',
]
kIgnoreIndex = -100
......@@ -15068,3 +15069,120 @@ def mse_loss(input, label):
"""
return reduce_mean(square_error_cost(input, label))
@templatedoc()
def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
"""
This operator initializes a variable with random values sampled from a
uniform distribution. The random result is in set [min, max).
Examples:
::
Input:
shape = [1, 2]
Output:
result=[[0.8505902, 0.8397286]]
Args:
shape (list|tuple|Variable): The shape of the output tensor, the data type of the integer is int,
and if the shape type is list or tuple, its elements can be an integer
or a tensor with the shape [1], the data type of the tensor is int64.
If the shape type is Variable,it ia a 1D tensor, the data type of the tensor is int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type of the output tensor, such as float32, float64.
Default: float32.
min (float, optional): Minimum value of uniform random, It's a closed interval. Default -1.0.
max (float, optional): Maximun value of uniform random, It's an open interval. Default 1.0.
seed (int, optional): Random seed used for generating samples. 0 means use a
seed generated by the system. Note that if seed is not 0, this
operator will always generate the same random numbers every time.
Default 0.
Returns: a Tensor with randomly initialized results whose data type is determined by the dtype parameter
and whose dimension is determined by the shape parameter.
Return type: Variable
Throw exception:
TypeError: The shape type should be list or tupple or variable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
result_1 = fluid.layers.uniform_random(shape=[3, 4])
# example 2:
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3)
result_2 = fluid.layers.uniform_random(shape=[dim_1, 5])
# example 3:
# attr shape is a Variable, the data type must be int64
var_shape = fluid.layers.data(name='var_shape',shape=[2],append_batch_size=False)
result_3 = fluid.layers.uniform_random(var_shape)
"""
if not (isinstance(shape, (list, tuple, Variable))):
raise TypeError("Input shape must be a python list,Variable or tuple.")
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
def get_new_shape_tensor(list_shape):
new_shape_tensor = []
for dim in list_shape:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_shape_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference('int64')
fill_constant([1], 'int64', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor
def get_attr_shape(list_shape):
unk_dim_idx = -1
attrs_shape = []
for dim_idx, dim_size in enumerate(list_shape):
if isinstance(dim_size, Variable):
attrs_shape.append(-1)
else:
attrs_shape.append(dim_size)
assert dim_size > 0, (
"Each dimension size given in shape must not be negtive "
"except one unknown dimension.")
return attrs_shape
helper = LayerHelper("uniform_random", **locals())
inputs = dict()
attrs = dict()
if in_dygraph_mode():
attrs = {'shape': shape}
else:
if isinstance(shape, Variable):
shape.stop_gradient = True
inputs["ShapeTensor"] = shape
elif isinstance(shape, (list, tuple)):
assert len(shape) > 0, (
"The size of argument(shape) can't be zero.")
attrs["shape"] = get_attr_shape(shape)
if contain_var(shape):
inputs['ShapeTensorList'] = get_new_shape_tensor(shape)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="uniform_random", inputs=inputs, attrs=attrs,
outputs={"Out": out})
return helper.append_activation(out)
......@@ -58,44 +58,6 @@ __all__ += __activations_noattr__
for _OP in set(__activations_noattr__):
globals()[_OP] = generate_activation_fn(_OP)
__all__ += ["uniform_random"]
_uniform_random_ = generate_layer_fn('uniform_random')
def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
"""
This operator initializes a variable with random values sampled from a
uniform distribution. The random result is in set [min, max].
Args:
shape (list): The shape of output variable.
dtype(np.dtype|core.VarDesc.VarType|str): The type of data, such as
float32, float64 etc. Default: float32.
min (float): Minimum value of uniform random. Default -1.0.
max (float): Maximun value of uniform random. Default 1.0.
seed (int): Random seed used for generating samples. 0 means use a
seed generated by the system. Note that if seed is not 0, this
operator will always generate the same random numbers every time.
Default 0.
Examples:
.. code-block:: python
import paddle.fluid as fluid
result = fluid.layers.uniform_random(shape=[32, 784])
"""
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
locals_var = locals().copy()
kwargs = dict()
for name, val in locals_var.items():
if val is not None:
kwargs[name] = val
return _uniform_random_(**kwargs)
__all__ += ['softshrink']
_softshrink_ = generate_layer_fn('softshrink')
......
......@@ -43,6 +43,53 @@ def output_hist_diag(out):
return hist, prob
class TestUniformRandomOp_attr_tensorlist(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.new_shape = (1000, 784)
shape_tensor = []
for index, ele in enumerate(self.new_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype("int64") * ele))
self.inputs = {'ShapeTensorList': shape_tensor}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
def init_attrs(self):
self.attrs = {"min": -5.0, "max": 10.0, "seed": 10}
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOp_attr_tensor(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int64")}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
def init_attrs(self):
self.attrs = {"min": -5.0, "max": 10.0, "seed": 10}
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOp(OpTest):
def setUp(self):
self.op_type = "uniform_random"
......@@ -158,5 +205,87 @@ class TestUniformRandomOpApi(unittest.TestCase):
ret = exe.run(feed={'x': x_tensor}, fetch_list=[y], return_numpy=False)
class TestUniformRandomOp_attr_tensor_API(unittest.TestCase):
def test_attr_tensor_API(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
dim_tensor = fluid.layers.fill_constant([1], "int64", 3)
ret = fluid.layers.nn.uniform_random([1, dim_tensor, 2])
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
outs = exe.run(train_program, fetch_list=[ret])
class TestUniformRandomOpSelectedRowsShapeTensor(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
def check_with_place(self, place):
scope = core.Scope()
out = scope.var("X").get_selected_rows()
shape_tensor = scope.var("Shape").get_tensor()
shape_tensor.set(np.array([4, 784]).astype("int64"), place)
op = Operator(
"uniform_random",
ShapeTensor="Shape",
Out="X",
min=-5.0,
max=10.0,
seed=10)
op.run(scope, place)
self.assertEqual(out.get_tensor().shape(), [4, 784])
hist, prob = output_hist(np.array(out.get_tensor()))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOpSelectedRowsShapeTensorList(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
def check_with_place(self, place):
scope = core.Scope()
out = scope.var("X").get_selected_rows()
shape_1 = scope.var("shape1").get_tensor()
shape_1.set(np.array([4]).astype("int64"), place)
shape_2 = scope.var("shape2").get_tensor()
shape_2.set(np.array([784]).astype("int64"), place)
op = Operator(
"uniform_random",
ShapeTensorList=["shape1", "shape2"],
Out="X",
min=-5.0,
max=10.0,
seed=10)
op.run(scope, place)
self.assertEqual(out.get_tensor().shape(), [4, 784])
hist, prob = output_hist(np.array(out.get_tensor()))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册