diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 3a451c19ec20ed44b54fe304642ad5ad05b70daf..9ea3f05d64f52a09ed8292e7046952265f18855f 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -490,6 +490,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, infer_meta_context.EmplaceBackAttr( phi::Scalar(PADDLE_GET_CONST(int, attr))); break; + case framework::proto::AttrType::LONG: + infer_meta_context.EmplaceBackAttr( + phi::Scalar(PADDLE_GET_CONST(int64_t, attr))); + break; case framework::proto::AttrType::STRING: infer_meta_context.EmplaceBackAttr( phi::Scalar(PADDLE_GET_CONST(std::string, attr))); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 23fce93ef30a3262f977d4a77bf8a6debf0dcf03..71bd350af6eff599269c4fc86426b525ccecd495 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2753,6 +2753,10 @@ void OperatorWithKernel::BuildPhiKernelContext( phi_kernel_context->EmplaceBackAttr(std::move( phi::Scalar(PADDLE_GET_CONST(int, attr_iter->second)))); break; + case proto::AttrType::LONG: + phi_kernel_context->EmplaceBackAttr(std::move( + phi::Scalar(PADDLE_GET_CONST(int64_t, attr_iter->second)))); + break; case proto::AttrType::STRING: phi_kernel_context->EmplaceBackAttr(std::move(phi::Scalar( PADDLE_GET_CONST(std::string, attr_iter->second)))); diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 1e76757e1c048dcb01f40c2c38a2dcab327dd564..b6c78c47a287c1b6d85dda4510c6b9d207dd30f3 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -420,6 +420,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, kernel_ctx->EmplaceBackAttr( std::move(phi::Scalar(PADDLE_GET_CONST(int, attr)))); break; + case framework::proto::AttrType::LONG: + kernel_ctx->EmplaceBackAttr( + std::move(phi::Scalar(PADDLE_GET_CONST(int64_t, attr)))); + break; case framework::proto::AttrType::STRING: kernel_ctx->EmplaceBackAttr( std::move(phi::Scalar(PADDLE_GET_CONST(std::string, attr)))); diff --git a/paddle/fluid/operators/eye_op.cc b/paddle/fluid/operators/eye_op.cc index 5ff3641e757ce5dc9019e546695e5f912aa939e6..629400a403e461115d8c98ef6ee927fe3af3df44 100644 --- a/paddle/fluid/operators/eye_op.cc +++ b/paddle/fluid/operators/eye_op.cc @@ -50,11 +50,13 @@ class EyeOpMaker : public framework::OpProtoAndCheckerMaker { "Output data type") .SetDefault(framework::proto::VarType::FP32); AddAttr("num_rows", - "(int64_t) the number of rows in output tensor"); + "(int64_t) the number of rows in output tensor") + .SupportTensor(); AddAttr("num_columns", "(int64_t) the number of columns in output tensor." "Default -1 means that num_columns=num_rows") - .SetDefault(-1); + .SetDefault(-1) + .SupportTensor(); AddOutput("Out", "(Tensor) Construct an identity tensor with " "specified shape [num_rows, num_columns]"); diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 840738d4bb682afb9f17c6fc07c0412ef11a9495..165005e808ca9071d616ac4b637be58cca7730fb 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -874,7 +874,7 @@ backward : exponential__grad - api : eye - args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={}) + args : (Scalar num_rows, Scalar num_columns, DataType dtype=DataType::FLOAT32, Place place={}) output : Tensor(out) infer_meta : func : EyeInferMeta diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index 442e62473238601f2fe06599e811fe252f691bbc..419590cbe67527bc2a11b2fe1dc9846c65ea980c 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -51,12 +51,25 @@ void CreateInferMetaBase(const std::vector& shape, out->set_layout(layout); } -void EyeInferMeta(int64_t num_rows, - int64_t num_columns, +void EyeInferMeta(const Scalar& num_rows, + const Scalar& num_columns, DataType dtype, - MetaTensor* out) { - if (num_columns == -1) num_columns = num_rows; - out->set_dims({num_rows, num_columns}); + MetaTensor* out, + MetaConfig config) { + int64_t rows, columns; + if (!config.is_runtime && num_rows.FromTensor()) { + rows = -1; + } else { + rows = num_rows.to(); + } + + if (!config.is_runtime && num_columns.FromTensor()) { + columns = -1; + } else { + columns = num_columns.to(); + if (columns == -1) columns = rows; + } + out->set_dims({rows, columns}); out->set_dtype(dtype); } diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 59673ba8bcf3687294f7103838c4308d6ec95bbf..27c89821a319e885ce237056b6391b27b9410086 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/common/int_array.h" +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/meta_tensor.h" namespace phi { @@ -41,10 +42,11 @@ void CreateInferMetaBase(const std::vector& shape, DataLayout layout, MetaTensor* out); -void EyeInferMeta(int64_t num_rows, - int64_t num_columns, +void EyeInferMeta(const Scalar& num_rows, + const Scalar& num_columns, DataType dtype, - MetaTensor* out); + MetaTensor* out, + MetaConfig config = MetaConfig()); void GaussianRandomInferMeta(const IntArray& shape, float mean, diff --git a/paddle/phi/kernels/eye_kernel.h b/paddle/phi/kernels/eye_kernel.h index e9e1abffd143324a12809fe784c4138a77352930..c7c2a627d705b4064e22f6f5d0064c5ccbd67426 100644 --- a/paddle/phi/kernels/eye_kernel.h +++ b/paddle/phi/kernels/eye_kernel.h @@ -14,14 +14,15 @@ #pragma once +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { template void EyeKernel(const Context& ctx, - int64_t num_rows, - int64_t num_columns, + const Scalar& num_rows, + const Scalar& num_columns, DataType dtype, DenseTensor* out); diff --git a/paddle/phi/kernels/impl/eye_kernel_impl.h b/paddle/phi/kernels/impl/eye_kernel_impl.h index f4041f921fd352e63688b24d77987eabd526d4a2..57b9ce73e89ce10c56ffdc88c573a9ae97f012fc 100644 --- a/paddle/phi/kernels/impl/eye_kernel_impl.h +++ b/paddle/phi/kernels/impl/eye_kernel_impl.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/common/scalar.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -34,20 +35,21 @@ struct EyeFunctor { template void EyeKernel(const Context& ctx, - int64_t num_rows, - int64_t num_columns, + const Scalar& num_rows, + const Scalar& num_columns, DataType dtype, DenseTensor* out) { - auto num = num_columns; - if (num == -1) { - num = num_rows; + auto columns = num_columns.to(); + auto rows = num_rows.to(); + if (columns == -1) { + columns = rows; } T* out_data = ctx.template Alloc(out); phi::funcs::SetConstant set_zero; set_zero(ctx, out, static_cast(0)); - int64_t num_eyes = (std::min)(num_rows, num); + int64_t num_eyes = (std::min)(rows, columns); paddle::platform::ForRange for_range(ctx, num_eyes); - EyeFunctor functor(num, out_data); + EyeFunctor functor(columns, out_data); for_range(functor); } diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 1eefe759c708da6a15d0b0227dc3f779bc518419..b62df9c102485ae056fb0fbc85b6d7391a98c88c 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1791,11 +1791,17 @@ def eye(num_rows, """ + def _check_attr(attr, message): + if isinstance(attr, ((Variable, core.VarBase, core.eager.Tensor))): + assert len(attr.shape) == 1 and attr.shape[0] in [1, -1] + elif not isinstance(attr, int) or attr < 0: + raise TypeError("{} should be a non-negative int.".format(message)) + + _check_attr(num_rows, "num_rows") if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) if num_columns is not None: - if not isinstance(num_columns, int) or num_columns < 0: - raise TypeError("num_columns should be a non-negative int") + _check_attr(num_columns, "num_columns") else: num_columns = num_rows @@ -1809,8 +1815,6 @@ def eye(num_rows, helper = LayerHelper("eye", **locals()) check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64', 'int32', 'int64'], 'eye') - if not isinstance(num_rows, int) or num_rows < 0: - raise TypeError("num_rows should be a non-negative int") out = helper.create_variable_for_type_inference(dtype=dtype) helper.append_op(type='eye', inputs={}, diff --git a/python/paddle/fluid/tests/unittests/test_eye_op.py b/python/paddle/fluid/tests/unittests/test_eye_op.py index d74cabb4275ad39cb990b6889e73d224510f3103..2b4af0fbdb9ca3091d23266d3dab8c5e6041d255 100644 --- a/python/paddle/fluid/tests/unittests/test_eye_op.py +++ b/python/paddle/fluid/tests/unittests/test_eye_op.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import unittest import numpy as np from op_test import OpTest @@ -22,6 +23,9 @@ import paddle import paddle.fluid as fluid import paddle.fluid.framework as framework +from paddle.fluid.framework import program_guard, Program +from test_attribute_var import UnittestBase + class TestEyeOp(OpTest): @@ -162,5 +166,69 @@ class API_TestTensorEye(unittest.TestCase): self.assertRaises(TypeError, test_num_columns_type_check1) +class TestEyeRowsCol(UnittestBase): + + def init_info(self): + self.shapes = [[2, 3, 4]] + self.save_path = os.path.join(self.temp_dir.name, self.path_prefix()) + + def test_static(self): + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(4, 10) + x = paddle.randn([2, 3, 4]) + x.stop_gradient = False + feat = fc(x) # [2,3,10] + + tmp = self.call_func(feat) + out = feat + tmp + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out)) + self.assertTrue(self.var_prefix() in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[tmp, out]) + gt = np.eye(3, 10) + np.testing.assert_allclose(res[0], gt) + paddle.static.save_inference_model(self.save_path, [x], [tmp, out], + exe) + # Test for Inference Predictor + infer_outs = self.infer_prog() + np.testing.assert_allclose(infer_outs[0], gt) + + def path_prefix(self): + return 'eye_rows_cols' + + def var_prefix(self): + return "Var[" + + def call_func(self, x): + rows = paddle.assign(3) + cols = paddle.assign(10) + out = paddle.eye(rows, cols) + return out + + def test_error(self): + with self.assertRaises(TypeError): + paddle.eye(-1) + + +class TestEyeRowsCol2(TestEyeRowsCol): + + def call_func(self, x): + rows = paddle.assign(3) + cols = paddle.assign(10) + out = paddle.fluid.layers.eye(rows, cols) + return out + + def test_error(self): + with self.assertRaises(TypeError): + paddle.fluid.layers.eye(-1) + + if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index b6291a9248890c3eebb4fb95fe1866e1db751973..c8d433002387eb42ebf3f362d9726ab6c50775fa 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -710,16 +710,20 @@ def eye(num_rows, num_columns=None, dtype=None, name=None): # [0 1 0]] """ + def _check_attr(attr, message): + if isinstance(attr, ((Variable, core.VarBase, core.eager.Tensor))): + assert len(attr.shape) == 1 and attr.shape[0] in [1, -1] + elif not isinstance(attr, int) or attr < 0: + raise TypeError("{} should be a non-negative int.".format(message)) + + _check_attr(num_rows, "num_rows") + if dtype is None: dtype = 'float32' - if num_columns is None: - num_columns = num_rows - if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) if num_columns is not None: - if not isinstance(num_columns, int) or num_columns < 0: - raise TypeError("num_columns should be a non-negative int") + _check_attr(num_columns, "num_columns") else: num_columns = num_rows @@ -735,8 +739,6 @@ def eye(num_rows, num_columns=None, dtype=None, name=None): helper = LayerHelper("eye", **locals()) check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64', 'int32', 'int64'], 'eye') - if not isinstance(num_rows, int) or num_rows < 0: - raise TypeError("num_rows should be a non-negative int") out = helper.create_variable_for_type_inference(dtype=dtype) helper.append_op(type='eye', inputs={},