From 03d7f3ddb2504995275e929d45f51c4088b50394 Mon Sep 17 00:00:00 2001 From: silingtong123 <35439432+silingtong123@users.noreply.github.com> Date: Tue, 29 Oct 2019 10:47:16 +0800 Subject: [PATCH] Make shape tensor support int32 (#20757) * Make shape tensor support int32 --- paddle/fluid/operators/uniform_random_op.cc | 14 +-- paddle/fluid/operators/uniform_random_op.h | 57 ++++++++++--- python/paddle/fluid/layers/nn.py | 13 ++- .../tests/unittests/test_uniform_random_op.py | 85 ++++++++++++++++++- 4 files changed, 143 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 946441ce0b2..e81d8a22fee 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -172,24 +172,24 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("ShapeTensor", - "(Tensor, optional). If provided, uniform_ranodom " + "(Tensor or Tensor, optional) . If provided, " + "uniform_random " "according to " "this given shape. It means that it has a higher priority than " "the shape attribute, while the shape attribute still should be " "set correctly to gurantee shape inference in compile time.") .AsDispensable(); AddInput("ShapeTensorList", - "(vector>, optional). If provided, uniform_random " - "use this." - "The shape of the tensor in vector MUST BE [1]," - "it has the highest priority compare with Input(Shape) and " - "attr(shape).") + "(vector> or vector>, optional). " + "If provided, uniform_random use this. The shape of the tensor " + "must be [1], it has the highest priority comparing with " + "Input(ShapeTensor) and attr(shape).") .AsDuplicable() .AsDispensable(); AddOutput("Out", "The output tensor of uniform random op"); AddComment(R"DOC( This operator initializes a tensor with random values sampled from a -uniform distribution. The random result is in set [min, max]. +uniform distribution. The random result is in set [min, max). )DOC"); AddAttr>("shape", "The shape of the output tensor") diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index c7593965897..b513656bf7f 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -24,15 +24,33 @@ using Tensor = framework::Tensor; inline std::vector GetNewDataFromShapeTensor( const Tensor *new_data_tensor) { - auto *new_data = new_data_tensor->data(); - if (platform::is_gpu_place(new_data_tensor->place())) { - framework::Tensor cpu_starts_tensor; - TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor); - new_data = cpu_starts_tensor.data(); + if (new_data_tensor->type() == framework::proto::VarType::INT64) { + auto *new_data = new_data_tensor->data(); + if (platform::is_gpu_place(new_data_tensor->place())) { + framework::Tensor cpu_starts_tensor; + TensorCopySync(*new_data_tensor, platform::CPUPlace(), + &cpu_starts_tensor); + new_data = cpu_starts_tensor.data(); + } + std::vector vec_new_data(new_data, + new_data + new_data_tensor->numel()); + return vec_new_data; + } else if (new_data_tensor->type() == framework::proto::VarType::INT32) { + auto *new_data = new_data_tensor->data(); + std::vector vec_new_data; + if (platform::is_gpu_place(new_data_tensor->place())) { + framework::Tensor cpu_starts_tensor; + TensorCopySync(*new_data_tensor, platform::CPUPlace(), + &cpu_starts_tensor); + new_data = cpu_starts_tensor.data(); + } + for (size_t i = 0; i < new_data_tensor->numel(); ++i) { + vec_new_data.push_back(static_cast(*(new_data + i))); + } + return vec_new_data; + } else { + PADDLE_THROW("The dtype of shape tensor must be int32 or int64."); } - std::vector vec_new_data(new_data, - new_data + new_data_tensor->numel()); - return vec_new_data; } inline std::vector GetNewDataFromShapeTensorList( @@ -43,12 +61,25 @@ inline std::vector GetNewDataFromShapeTensorList( auto tensor = list_new_shape_tensor[i]; PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}), "shape of dim tensor should be [1]"); - if (platform::is_gpu_place(tensor->place())) { - framework::Tensor temp; - TensorCopySync(*tensor, platform::CPUPlace(), &temp); - vec_new_shape.push_back(*temp.data()); + + if (tensor->type() == framework::proto::VarType::INT32) { + if (platform::is_gpu_place(tensor->place())) { + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_new_shape.push_back(static_cast(*temp.data())); + } else { + vec_new_shape.push_back(static_cast(*tensor->data())); + } + } else if (tensor->type() == framework::proto::VarType::INT64) { + if (platform::is_gpu_place(tensor->place())) { + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + vec_new_shape.push_back(*temp.data()); + } else { + vec_new_shape.push_back(*tensor->data()); + } } else { - vec_new_shape.push_back(*tensor->data()); + PADDLE_THROW("The dtype of shape tensor must be int32 or int64."); } } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e74a02ca93d..2e0c1044a6e 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -17623,8 +17623,8 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): Args: shape (list|tuple|Variable): The shape of the output Tensor, if the shape is a list or tuple, its elements can be an integer - or a Tensor with the shape [1], and the type of the Tensor is int64. - If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor is int64. + or a Tensor with the shape [1], and the type of the Tensor must be int32 or int64. + If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor must be int32 or int64. dtype(np.dtype|core.VarDesc.VarType|str, optional): The type of the output Tensor. Supported data types: float32, float64. Default: float32. min (float, optional): The lower bound on the range of random values to generate, the min is included in the range. Default -1.0. @@ -17652,12 +17652,17 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): # example 2: # attr shape is a list which contains tensor Variable. dim_1 = fluid.layers.fill_constant([1],"int64",3) - result_2 = fluid.layers.uniform_random(shape=[dim_1, 5]) + dim_2 = fluid.layers.fill_constant([1],"int32",5) + result_2 = fluid.layers.uniform_random(shape=[dim_1, dim_2]) # example 3: - # attr shape is a Variable, the data type must be int64 + # attr shape is a Variable, the data type must be int64 or int32. var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64") result_3 = fluid.layers.uniform_random(var_shape) + var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32") + result_4 = fluid.layers.uniform_random(var_shape_int32) + + """ if not (isinstance(shape, (list, tuple, Variable))): diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 92ac6be3c2d..65d534b79b0 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -70,6 +70,32 @@ class TestUniformRandomOp_attr_tensorlist(OpTest): hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) +class TestUniformRandomOp_attr_tensorlist_int32(OpTest): + def setUp(self): + self.op_type = "uniform_random" + self.new_shape = (1000, 784) + shape_tensor = [] + for index, ele in enumerate(self.new_shape): + shape_tensor.append(("x" + str(index), np.ones( + (1)).astype("int32") * ele)) + self.inputs = {'ShapeTensorList': shape_tensor} + self.init_attrs() + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + + def init_attrs(self): + self.attrs = {"min": -5.0, "max": 10.0, "seed": 10} + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose( + hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) + + class TestUniformRandomOp_attr_tensor(OpTest): def setUp(self): self.op_type = "uniform_random" @@ -91,6 +117,27 @@ class TestUniformRandomOp_attr_tensor(OpTest): hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) +class TestUniformRandomOp_attr_tensor_int32(OpTest): + def setUp(self): + self.op_type = "uniform_random" + self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int32")} + self.init_attrs() + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + + def init_attrs(self): + self.attrs = {"min": -5.0, "max": 10.0, "seed": 10} + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose( + hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) + + class TestUniformRandomOp(OpTest): def setUp(self): self.op_type = "uniform_random" @@ -235,13 +282,47 @@ class TestUniformRandomOp_attr_tensor_API(unittest.TestCase): dim_tensor = fluid.layers.fill_constant([1], "int64", 3) ret = fluid.layers.nn.uniform_random([1, dim_tensor, 2]) - use_cuda = False - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + exe.run(startup_program) + outs = exe.run(train_program, fetch_list=[ret]) + + def test_attr_tensorlist_int32_API(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + dim_1 = fluid.layers.fill_constant([1], "int64", 3) + dim_2 = fluid.layers.fill_constant([1], "int32", 2) + ret = fluid.layers.nn.uniform_random([1, dim_1, dim_2]) + + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) exe = fluid.Executor(place) exe.run(startup_program) outs = exe.run(train_program, fetch_list=[ret]) + def test_attr_tensor_int32_API(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + shape = fluid.data(name='shape_tensor', shape=[2], dtype="int32") + ret = fluid.layers.nn.uniform_random(shape) + + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + Shape = np.array([2, 3]).astype('int32') + exe.run(startup_program) + outs = exe.run(train_program, + feed={'shape_tensor': Shape}, + fetch_list=[ret]) + class TestUniformRandomOp_API_seed(unittest.TestCase): def test_attr_tensor_API(self): -- GitLab