提交 03d7f3dd 编写于 作者: S silingtong123 提交者: liuwei1031

Make shape tensor support int32 (#20757)

*  Make shape tensor support int32
上级 95ba4bd2
......@@ -172,24 +172,24 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("ShapeTensor",
"(Tensor<int64_t>, optional). If provided, uniform_ranodom "
"(Tensor<int64_t> or Tensor<int32_t>, optional) . If provided, "
"uniform_random "
"according to "
"this given shape. It means that it has a higher priority than "
"the shape attribute, while the shape attribute still should be "
"set correctly to gurantee shape inference in compile time.")
.AsDispensable();
AddInput("ShapeTensorList",
"(vector<Tensor<int64_t>>, optional). If provided, uniform_random "
"use this."
"The shape of the tensor in vector MUST BE [1],"
"it has the highest priority compare with Input(Shape) and "
"attr(shape).")
"(vector<Tensor<int64_t>> or vector<Tensor<int32_t>>, optional). "
"If provided, uniform_random use this. The shape of the tensor "
"must be [1], it has the highest priority comparing with "
"Input(ShapeTensor) and attr(shape).")
.AsDuplicable()
.AsDispensable();
AddOutput("Out", "The output tensor of uniform random op");
AddComment(R"DOC(
This operator initializes a tensor with random values sampled from a
uniform distribution. The random result is in set [min, max].
uniform distribution. The random result is in set [min, max).
)DOC");
AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor")
......
......@@ -24,15 +24,33 @@ using Tensor = framework::Tensor;
inline std::vector<int64_t> GetNewDataFromShapeTensor(
const Tensor *new_data_tensor) {
auto *new_data = new_data_tensor->data<int64_t>();
if (platform::is_gpu_place(new_data_tensor->place())) {
framework::Tensor cpu_starts_tensor;
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<int64_t>();
if (new_data_tensor->type() == framework::proto::VarType::INT64) {
auto *new_data = new_data_tensor->data<int64_t>();
if (platform::is_gpu_place(new_data_tensor->place())) {
framework::Tensor cpu_starts_tensor;
TensorCopySync(*new_data_tensor, platform::CPUPlace(),
&cpu_starts_tensor);
new_data = cpu_starts_tensor.data<int64_t>();
}
std::vector<int64_t> vec_new_data(new_data,
new_data + new_data_tensor->numel());
return vec_new_data;
} else if (new_data_tensor->type() == framework::proto::VarType::INT32) {
auto *new_data = new_data_tensor->data<int32_t>();
std::vector<int64_t> vec_new_data;
if (platform::is_gpu_place(new_data_tensor->place())) {
framework::Tensor cpu_starts_tensor;
TensorCopySync(*new_data_tensor, platform::CPUPlace(),
&cpu_starts_tensor);
new_data = cpu_starts_tensor.data<int32_t>();
}
for (size_t i = 0; i < new_data_tensor->numel(); ++i) {
vec_new_data.push_back(static_cast<int64_t>(*(new_data + i)));
}
return vec_new_data;
} else {
PADDLE_THROW("The dtype of shape tensor must be int32 or int64.");
}
std::vector<int64_t> vec_new_data(new_data,
new_data + new_data_tensor->numel());
return vec_new_data;
}
inline std::vector<int64_t> GetNewDataFromShapeTensorList(
......@@ -43,12 +61,25 @@ inline std::vector<int64_t> GetNewDataFromShapeTensorList(
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
"shape of dim tensor should be [1]");
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(*temp.data<int64_t>());
if (tensor->type() == framework::proto::VarType::INT32) {
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int64_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int64_t>(*tensor->data<int32_t>()));
}
} else if (tensor->type() == framework::proto::VarType::INT64) {
if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(*temp.data<int64_t>());
} else {
vec_new_shape.push_back(*tensor->data<int64_t>());
}
} else {
vec_new_shape.push_back(*tensor->data<int64_t>());
PADDLE_THROW("The dtype of shape tensor must be int32 or int64.");
}
}
......
......@@ -17623,8 +17623,8 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
Args:
shape (list|tuple|Variable): The shape of the output Tensor, if the shape is a list or tuple,
its elements can be an integer
or a Tensor with the shape [1], and the type of the Tensor is int64.
If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor is int64.
or a Tensor with the shape [1], and the type of the Tensor must be int32 or int64.
If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor must be int32 or int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The type of the output Tensor. Supported data types: float32, float64.
Default: float32.
min (float, optional): The lower bound on the range of random values to generate, the min is included in the range. Default -1.0.
......@@ -17652,12 +17652,17 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
# example 2:
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3)
result_2 = fluid.layers.uniform_random(shape=[dim_1, 5])
dim_2 = fluid.layers.fill_constant([1],"int32",5)
result_2 = fluid.layers.uniform_random(shape=[dim_1, dim_2])
# example 3:
# attr shape is a Variable, the data type must be int64
# attr shape is a Variable, the data type must be int64 or int32.
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = fluid.layers.uniform_random(var_shape)
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = fluid.layers.uniform_random(var_shape_int32)
"""
if not (isinstance(shape, (list, tuple, Variable))):
......
......@@ -70,6 +70,32 @@ class TestUniformRandomOp_attr_tensorlist(OpTest):
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOp_attr_tensorlist_int32(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.new_shape = (1000, 784)
shape_tensor = []
for index, ele in enumerate(self.new_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype("int32") * ele))
self.inputs = {'ShapeTensorList': shape_tensor}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
def init_attrs(self):
self.attrs = {"min": -5.0, "max": 10.0, "seed": 10}
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOp_attr_tensor(OpTest):
def setUp(self):
self.op_type = "uniform_random"
......@@ -91,6 +117,27 @@ class TestUniformRandomOp_attr_tensor(OpTest):
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOp_attr_tensor_int32(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int32")}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
def init_attrs(self):
self.attrs = {"min": -5.0, "max": 10.0, "seed": 10}
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestUniformRandomOp(OpTest):
def setUp(self):
self.op_type = "uniform_random"
......@@ -235,13 +282,47 @@ class TestUniformRandomOp_attr_tensor_API(unittest.TestCase):
dim_tensor = fluid.layers.fill_constant([1], "int64", 3)
ret = fluid.layers.nn.uniform_random([1, dim_tensor, 2])
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
outs = exe.run(train_program, fetch_list=[ret])
def test_attr_tensorlist_int32_API(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
dim_1 = fluid.layers.fill_constant([1], "int64", 3)
dim_2 = fluid.layers.fill_constant([1], "int32", 2)
ret = fluid.layers.nn.uniform_random([1, dim_1, dim_2])
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
outs = exe.run(train_program, fetch_list=[ret])
def test_attr_tensor_int32_API(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
shape = fluid.data(name='shape_tensor', shape=[2], dtype="int32")
ret = fluid.layers.nn.uniform_random(shape)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
Shape = np.array([2, 3]).astype('int32')
exe.run(startup_program)
outs = exe.run(train_program,
feed={'shape_tensor': Shape},
fetch_list=[ret])
class TestUniformRandomOp_API_seed(unittest.TestCase):
def test_attr_tensor_API(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册