未验证 提交 dd28cada 编写于 作者: L Leo Chen 提交者: GitHub

Refine bernoulli and unsqueeze op (#26842) (#26885)

* add check for bernoulli and register bool for unsqueeze

* follow comments
上级 7495b288
...@@ -31,6 +31,10 @@ struct BernoulliCudaFunctor { ...@@ -31,6 +31,10 @@ struct BernoulliCudaFunctor {
__host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {} __host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n, const T p) const { __host__ __device__ T operator()(const unsigned int n, const T p) const {
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
// lines of error messages if, and it should be refined.
PADDLE_ENFORCE(p >= 0.0 && p <= 1.0,
"The probability should be >=0 and <= 1, but got %f", p);
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed_); rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0); thrust::uniform_real_distribution<T> dist(0.0, 1.0);
......
...@@ -25,10 +25,12 @@ namespace operators { ...@@ -25,10 +25,12 @@ namespace operators {
template <typename T> template <typename T>
inline HOSTDEVICE T BernoulliFunctor(T p, T rand) { inline HOSTDEVICE T BernoulliFunctor(T p, T rand) {
PADDLE_ENFORCE_LE(p, 1, platform::errors::OutOfRange( PADDLE_ENFORCE_LE(p, 1.0,
platform::errors::OutOfRange(
"The probability should be <= 1, but got %f", p)); "The probability should be <= 1, but got %f", p));
PADDLE_ENFORCE_GE(p, 0, platform::errors::OutOfRange( PADDLE_ENFORCE_GE(p, 0.0,
"The probability should be >= 1, but got %f", p)); platform::errors::OutOfRange(
"The probability should be >= 0, but got %f", p));
return static_cast<T>(rand < p); return static_cast<T>(rand < p);
} }
......
...@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/unsqueeze_op.h" #include "paddle/fluid/operators/unsqueeze_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
...@@ -327,6 +329,7 @@ REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp, ...@@ -327,6 +329,7 @@ REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>, unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -334,12 +337,14 @@ REGISTER_OP_CPU_KERNEL( ...@@ -334,12 +337,14 @@ REGISTER_OP_CPU_KERNEL(
unsqueeze_grad, unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>, unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -347,6 +352,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -347,6 +352,7 @@ REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad, unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>, unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -30,6 +31,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -30,6 +31,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>, plat::float16>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -38,6 +40,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -38,6 +40,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -47,6 +50,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -47,6 +50,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::float16>, plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -433,8 +433,8 @@ def stack(x, axis=0, name=None): ...@@ -433,8 +433,8 @@ def stack(x, axis=0, name=None):
[5.0, 6.0] ] ] [5.0, 6.0] ] ]
Args: Args:
x (Tensor|list[Tensor]): Input ``x`` can be a single tensor, or a ``list`` of tensors. x (Tensor|list[Tensor]|tuple[Tensor]): Input ``x`` can be a single tensor, or a ``list`` or ``tuple`` of tensors.
If ``x`` is a ``list``, the Tensors in ``x`` If ``x`` is a ``list`` or ``tuple`` , the Tensors in ``x``
must be of the same shape and dtype. Supported data types: float32, float64, int32, int64. must be of the same shape and dtype. Supported data types: float32, float64, int32, int64.
axis (int, optional): The axis along which all inputs are stacked. ``axis`` range is ``[-(R+1), R+1)``, axis (int, optional): The axis along which all inputs are stacked. ``axis`` range is ``[-(R+1), R+1)``,
where ``R`` is the number of dimensions of the first input tensor ``x[0]``. where ``R`` is the number of dimensions of the first input tensor ``x[0]``.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册