未验证 提交 47b6e5ff 编写于 作者: X xiongkun 提交者: GitHub

[Yaml] add yaml for Uniform random and add unit test. (#41517) (#41619)

* gather op

* add mod

* [Yaml] final state for uniform and uniform_random
上级 a0b0a32f
......@@ -16,9 +16,11 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/phi/infermeta/nullary.h"
namespace paddle {
namespace operators {
......@@ -122,74 +124,6 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "UniformRandomOp");
PADDLE_ENFORCE_LT(
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
platform::errors::InvalidArgument(
"The uniform_random's min must less then max. But received min = "
"%f great than or equal max = %f.",
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
platform::errors::InvalidArgument(
"The uniform_random's diag_num must greater than or "
"equal 0. But recevied diag_num (%d) < 0.",
ctx->Attrs().Get<int>("diag_num")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
platform::errors::InvalidArgument(
"The uniform_random's diag_step must greater than or "
"equal 0. But recevied diag_step (%d) < 0.",
ctx->Attrs().Get<int>("diag_step")));
if (ctx->HasInputs("ShapeTensorList")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensorList");
PADDLE_ENFORCE_GT(inputs_name.size(), 0,
platform::errors::InvalidArgument(
"Input(ShapeTensorList)'size of "
"Op(uniform_random) can't be zero."
"Please check the Attr(shape)'s size of"
"Op(fluid.layers.uniform_random).)"));
auto out_dims = std::vector<int>(inputs_name.size(), -1);
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
return;
}
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
if (ctx->HasInput("ShapeTensor") && shape.empty()) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
PADDLE_ENFORCE_EQ(
shape_dims.size(), 1,
platform::errors::InvalidArgument(
"ShapeError: Input(ShapeTensor)' dimension size of "
"Op(uniform_random) must be 1."
"But received ShapeTensor's dimensions = %d, shape = [%s]",
shape_dims.size(), shape_dims));
int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
auto vec_dims = std::vector<int64_t>(num_ele, -1);
auto out_dims = phi::make_ddim(vec_dims);
ctx->SetOutputDim("Out", out_dims);
return;
}
PADDLE_ENFORCE_EQ(shape.empty(), false,
platform::errors::InvalidArgument(
"if there is no Input(ShapeTensorList) and no "
"Input(ShapeTensor),the "
"attr(shape) information must "
"be set by Attr(shape)."));
std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape.size());
for (auto dim : shape) {
tensor_shape.push_back(static_cast<int64_t>(dim));
}
ctx->SetOutputDim("Out", phi::make_ddim(tensor_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -274,12 +208,16 @@ class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(uniform_random, UniformRandomInferShapeFunctor,
PD_INFER_META(phi::UniformRandomInferMeta));
REGISTER_OPERATOR(
uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::UniformRandomOpVarTypeInference);
paddle::operators::UniformRandomOpVarTypeInference,
UniformRandomInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
uniform_random_batch_size_like,
......
......@@ -63,6 +63,18 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) {
out->set_dtype(dtype);
}
void UniformRandomInferMeta(const IntArray& shape,
DataType dtype,
float min,
float max,
int seed,
MetaTensor* out) {
auto out_dims = phi::make_ddim(shape.GetData());
out->set_dims(out_dims);
out->set_dtype(dtype);
out->set_layout(DataLayout::NCHW);
}
void RandintInferMeta(
int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out) {
PADDLE_ENFORCE_NOT_NULL(
......
......@@ -65,4 +65,11 @@ void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
DataType dtype,
MetaTensor* out);
void UniformRandomInferMeta(const IntArray& shape,
DataType dtype,
float min,
float max,
int seed,
MetaTensor* out);
} // namespace phi
......@@ -26,6 +26,7 @@ import paddle
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard
def output_hist(out):
......@@ -52,6 +53,7 @@ def output_hist_diag(out):
class TestUniformRandomOp_attr_tensorlist(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.new_shape = (1000, 784)
shape_tensor = []
for index, ele in enumerate(self.new_shape):
......@@ -84,6 +86,7 @@ class TestMaxMinAreInt(TestUniformRandomOp_attr_tensorlist):
class TestUniformRandomOp_attr_tensorlist_int32(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.new_shape = (1000, 784)
shape_tensor = []
for index, ele in enumerate(self.new_shape):
......@@ -110,6 +113,7 @@ class TestUniformRandomOp_attr_tensorlist_int32(OpTest):
class TestUniformRandomOp_attr_tensor(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int64")}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
......@@ -131,6 +135,7 @@ class TestUniformRandomOp_attr_tensor(OpTest):
class TestUniformRandomOp_attr_tensor_int32(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int32")}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
......@@ -152,6 +157,7 @@ class TestUniformRandomOp_attr_tensor_int32(OpTest):
class TestUniformRandomOp(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.inputs = {}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
......@@ -174,6 +180,18 @@ class TestUniformRandomOp(OpTest):
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
def test_check_api(self):
places = self._get_places()
for place in places:
with fluid.dygraph.base.guard(place=place):
out = self.python_api(self.attrs['shape'], 'float32',
self.attrs['min'], self.attrs['max'],
self.attrs['seed'])
def test_check_api_eager(self):
with _test_eager_guard():
self.test_check_api()
class TestUniformRandomOpError(unittest.TestCase):
def test_errors(self):
......
......@@ -548,7 +548,14 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if paddle.in_dynamic_mode():
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
return _C_ops.final_state_uniform_random(shape, dtype,
float(min),
float(max), seed,
_current_expected_place())
if _in_legacy_dygraph():
shape = utils.convert_shape_to_list(shape)
return _C_ops.uniform_random('shape', shape, 'min',
float(min), 'max',
......
......@@ -1975,6 +1975,18 @@
func : unfold
backward : unfold_grad
- api : uniform_random
args : (IntArray shape, DataType dtype, float min, float max, int seed, Place place={})
output : Tensor(out)
infer_meta :
func : UniformRandomInferMeta
param: [shape, dtype, min, max, seed]
kernel :
func : uniform_random
param: [shape, dtype, min, max, seed]
data_type : dtype
backend : place
# The `axis` argument of Python API paddle.unique is not vector
- api : unique
args : (Tensor x, bool return_index, bool return_inverse, bool return_counts, int[] axis, DataType dtype=DataType::INT64)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册