diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index d78a2c4c21149ef3c800991b9a144ea198f1bdcf..7e88e039611007d17156d10f852eb46f3ee8e7a3 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -52,7 +52,7 @@ struct SizeOfTypeFunctor { }; static inline size_t SizeOfType(std::type_index type) { - SizeOfTypeFunctor functor; + SizeOfTypeFunctor functor; size_t size = functor(type); PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); return size; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f22f86468db7be6c87193f85c3f87aef4a9d3c68..b497c877d1b564434807ac040112020445bd71f1 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -62,6 +62,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d);\n") endif() + if ("${TARGET}" STREQUAL "compare_op") + set(pybind_flag 1) + file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal);\n") + endif() + # pool_with_index_op contains several operators if ("${TARGET}" STREQUAL "pool_with_index_op") set(pybind_flag 1) diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b425d14df3bc484437dc72f29abf13b887006bd --- /dev/null +++ b/paddle/operators/compare_op.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/compare_op.h" +#include "paddle/framework/op_registry.h" +namespace paddle { +namespace operators { +template +class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + CompareOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + OpComment comment; + AddInput("X", + string::Sprintf("(LoDTensor) the left hand operand of %s operator", + comment.type)); + AddInput("Y", string::Sprintf( + "(LoDTensor) the right hand operand of %s operator", + comment.type)); + AddOutput("Out", string::Sprintf( + "(LoDTensor) n-dim bool tensor. Each element is %s", + comment.equation)); + AddComment(string::Sprintf(R"DOC(%s Operator + +It operates element-wise on X and Y, and returns the Out. Each of them is a +N-dim tensor. X and Y could be any type. The each element of the Out tensor is +calculated by %s +)DOC", + comment.type, comment.equation)); + } +}; + +template +class CompareOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + OpComment comment; + PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X", + comment.type); + PADDLE_ENFORCE(context->HasInput("Y"), "%s operator must has input Y", + comment.type); + auto dim_x = context->GetInputDim("X"); + auto dim_y = context->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y), + "The number of elements in X and Y should be same"); + + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OP_WITH_KERNEL( \ + op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker); + +REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); +REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); +REGISTER_LOGICAL_OP(equal, "Out = X == Y"); +REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.cu b/paddle/operators/compare_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..42a5bb2f45fd389f60c3dc034cade7f56a907e35 --- /dev/null +++ b/paddle/operators/compare_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/compare_op.h" + +REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor); +REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.h b/paddle/operators/compare_op.h new file mode 100644 index 0000000000000000000000000000000000000000..04e04e347b398abb5fb66876bf801b1eee688ec6 --- /dev/null +++ b/paddle/operators/compare_op.h @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace operators { + +template +struct LessThanFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; } +}; + +template +struct EqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + if (std::is_floating_point::value) { + // This branch will be optimized while compiling if T is integer. It is + // safe to cast a and b to double. + return fabs(static_cast(a - b)) < 1e-8; + } else { + return (a == b); + } + } +}; + +template +class CompareOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* out = context.Output("Out"); + Functor binary_func; + platform::Transform trans; + trans(context.device_context(), x->data(), x->data() + x->numel(), + y->data(), out->mutable_data(context.GetPlace()), + binary_func); + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, \ + ::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \ + functor>, \ + ::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \ + functor>, \ + ::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \ + functor>, \ + ::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \ + functor>); diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 0c528174b2b2b3a27869ed0083fc96b8d90e723b..0f906e0e470b7f95bb2103ae55330fc1831aa78f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -113,11 +113,13 @@ PYBIND11_PLUGIN(core) { .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) + .def("set", PyCPUTensorSetFromArray) #ifdef PADDLE_WITH_CUDA .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) + .def("set", PyCUDATensorSetFromArray) #endif .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("set_float_element", TensorSetElement) diff --git a/paddle/pybind/tensor_py.h b/paddle/pybind/tensor_py.h index f278e79af60486bce400f313b80ebbe3971f869b..41fa658502d341fe9653a3e99b58498fcaeada47 100644 --- a/paddle/pybind/tensor_py.h +++ b/paddle/pybind/tensor_py.h @@ -85,7 +85,7 @@ struct CastToPyBufferImpl { } // namespace details inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { auto buffer_info = - details::CastToPyBufferImpl()( + details::CastToPyBufferImpl()( tensor); return buffer_info; } diff --git a/python/paddle/v2/framework/tests/test_compare_op.py b/python/paddle/v2/framework/tests/test_compare_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bb0256694d77323f12c50856533e93b090dc6198 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_compare_op.py @@ -0,0 +1,29 @@ +import op_test +import unittest +import numpy + + +def create_test_class(op_type, typename, callback): + class Cls(op_test.OpTest): + def setUp(self): + a = numpy.random.random(size=(10, 7)).astype(typename) + b = numpy.random.random(size=(10, 7)).astype(typename) + c = callback(a, b) + self.inputs = {'X': a, 'Y': b} + self.outputs = {'Out': c} + self.op_type = op_type + + def test_output(self): + self.check_output() + + cls_name = "{0}_{1}".format(op_type, typename) + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +for _type_name in {'float32', 'float64', 'int32', 'int64'}: + create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) + create_test_class('equal', _type_name, lambda _a, _b: _a == _b) + +if __name__ == '__main__': + unittest.main()