From a2e10930cf2781b58875abb9c475375e1282e575 Mon Sep 17 00:00:00 2001 From: channings Date: Fri, 3 Apr 2020 17:24:23 +0800 Subject: [PATCH] update linspace, equal operators to API 2.0 (#23274) * update linspace, equal operators to API 2.0, test=develop * equal support higher performance CUDA kernel, test=develop * update comment of equal&linspace operator, test=develop * update comment of equal&linspace operator, test=develop --- cmake/operators.cmake | 2 +- .../operators/controlflow/CMakeLists.txt | 2 +- .../controlflow/compare_reduce_op.cc | 150 +++++++++++++++++ .../controlflow/compare_reduce_op.cu | 88 ++++++++++ .../operators/controlflow/compare_reduce_op.h | 43 +++++ python/paddle/__init__.py | 4 +- python/paddle/common_ops_import.py | 27 +++ .../tests/unittests/test_compare_reduce_op.py | 156 ++++++++++++++++++ .../fluid/tests/unittests/test_linspace.py | 33 ++++ python/paddle/tensor/__init__.py | 6 +- python/paddle/tensor/creation.py | 99 +++++++++-- python/paddle/tensor/logic.py | 106 ++++++++++-- 12 files changed, 682 insertions(+), 34 deletions(-) create mode 100644 paddle/fluid/operators/controlflow/compare_reduce_op.cc create mode 100644 paddle/fluid/operators/controlflow/compare_reduce_op.cu create mode 100644 paddle/fluid/operators/controlflow/compare_reduce_op.h create mode 100644 python/paddle/common_ops_import.py create mode 100644 python/paddle/fluid/tests/unittests/test_compare_reduce_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index d3471036c2..961c1b554a 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -114,7 +114,7 @@ function(op_library TARGET) endif() # Define operators that don't need pybind here. - foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" + foreach(manual_pybind_op "compare_reduce_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index 7ae54ef480..e1742b03ab 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -9,4 +9,4 @@ cc_test(conditional_block_op_test SRCS conditional_block_op_test.cc DEPS conditi target_link_libraries(conditional_block_infer_op conditional_block_op) -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_reduce);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") diff --git a/paddle/fluid/operators/controlflow/compare_reduce_op.cc b/paddle/fluid/operators/controlflow/compare_reduce_op.cc new file mode 100644 index 0000000000..316b46b02c --- /dev/null +++ b/paddle/fluid/operators/controlflow/compare_reduce_op.cc @@ -0,0 +1,150 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/controlflow/compare_reduce_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class CompareReduceOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + using Tensor = framework::Tensor; + + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* z = context.Output("Out"); + int axis = context.Attr("axis"); + + Tensor tmp; + framework::DDim x_dims = x->dims(); + framework::DDim y_dims = y->dims(); + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector tmp_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), tmp_dims_array.data(), max_dim, + axis); + tmp.mutable_data(framework::make_ddim(tmp_dims_array), + context.GetPlace()); + + if (x->numel() == 1 && y->numel() == 1) { + bool* z_data = tmp.mutable_data(context.GetPlace()); + z_data[0] = Functor()(x->data()[0], y->data()[0]); + } else { + ElementwiseComputeEx( + context, x, y, axis, Functor(), &tmp); + } + + // Reduce by 'logical and' operator + z->mutable_data(context.GetPlace()); + auto ipt = framework::EigenVector::Flatten(tmp); + auto out = framework::EigenScalar::From(*z); + auto& place = *context.template device_context() + .eigen_device(); + auto reduce_dim = Eigen::array({{0}}); + out.device(place) = ipt.all(reduce_dim); + } +}; + +template +class CompareReduceOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + OpComment comment; + AddInput("X", string::Sprintf("the left hand operand of %s operator", + comment.type)); + AddInput("Y", string::Sprintf("the right hand operand of %s operator", + comment.type)); + AddAttr( + "axis", + "The start dimension index for broadcasting Y onto X. [default -1]") + .SetDefault(-1) + .EqualGreaterThan(-1); + AddOutput("Out", string::Sprintf( + "tensor with a bool element. If all " + "element %s, the Out tensor is [True], else [False]", + comment.equation)); + AddComment(string::Sprintf(R"DOC( +It operates element-wise on X and Y, and returns the Out. X, Y is a +N-dim tensor, which could be any type. If all element $%s$, the Out tensor +is [True], else [False] +)DOC", + comment.equation)); + } +}; + +template +class CompareReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* context) const override { + OpComment comment; + PADDLE_ENFORCE_EQ(context->HasInput("X"), true, + platform::errors::InvalidArgument( + "%s operator must have input X", comment.type)); + PADDLE_ENFORCE_EQ(context->HasInput("Y"), true, + platform::errors::InvalidArgument( + "%s operator must have input Y", comment.type)); + auto dim_x = context->GetInputDim("X"); + auto dim_y = context->GetInputDim("Y"); + PADDLE_ENFORCE_GE( + dim_x.size(), dim_y.size(), + platform::errors::InvalidArgument( + "The size of dim_y should not be greater than dim_x's.")); + + context->SetOutputDim("Out", {1}); + context->ShareLoD("X", "Out"); + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_COMPARE_REDUCE_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::CompareReduceOp<_##op_type##Comment>, \ + ::paddle::operators::CompareReduceOpProtoMaker<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker, \ + ::paddle::framework::EmptyGradOpMaker); + +#define REGISTER_COMPARE_REDUCE_CPU_KERNEL(op_type, functor) \ + REGISTER_OP_CPU_KERNEL( \ + op_type, ::paddle::operators::CompareReduceOpKernel< \ + ::paddle::platform::CPUDeviceContext, functor>, \ + ::paddle::operators::CompareReduceOpKernel< \ + ::paddle::platform::CPUDeviceContext, functor>, \ + ::paddle::operators::CompareReduceOpKernel< \ + ::paddle::platform::CPUDeviceContext, functor>, \ + ::paddle::operators::CompareReduceOpKernel< \ + ::paddle::platform::CPUDeviceContext, functor>); +REGISTER_COMPARE_REDUCE_OP(equal_reduce, "X == Y"); + +REGISTER_COMPARE_REDUCE_CPU_KERNEL(equal_reduce, + paddle::operators::EqualReduceFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_reduce_op.cu b/paddle/fluid/operators/controlflow/compare_reduce_op.cu new file mode 100644 index 0000000000..3adac0d966 --- /dev/null +++ b/paddle/fluid/operators/controlflow/compare_reduce_op.cu @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/controlflow/compare_reduce_op.h" +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +namespace paddle { +namespace operators { + +template +struct IdentityFunctor { + HOSTDEVICE explicit inline IdentityFunctor() {} + + HOSTDEVICE inline T operator()(const T& x) const { return x; } +}; + +struct BitwiseAdd { + // Bitwise add operator, returns a + b + template + __host__ __device__ __forceinline__ T operator()(const T& a, + const T& b) const { + return a & b; + } +}; +template +class CompareReduceOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + using Tensor = framework::Tensor; + + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* z = context.Output("Out"); + int axis = context.Attr("axis"); + + Tensor tmp; + framework::DDim x_dims = x->dims(); + framework::DDim y_dims = y->dims(); + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector tmp_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), tmp_dims_array.data(), max_dim, + axis); + tmp.mutable_data(framework::make_ddim(tmp_dims_array), + context.GetPlace()); + ElementwiseComputeEx(context, x, y, axis, + Functor(), &tmp); + // Reduce by 'bitwise and' operator + std::vector reduce_dims; + reduce_dims.resize(tmp.dims().size()); + for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i; + auto stream = context.cuda_device_context().stream(); + TensorReduce>( + tmp, z, reduce_dims, true, BitwiseAdd(), IdentityFunctor(), + stream); + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_COMPARE_REDUCE_CUDA_KERNEL(op_type, functor) \ + REGISTER_OP_CUDA_KERNEL( \ + op_type, paddle::operators::CompareReduceOpKernel< \ + paddle::platform::CUDADeviceContext, functor>, \ + paddle::operators::CompareReduceOpKernel< \ + paddle::platform::CUDADeviceContext, functor>, \ + paddle::operators::CompareReduceOpKernel< \ + paddle::platform::CUDADeviceContext, functor>, \ + paddle::operators::CompareReduceOpKernel< \ + paddle::platform::CUDADeviceContext, functor>); +REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_reduce, + paddle::operators::EqualReduceFunctor); diff --git a/paddle/fluid/operators/controlflow/compare_reduce_op.h b/paddle/fluid/operators/controlflow/compare_reduce_op.h new file mode 100644 index 0000000000..bcad240601 --- /dev/null +++ b/paddle/fluid/operators/controlflow/compare_reduce_op.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { + +template +struct EqualReduceFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + if (std::is_floating_point::value) { + // This branch will be optimized while compiling if T is integer. It is + // safe to cast a and b to double. + return fabs(static_cast(a - b)) < 1e-8; + } else { + return (a == b); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index aa4761dcb2..d6f18eebd7 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -44,7 +44,7 @@ import paddle.nn # from .tensor.creation import eye #DEFINE_ALIAS # from .tensor.creation import fill_constant #DEFINE_ALIAS # from .tensor.creation import get_.tensor_from_selected_rows #DEFINE_ALIAS -# from .tensor.creation import linspace #DEFINE_ALIAS +from .tensor.creation import linspace #DEFINE_ALIAS # from .tensor.creation import ones #DEFINE_ALIAS # from .tensor.creation import ones_like #DEFINE_ALIAS # from .tensor.creation import range #DEFINE_ALIAS @@ -62,7 +62,7 @@ from .tensor.creation import full #DEFINE_ALIAS # from .tensor.stat import reduce_mean #DEFINE_ALIAS # from .tensor.stat import std #DEFINE_ALIAS # from .tensor.stat import var #DEFINE_ALIAS -# from .tensor.logic import equal #DEFINE_ALIAS +from .tensor.logic import equal #DEFINE_ALIAS # from .tensor.logic import greater_equal #DEFINE_ALIAS # from .tensor.logic import greater_than #DEFINE_ALIAS # from .tensor.logic import is_empty #DEFINE_ALIAS diff --git a/python/paddle/common_ops_import.py b/python/paddle/common_ops_import.py new file mode 100644 index 0000000000..477ff2fe4e --- /dev/null +++ b/python/paddle/common_ops_import.py @@ -0,0 +1,27 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from six.moves import reduce +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator +from paddle.fluid.framework import Variable, device_guard +from paddle.fluid.initializer import Constant +from paddle.fluid.core import VarDesc +from paddle.fluid import core +from paddle.fluid.data_feeder import check_type, check_dtype, convert_dtype +from paddle.fluid.layers import utils +from paddle.fluid.layers import fill_constant +import numpy +import warnings diff --git a/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py b/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py new file mode 100644 index 0000000000..d14ff1a4e2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_compare_reduce_op.py @@ -0,0 +1,156 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import op_test +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +def create_test_broadcast_class(op_type, args, callback): + class Cls(op_test.OpTest): + def setUp(self): + x = np.random.random(size=args['x_size']).astype('int32') + y = np.random.random(size=args['y_size']).astype('int32') + z = callback(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': z} + self.op_type = op_type + self.axis = args['axis'] + + def test_output(self): + self.check_output() + + cls_name = "{0}_{1}".format(op_type, 'broadcast') + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +def create_test_not_equal_class(op_type, typename, callback): + class Cls(op_test.OpTest): + def setUp(self): + x = np.random.random(size=(10, 7)).astype(typename) + y = np.random.random(size=(10, 7)).astype(typename) + z = callback(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': z} + self.op_type = op_type + + def test_output(self): + self.check_output() + + cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_equal') + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +def create_test_equal_class(op_type, typename, callback): + class Cls(op_test.OpTest): + def setUp(self): + x = y = np.random.random(size=(10, 7)).astype(typename) + z = callback(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': z} + self.op_type = op_type + + def test_output(self): + self.check_output() + + cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal') + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +def create_test_dim1_class(op_type, typename, callback): + class Cls(op_test.OpTest): + def setUp(self): + x = y = np.random.random(size=(1)).astype(typename) + z = callback(x, y) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': z} + self.op_type = op_type + + def test_output(self): + self.check_output() + + cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal') + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y)) + +for _type_name in {'float32', 'float64', 'int32', 'int64'}: + create_test_not_equal_class('equal_reduce', _type_name, np_equal) + create_test_equal_class('equal_reduce', _type_name, np_equal) + create_test_dim1_class('equal_reduce', _type_name, np_equal) + +broadcast_args = [{ + 'x_size': (100, 2, 3), + 'y_size': (100), + 'axis': 0 +}, { + 'x_size': (2, 100, 3), + 'y_size': (100), + 'axis': 1 +}, { + 'x_size': (2, 3, 100), + 'y_size': (1, 1), + 'axis': -1 +}, { + 'x_size': (2, 10, 12, 3), + 'y_size': (10, 12), + 'axis': 1 +}, { + 'x_size': (100, 2, 3, 4), + 'y_size': (100, 1), + 'axis': 0 +}, { + 'x_size': (10, 3, 12), + 'y_size': (10, 1, 12), + 'axis': -1 +}, { + 'x_size': (2, 12, 3, 5), + 'y_size': (2, 12, 1, 5), + 'axis': -1 +}, { + 'x_size': (2, 12, 3, 5), + 'y_size': (3, 5), + 'axis': 2 +}] + + +def np_broadcast_equal(_x, _y): + res = np.all(np.equal(_x, _y)) + return np.array(res) + + +for args in broadcast_args: + create_test_broadcast_class('equal_reduce', args, np_broadcast_equal) + + +class TestEqualReduceAPI(unittest.TestCase): + def test_name(self): + x = fluid.layers.assign(np.array([3, 4], dtype="int32")) + y = fluid.layers.assign(np.array([3, 4], dtype="int32")) + out = paddle.equal(x, y, name='equal_res') + assert 'equal_res' in out.name + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_linspace.py b/python/paddle/fluid/tests/unittests/test_linspace.py index eeecf17832..9e88541fe1 100644 --- a/python/paddle/fluid/tests/unittests/test_linspace.py +++ b/python/paddle/fluid/tests/unittests/test_linspace.py @@ -17,6 +17,9 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard class TestLinspaceOpCommonCase(OpTest): @@ -67,5 +70,35 @@ class TestLinspaceOpNumOneCase(OpTest): self.check_output() +class TestLinspaceAPI(unittest.TestCase): + def test_out(self): + with program_guard(fluid.Program()): + out_1 = fluid.data(name="out_1", shape=[5], dtype="float32") + out_2 = paddle.tensor.linspace(0, 10, 5, dtype='float32', out=out_1) + exe = fluid.Executor(place=fluid.CPUPlace()) + ipt = {'out_1': np.random.random([5]).astype('float32')} + res_1, res_2 = exe.run(fluid.default_main_program(), + feed=ipt, + fetch_list=[out_1, out_2]) + assert np.array_equal(res_1, res_2) + + def test_name(self): + with fluid.program_guard(fluid.Program()): + out = paddle.linspace( + 0, 10, 5, dtype='float32', name='linspace_res') + assert 'linspace_res' in out.name + + +class TestLinspaceOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # for ci coverage + # The device of fill_constant must be in 'cpu', 'gpu' or None + def test_device_value(): + paddle.linspace(0, 10, 1, dtype="float32", device='xxxpu') + + self.assertRaises(ValueError, test_device_value) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ac5bb335f5..05598cc860 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -20,8 +20,8 @@ # from .creation import diag #DEFINE_ALIAS # from .creation import eye #DEFINE_ALIAS # from .creation import fill_constant #DEFINE_ALIAS -# from .creation import get_tensor_from_selected_rows #DEFINE_ALIAS -# from .creation import linspace #DEFINE_ALIAS +# from .creation import get__from_selected_rows #DEFINE_ALIAS +from .creation import linspace #DEFINE_ALIAS # from .creation import ones #DEFINE_ALIAS # from .creation import ones_like #DEFINE_ALIAS # from .creation import range #DEFINE_ALIAS @@ -39,7 +39,7 @@ from .creation import full #DEFINE_ALIAS # from .stat import reduce_mean #DEFINE_ALIAS # from .stat import std #DEFINE_ALIAS # from .stat import var #DEFINE_ALIAS -# from .logic import equal #DEFINE_ALIAS +from .logic import equal #DEFINE_ALIAS # from .logic import greater_equal #DEFINE_ALIAS # from .logic import greater_than #DEFINE_ALIAS # from .logic import is_empty #DEFINE_ALIAS diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 0af896de7c..aeadb034da 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -12,15 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: define functions to get create a tensor +from paddle.common_ops_import import * -from __future__ import print_function -from ..fluid.framework import Variable -from ..fluid.initializer import Constant -from ..fluid.layer_helper import LayerHelper -from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype -from ..fluid.framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard -from ..fluid.layers import fill_constant +# TODO: define functions to get create a tensor __all__ = [ 'create_tensor', @@ -30,7 +24,7 @@ __all__ = [ # 'diag', 'eye', # 'fill_constant', # 'get_tensor_from_selected_rows', - # 'linspace', + 'linspace', # 'ones', # 'ones_like', # 'range', @@ -39,7 +33,6 @@ __all__ = [ # 'arrange', # 'eye', 'full', - # 'linspace', # 'full_like', # 'triu', # 'tril', @@ -47,6 +40,92 @@ __all__ = [ ] +def linspace(start, stop, num, dtype, out=None, device=None, name=None): + """ + This OP return fixed number of evenly spaced values within a given interval. + + **NOTICE**: The output of this OP has no gradient. + + Args: + start(float|Variable): The input :attr:`start` is start variable of range. It is a float scalar, \ + or a tensor of shape [1] with input data type float32, float64. + stop(float|Variable): The input :attr:`stop` is start variable of range. It is a float scalar, \ + or a tensor of shape [1] with input data type float32, float64. + num(int|Variable): The input :attr:`num` is given num of the sequence. It is an int scalar, \ + or a tensor of shape [1] with type int32. + dtype(string): The data type of output tensor, it could be 'float32' and 'float64'. + out (Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. Default: None. + device (string, optional): Which device to run the operator. The :attr:`device` must be + None, 'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in + the paddle program. Default: None. + name(str, optional): Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name`.Default: None. + + Returns: + Variable, the output data type will be float32, float64.: The 1-D tensor with fixed number of evenly spaced values, \ + the data shape of this tensor is :math:`[num]` . If the :attr:`num` is set 1, the output tensor just has \ + the value with input :attr:`start`. + + Examples: + .. code-block:: python + + import paddle + data = paddle.linspace(0, 10, 5, dtype='float32') # [0.0, 2.5, 5.0, 7.5, 10.0] + data = paddle.linspace(0, 10, 1, dtype='float32') # [0.0] + + """ + helper = LayerHelper("linspace", **locals()) + + if not isinstance(start, Variable): + start = fill_constant([1], dtype, start) + if not isinstance(stop, Variable): + stop = fill_constant([1], dtype, stop) + if not isinstance(num, Variable): + num = fill_constant([1], 'int32', num) + + if out is None: + out = helper.create_variable_for_type_inference(dtype=start.dtype) + else: + check_dtype( + out.dtype, out.name, + convert_dtype(start.dtype), 'linspace', + "The out data type '%s' in linspace must be the same with '%s' seted by parameter 'dtype'." + % (out.dtype, dtype)) + if name: + warning.warn( + "The output Variable name of the paddle.tensor.linspace operation can only be given by parameter out or name.\ + When parameter out and name are set at the same time, out has a higher priority than name. \ + Finally, the output Variable name is same as the out name %s." % + out.name, + category=UserWarning, + stacklevel=2) + + if device is not None: + if device not in ['cpu', 'gpu']: + raise ValueError( + "The value of 'device' in linspace operation must be cpu or gpu, but received %s." + % (device)) + else: + with device_guard(device): + helper.append_op( + type='linspace', + inputs={'Start': start, + 'Stop': stop, + 'Num': num}, + outputs={'Out': [out]}) + else: + helper.append_op( + type='linspace', + inputs={'Start': start, + 'Stop': stop, + 'Num': num}, + outputs={'Out': [out]}) + + return out + + def full(shape, fill_value, out=None, diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 74542d7dad..13a4c5cdc6 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -12,21 +12,93 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle.common_ops_import import * +import paddle.fluid as fluid + # TODO: define logic functions of a tensor -# __all__ = ['equal', -# 'greater_equal', -# 'greater_than', -# 'is_empty', -# 'isfinite', -# 'less_equal', -# 'less_than', -# 'logical_and', -# 'logical_not', -# 'logical_or', -# 'logical_xor', -# 'not_equal', -# 'reduce_all', -# 'reduce_any', -# 'allclose', -# 'elementwise_equal', -# 'isnan'] +__all__ = [ + 'equal', + # 'greater_equal', + # 'greater_than', + # 'is_empty', + # 'isfinite', + # 'less_equal', + # 'less_than', + # 'logical_and', + # 'logical_not', + # 'logical_or', + # 'logical_xor', + # 'not_equal', + # 'reduce_all', + # 'reduce_any', + # 'allclose', + # 'elementwise_equal', + # 'isnan' +] + + +def equal(x, y, axis=-1, name=None): + """ + This OP returns the truth value of :math:`x == y`. True if two inputs have the same elements, False otherwise. + + **NOTICE**: The output of this OP has no gradient, and this OP supports broadcasting by :attr:`axis`. + + Args: + x(Variable): Tensor, data type is float32, float64, int32, int64. + y(Variable): Tensor, data type is float32, float64, int32, int64. + axis(int32, optional): If X.dimension != Y.dimension, Y.dimension + must be a subsequence of x.dimension. And axis is the start + dimension index for broadcasting Y onto X. For more detail, + please refer to OP:`elementwise_add`. + name(str, optional): Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name`.Default: None. + + Returns: + Variable: output Tensor, data type is bool, value is [False] or [True]. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle + import numpy as np + + label = fluid.layers.assign(np.array([3, 4], dtype="int32")) + label_1 = fluid.layers.assign(np.array([1, 2], dtype="int32")) + limit = fluid.layers.assign(np.array([3, 4], dtype="int32")) + out1 = paddle.equal(x=label, y=limit) #out1=[True] + out2 = paddle.equal(x=label_1, y=limit) #out2=[False] + + .. code-block:: python + + import paddle.fluid as fluid + import paddle + import numpy as np + + def gen_data(): + return { + "x": np.ones((2, 3, 4, 5)).astype('float32'), + "y": np.zeros((3, 4)).astype('float32') + } + + x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32') + y = fluid.data(name="y", shape=[3,4], dtype='float32') + out = paddle.equal(x, y, axis=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + res = exe.run(feed=gen_data(), + fetch_list=[out]) + print(res[0]) #[False] + """ + helper = LayerHelper("equal_reduce", **locals()) + out = helper.create_variable_for_type_inference(dtype='bool') + attrs = {} + attrs['axis'] = axis + helper.append_op( + type='equal_reduce', + inputs={'X': [x], + 'Y': [y]}, + attrs=attrs, + outputs={'Out': [out]}) + return out -- GitLab