提交 80537a1d 编写于 作者: P pangyoki

add multinomial python api unittest

上级 c66eec75
...@@ -26,69 +26,17 @@ limitations under the License. */ ...@@ -26,69 +26,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
/*
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
*/
/*
template <class T>
__global__ void SumArrayCUDAKernel(T **in, T *out, size_t in_size) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
// T total(read_dst ? out[id] : static_cast<T>(0));
T total(static_cast<T>(0))
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += tmp[id];
}
}
out[id] = total;
id += blockDim.x * gridDim.x;
}*/
/*
template <typename T>
__global__ void NormalizeProbability(T* probs, int64_t rows, int64_t cols) {
extern __shared__ std::vector<T> sum_rows(rows);
T val;
for (int64_t i = blockId.x; i < rows; i += gridDim.x) {
T sum = static_cast<T>(0);
for (int64_t j = threadIdx.x; j < cols; j += blockDim.x) {
val = probs[i * cols + j];
sum += val;
}
}
}*/
template <typename T> template <typename T>
__global__ void NormalizeProbability(T* norm_probs, const T* in_data, __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) { T* sum_rows) {
// int id = blockIdx.x * blockDim.x + threadIdx.x;
// int id = threadIdx.x;
int id = threadIdx.x + blockIdx.x * blockDim.x + int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x; blockIdx.y * gridDim.x * blockDim.x;
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y]; norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
} }
template <typename T>
__global__ void yokiFunc(const T* in_data, T* out) {
// int id = blockIdx.x * blockDim.x + threadIdx.x;
// int id = threadIdx.x;
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
out[id] = in_data[id];
}
template <typename T> template <typename T>
__global__ void Cumsum(T* norm_probs_data, int64_t num_distributions, __global__ void Cumsum(T* norm_probs_data, int64_t num_distributions,
int64_t num_categories, T* cumulative_probs) { int64_t num_categories, T* cumulative_probs) {
// int id = blockIdx.x;
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) { for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
thrust::inclusive_scan(thrust::device, thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories, norm_probs_data + id * num_categories,
...@@ -111,52 +59,43 @@ struct RandomGeneratorCudaFunctor { ...@@ -111,52 +59,43 @@ struct RandomGeneratorCudaFunctor {
} }
}; };
/*
template <typename T> template <typename T>
class MultinomialCudaFunctor(T* out_data, const T* in_data, __device__ int binarySearchFunctor(T* cumdist, T* dist, int size, T val) {
const int64_t num_samples, const bool replacement, int left = 0;
const int64_t num_categories, int right = size;
const int64_t num_distributions) {
}*/
template <typename T>
__device__ int binarySearchForMultinomial(T* cumdist, T* dist, int size,
T val) {
int start = 0;
int end = size;
// cumdist[size - 1] = 0 => all zero prob dist // cumdist[size - 1] = 0 => all zero prob dist
// CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<T>(0)); // CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<T>(0));
while (end - start > 0) { while (right - left > 0) {
int mid = start + (end - start) / 2; int mid = left + (right - left) / 2;
T midVal = cumdist[mid]; T midVal = cumdist[mid];
if (midVal < val) { if (midVal < val) {
start = mid + 1; left = mid + 1;
} else { } else {
end = mid; right = mid;
} }
} }
if (start == size) { if (left == size) {
// No probability mass or precision problems; just return the // No probability mass or precision problems; just return the
// first non-zero element by setting start to size-1 here, // first non-zero element by setting left to size-1 here,
// the code below will move it to the last non-zero probability // the code below will move it to the last non-zero probability
// this actually can happen when the random number is 1 // this actually can happen when the random number is 1
// (github pytorch issue #4858). // (github pytorch issue #4858).
start = size - 1; left = size - 1;
} }
while (start >= 1 && dist[start] == 0) start--; while (left >= 1 && dist[left] == 0) left--;
return start; return left;
} }
template <typename T> template <typename T>
__global__ void sampleMultinomialWithReplacement( __global__ void sampleMultinomialWithReplacement(
T* rng, const int64_t totalSamples, T* dest, const int64_t distributions, T* rng_data, const int64_t num_samples, T* out_data,
const int64_t categories, T* normDistPrefixSum, T* normDist) { const int64_t num_distributions, const int64_t num_categories,
T* cumulative_probs, T* norm_probs_data) {
// At the moment, each warp computes one sample value in the binary // At the moment, each warp computes one sample value in the binary
// search due to divergence. It seems possible to compute multiple // search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. // values and limit divergence though later on.
...@@ -170,22 +109,23 @@ __global__ void sampleMultinomialWithReplacement( ...@@ -170,22 +109,23 @@ __global__ void sampleMultinomialWithReplacement(
int idx = threadIdx.x + blockIdx.x * blockDim.x + int idx = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x; blockIdx.y * gridDim.x * blockDim.x;
for (int curDist = blockIdx.y; curDist < distributions; for (int curDist = blockIdx.y; curDist < num_distributions;
curDist += gridDim.y) { curDist += gridDim.y) {
for (int sample = blockIdx.x * blockDim.x + threadIdx.x; for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
sample < totalSamples; sample += blockDim.x * gridDim.x) { sample < num_samples; sample += blockDim.x * gridDim.x) {
// we are losing 3 out of 4 generated numbers but it's ok // we are losing 3 out of 4 generated numbers but it's ok
// this kernel is not very efficient anyway // this kernel is not very efficient anyway
// T uniform_random = dist(rng); // T uniform_random = dist(rng);
T uniform_random = rng[sample + curDist * totalSamples]; T uniform_random = rng_data[sample + curDist * num_samples];
// Find the bucket that a uniform sample lies in // Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial<T>( int choice =
normDistPrefixSum + curDist * categories, binarySearchFunctor<T>(cumulative_probs + curDist * num_categories,
normDist + curDist * categories, categories, uniform_random); norm_probs_data + curDist * num_categories,
num_categories, uniform_random);
dest[sample + curDist * totalSamples] = choice; out_data[sample + curDist * num_samples] = choice;
} }
} }
} }
...@@ -198,14 +138,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T> ...@@ -198,14 +138,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
const auto x = ctx.Input<framework::Tensor>("X"); const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
// auto yokiout = ctx.Output<framework::Tensor>("yokiOut");
const int64_t num_samples = ctx.Attr<int>("num_samples"); const int64_t num_samples = ctx.Attr<int>("num_samples");
const bool replacement = ctx.Attr<bool>("replacement"); const bool replacement = ctx.Attr<bool>("replacement");
auto* in_data = x->data<T>(); auto* in_data = x->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace()); auto* out_data = out->mutable_data<T>(ctx.GetPlace());
// auto* yokiout_data = yokiout->mutable_data<T>(ctx.GetPlace());
auto in_dims = x->dims(); auto in_dims = x->dims();
int64_t in_rank = in_dims.size(); int64_t in_rank = in_dims.size();
...@@ -215,10 +152,6 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T> ...@@ -215,10 +152,6 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
if (!replacement) { if (!replacement) {
int in_data_numel = x->numel(); int in_data_numel = x->numel();
int out_data_numel = out->numel(); int out_data_numel = out->numel();
// std::vector<T> cpu_in_data(in_data_numel);
// std::vector<T> cpu_out_data(out_data_numel);
// T cpu_in_data[in_data_numel];
// T cpu_out_data[out_data_numel];
T* cpu_in_data = new T[in_data_numel]; T* cpu_in_data = new T[in_data_numel];
T* cpu_out_data = new T[out_data_numel]; T* cpu_out_data = new T[out_data_numel];
...@@ -226,10 +159,6 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T> ...@@ -226,10 +159,6 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T), cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
VLOG(3) << "Print cpu_in_data " << cpu_in_data[0] << "\n";
VLOG(3) << "Print in_data_numel " << in_data_numel << "\n";
VLOG(3) << "Print out_data_numel " << out_data_numel << "\n";
MultinomialFunctor<T>(cpu_out_data, cpu_in_data, num_samples, replacement, MultinomialFunctor<T>(cpu_out_data, cpu_in_data, num_samples, replacement,
num_categories, num_distributions); num_categories, num_distributions);
cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(T), cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(T),
...@@ -240,21 +169,9 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T> ...@@ -240,21 +169,9 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
return; return;
} }
// std::vector<T> sum_rows(num_distributions);
// SumArrayCUDAKernel<T>(in_data, sum_rows,)
VLOG(3) << "Print num_distributions " << num_distributions << "\n";
VLOG(3) << "Print num_categories " << num_categories << "\n";
VLOG(3) << "Print in_rank " << in_rank << "\n";
framework::Tensor sum_rows_t; framework::Tensor sum_rows_t;
auto* sum_rows_data = auto* sum_rows_data =
sum_rows_t.mutable_data<T>({num_distributions}, ctx.GetPlace()); sum_rows_t.mutable_data<T>({num_distributions}, ctx.GetPlace());
// auto* sum_rows_data =
// sum_rows_t->mutable_data<T>(framework::make_ddim({num_distributions}),
// ctx.GetPlace());
auto& place = *ctx.template device_context<platform::CUDADeviceContext>() auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device(); .eigen_device();
...@@ -262,58 +179,34 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T> ...@@ -262,58 +179,34 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
if (num_distributions == 1) { if (num_distributions == 1) {
auto eigen_input = framework::EigenVector<T>::Flatten(*x); auto eigen_input = framework::EigenVector<T>::Flatten(*x);
auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t); auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
// auto eigen_sum_rows = framework::EigenScalar<T>::From(sum_rows_t);
eigen_sum_rows.device(place) = eigen_sum_rows.device(place) =
eigen_input.sum(Eigen::DSizes<int, 1>(1)) eigen_input.sum(Eigen::DSizes<int, 1>(1))
.eval() .eval()
.reshape(Eigen::DSizes<int, 1>(sum_rows_t.dims()[0])); .reshape(Eigen::DSizes<int, 1>(sum_rows_t.dims()[0]));
} else { } else {
auto eigen_input = framework::EigenMatrix<T>::From(*x); auto eigen_input = framework::EigenMatrix<T>::From(*x);
// auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t); auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1)); eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1));
// .eval()
// .reshape(Eigen::DSizes<int, 1>(sum_rows_t.dims()[0]));
// eigen_sum_rows.device(place) =
// eigen_input.sum().eval().reshape(Eigen::DSizes<int, 1>(1));
} }
// std::vector<T> in_data_norm(num_categories);
framework::Tensor norm_probs_t; framework::Tensor norm_probs_t;
auto* norm_probs_data = norm_probs_t.mutable_data<T>( auto* norm_probs_data = norm_probs_t.mutable_data<T>(
{num_distributions, num_categories}, ctx.GetPlace()); {num_distributions, num_categories}, ctx.GetPlace());
// dim3 grid(num_distributions);
// dim3 block(num_categories);
dim3 block(num_categories < 512 ? num_categories : 512); dim3 block(num_categories < 512 ? num_categories : 512);
dim3 grid((num_categories - 1) / block.x + 1, num_distributions); dim3 grid((num_categories - 1) / block.x + 1, num_distributions);
NormalizeProbability< NormalizeProbability<
T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>( T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, in_data, sum_rows_data); norm_probs_data, in_data, sum_rows_data);
// num_distributions can only be 1.
// std::vector<T> cumulative_probs(num_categories);
framework::Tensor cumulative_probs_t; framework::Tensor cumulative_probs_t;
auto* cumulative_probs = cumulative_probs_t.mutable_data<T>( auto* cumulative_probs = cumulative_probs_t.mutable_data<T>(
{num_distributions, num_categories}, ctx.GetPlace()); {num_distributions, num_categories}, ctx.GetPlace());
// T cumulative_probs[num_categories];
dim3 block1(1); dim3 block1(1);
dim3 grid1(num_distributions); dim3 grid1(num_distributions);
Cumsum<T><<<grid1, block1, 0, ctx.cuda_device_context().stream()>>>( Cumsum<T><<<grid1, block1, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, num_distributions, num_categories, cumulative_probs); norm_probs_data, num_distributions, num_categories, cumulative_probs);
/*
dim3 block2(num_categories < 512 ? num_categories : 512);
dim3 grid2((num_categories-1)/block2.x+1, num_distributions);
yokiFunc<T><<<grid2, block2, 0, ctx.cuda_device_context().stream()>>>(
cumulative_probs, yokiout_data);*/
// int64_t size = num_categories;
// thrust::inclusive_scan(thrust::device, norm_probs_data,
// norm_probs_data + num_categories,
// cumulative_probs);
VLOG(3) << "Print cumsum " << cumulative_probs << "\n"; VLOG(3) << "Print cumsum " << cumulative_probs << "\n";
if (replacement) { if (replacement) {
...@@ -336,24 +229,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T> ...@@ -336,24 +229,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
index_sequence_begin + num_distributions * num_samples, rng_data, index_sequence_begin + num_distributions * num_samples, rng_data,
RandomGeneratorCudaFunctor<T>(seed)); RandomGeneratorCudaFunctor<T>(seed));
VLOG(3) << "Print enter\n";
// VLOG(3) << "Print size in_data " <<
// sizeof(in_data)/sizeof(in_data[num_categories-1]) << "\n";
// VLOG(3) << "Print norm_probs_data0 " <<
// sizeof(norm_probs_data[num_categories-1]) << "\n";
sampleMultinomialWithReplacement< sampleMultinomialWithReplacement<
T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>( T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
rng_data, num_samples, out_data, num_distributions, num_categories, rng_data, num_samples, out_data, num_distributions, num_categories,
cumulative_probs, norm_probs_data); cumulative_probs, norm_probs_data);
VLOG(3) << "Print end\n" << out_data;
} }
VLOG(3) << "Print final end\n";
// MultinomialCudaFunctor<T>(out_data, in_data, num_samples, replacement,
// num_categories, num_distributions);
} }
}; };
......
...@@ -126,6 +126,49 @@ class TestMultinomialApi(unittest.TestCase): ...@@ -126,6 +126,49 @@ class TestMultinomialApi(unittest.TestCase):
sample_prob, prob, rtol=0, atol=0.01), sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob)) "sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
def test_dygraph2(self):
paddle.disable_static()
x = paddle.rand([3, 4])
out = paddle.multinomial(x, num_samples=100000, replacement=True)
x_numpy = x.numpy()
out_list = np.split(out.numpy(), 3, axis=0)
count_array = [0] * 3
for i in range(3):
count_array[i] = np.unique(
out_list[i], return_counts=True)[1].astype("float32")
sample_prob = np.stack(count_array, axis=0)
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
paddle.enable_static()
def test_dygraph3(self):
paddle.disable_static()
x = paddle.rand([1000])
out = paddle.multinomial(x, num_samples=100, replacement=False)
x_numpy = x.numpy()
unique_out = np.unique(out.numpy())
self.assertEqual(
len(unique_out), 100,
"replacement is False. categories can't be sampled repeatedly")
paddle.enable_static()
"""
def test_replacement_error(self):
def test_error():
paddle.disable_static()
x = paddle.rand([5])
out = paddle.multinomial(x, num_samples=10, replacement=False)
self.assertRaises(OutOfRangeError, test_error) # not OutOfRangeError
"""
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO: define random functions # TODO: define random functions
from ..fluid import core from ..fluid import core
from ..fluid.framework import in_dygraph_mode, Variable, convert_np_dtype_to_dtype_ from ..fluid.framework import in_dygraph_mode, Variable, convert_np_dtype_to_dtype_
...@@ -40,18 +40,18 @@ def bernoulli(x, name=None): ...@@ -40,18 +40,18 @@ def bernoulli(x, name=None):
This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution. This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution.
The input ``x`` is a tensor with probabilities for generating the random binary number. The input ``x`` is a tensor with probabilities for generating the random binary number.
Each element in ``x`` should be in [0, 1], and the out is generated by: Each element in ``x`` should be in [0, 1], and the out is generated by:
.. math:: .. math::
out_i ~ Bernoulli (x_i) out_i ~ Bernoulli (x_i)
Args: Args:
x(Tensor): A tensor with probabilities for generating the random binary number. The data type x(Tensor): A tensor with probabilities for generating the random binary number. The data type
should be float32, float64. should be float32, float64.
name(str, optional): The default value is None. Normally there is no name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Returns: Returns:
Tensor: A Tensor filled with random binary number with the same shape and dtype as ``x``. Tensor: A Tensor filled with random binary number with the same shape and dtype as ``x``.
Examples: Examples:
...@@ -80,7 +80,7 @@ def bernoulli(x, name=None): ...@@ -80,7 +80,7 @@ def bernoulli(x, name=None):
helper = LayerHelper("randint", **locals()) helper = LayerHelper("randint", **locals())
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
dtype=x.dtype) # maybe set out to int32 ? dtype=x.dtype) # maybe set out to int32 ?
helper.append_op( helper.append_op(
type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={}) type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={})
return out return out
...@@ -88,8 +88,23 @@ def bernoulli(x, name=None): ...@@ -88,8 +88,23 @@ def bernoulli(x, name=None):
def multinomial(x, num_samples=1, replacement=False, name=None): def multinomial(x, num_samples=1, replacement=False, name=None):
""" """
This OP returns a Tensor filled with random values sampled from a Multinomical
distribution. The input ``x`` is a tensor with probabilities for generating the
random number. Each element in ``x`` should be larger or equal to 0, but not all
0. ``replacement`` indicates whether it is a replaceable sample. If ``replacement``
is True, a category can be sampled more than once.
Args:
x(Tensor): A tensor with probabilities for generating the random number. The data type
should be float32, float64.
num_samples(int, optional): Number of samples, default is 1.
replacement(bool, optional): whether it is a replaceable sample, default is False.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with sampled category index after ``num_samples`` times samples.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -97,15 +112,24 @@ def multinomial(x, num_samples=1, replacement=False, name=None): ...@@ -97,15 +112,24 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
paddle.disable_static() paddle.disable_static()
x = paddle.rand([2, 3]) x = paddle.rand([2,4])
print(x.numpy()) print(x.numpy())
# [[0.11272584 0.3890902 0.7730957 ] # [[0.7713825 0.4055941 0.433339 0.70706886]
# [0.10351662 0.8510418 0.63806665]] # [0.9223313 0.8519825 0.04574518 0.16560672]]
out = paddle.bernoulli(x) out1 = paddle.multinomial(x, num_samples=5, replacement=True)
print(out.numpy()) print(out1.numpy())
# [[0. 0. 1.] # [[3. 3. 1. 1. 0.]
# [0. 0. 1.]] # [0. 0. 0. 0. 1.]]
out2 = paddle.multinomial(x, num_samples=5)
# OutOfRangeError: When replacement is False, number of samples
# should be less than non-zero categories
out3 = paddle.multinomial(x, num_samples=3)
print(out3.numpy())
# [[0. 2. 3.]
# [0. 1. 3.]]
""" """
...@@ -152,7 +176,7 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None): ...@@ -152,7 +176,7 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
Returns: Returns:
Tensor: A Tensor filled with random values sampled from a Gaussian Tensor: A Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``. distribution, with ``shape`` and ``dtype``.
""" """
op_type_for_check = 'gaussian/standard_normal/randn/normal' op_type_for_check = 'gaussian/standard_normal/randn/normal'
seed = 0 seed = 0
...@@ -393,7 +417,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): ...@@ -393,7 +417,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static() paddle.disable_static()
...@@ -481,7 +505,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -481,7 +505,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Returns: Returns:
Tensor: A Tensor filled with random integers from a discrete uniform Tensor: A Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``. distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
...@@ -591,7 +615,7 @@ def randperm(n, dtype="int64", name=None): ...@@ -591,7 +615,7 @@ def randperm(n, dtype="int64", name=None):
out2 = paddle.randperm(7, 'int32') out2 = paddle.randperm(7, 'int32')
# [1, 6, 2, 0, 4, 3, 5] # random # [1, 6, 2, 0, 4, 3, 5] # random
""" """
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册