diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc index daa2c193b48fe216ff284169a3dce1b4cd40a791..930c295a9cb31238954efeb87ff5ac2d3ca7bdc6 100644 --- a/paddle/operators/compare_op.cc +++ b/paddle/operators/compare_op.cc @@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is calculated by %s )DOC", comment.type, comment.equation)); + AddAttr("axis", + "(int, default -1). The start dimension index " + "for broadcasting Y onto X.") + .SetDefault(-1) + .EqualGreaterThan(-1); } }; @@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); -REGISTER_LOGICAL_OP(greater_than, "Out = X > Y"); -REGISTER_LOGICAL_KERNEL(greater_than, CPU, - paddle::operators::GreaterThanFunctor); -REGISTER_LOGICAL_OP(greater_equal, "Out = X >= Y"); -REGISTER_LOGICAL_KERNEL(greater_equal, CPU, - paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.cu b/paddle/operators/compare_op.cu index 26049271befd1fe57001659d1a406e73de0004a7..f625824dbc99d603f1e92700b4ad3d7fa25b471d 100644 --- a/paddle/operators/compare_op.cu +++ b/paddle/operators/compare_op.cu @@ -16,8 +16,4 @@ limitations under the License. */ REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); -REGISTER_LOGICAL_KERNEL(greater_than, CUDA, - paddle::operators::GreaterThanFunctor); -REGISTER_LOGICAL_KERNEL(greater_equal, CUDA, - paddle::operators::GreaterEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); diff --git a/paddle/operators/compare_op.h b/paddle/operators/compare_op.h index 567e89c0a727ad0cdd2add8ec8b2a42c86a58007..9c655d6c0d8e5fe04ee6d85f7e9d9da68105230c 100644 --- a/paddle/operators/compare_op.h +++ b/paddle/operators/compare_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "paddle/framework/op_registry.h" +#include "paddle/operators/elementwise_op_function.h" #include "paddle/platform/transform.h" namespace paddle { @@ -33,18 +34,6 @@ struct LessEqualFunctor { HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; } }; -template -struct GreaterThanFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; } -}; - -template -struct GreaterEqualFunctor { - using ELEM_TYPE = T; - HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; } -}; - template struct EqualFunctor { using ELEM_TYPE = T; @@ -65,14 +54,7 @@ class CompareOpKernel public: void Compute(const framework::ExecutionContext& context) const override { using T = typename Functor::ELEM_TYPE; - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* out = context.Output("Out"); - Functor binary_func; - platform::Transform trans; - trans(context.template device_context(), x->data(), - x->data() + x->numel(), y->data(), - out->mutable_data(context.GetPlace()), binary_func); + ElementwiseComputeEx(context); } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index db5d30c1af286913f8decd7ab74058fd732ead65..d749b8e8757d0d433be05876779ccc22b95ca80b 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -176,14 +176,15 @@ class MidWiseTransformIterator }; #endif -template +template class TransformFunctor { public: TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z, const DeviceContext& ctx, Functor func) : x_(x->data()), y_(y->data()), - z_(z->mutable_data(ctx.GetPlace())), + z_(z->mutable_data(ctx.GetPlace())), nx_(x->numel()), ctx_(ctx), func_(func) {} @@ -208,7 +209,7 @@ class TransformFunctor { private: const T* x_; const T* y_; - T* z_; + OutType* z_; int64_t nx_; const DeviceContext& ctx_; Functor func_; @@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { } } -template +template void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { using Tensor = framework::Tensor; auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - TransformFunctor functor( + z->mutable_data(ctx.GetPlace()); + TransformFunctor functor( x, y, z, ctx.template device_context(), Functor()); auto x_dims = x->dims(); diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 1e55adf0a8cdd6a97f2fcdfeb4ec63ae402412c6..5f01fdb076d3bf7d060a805d1431f4973993a843 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib -from ..layer_helper import LayerHelper, unique_name -from ..framework import Program, Variable, Operator -from .. import core +from layer_function_generator import autodoc from tensor import assign, fill_constant -import contextlib -from ..registry import autodoc +from .. import core +from ..framework import Program, Variable, Operator +from ..layer_helper import LayerHelper, unique_name __all__ = [ 'split_lod_tensor', @@ -1477,7 +1477,7 @@ class DynamicRNN(object): method)) -@autodoc +@autodoc() def reorder_lod_tensor_by_rank(x, rank_table): helper = LayerHelper('reorder_lod_tensor_by_rank', **locals()) helper.is_instance('x', Variable) diff --git a/python/paddle/v2/fluid/layers/device.py b/python/paddle/v2/fluid/layers/device.py index 736813d1b109087da367666d90be9e88dad1860e..107511b5f4ab1108610bc1326f30e5d9ab407853 100644 --- a/python/paddle/v2/fluid/layers/device.py +++ b/python/paddle/v2/fluid/layers/device.py @@ -15,14 +15,14 @@ All util layers. """ -from ..layer_helper import LayerHelper +from layer_function_generator import autodoc from ..framework import unique_name -from ..registry import autodoc +from ..layer_helper import LayerHelper __all__ = ['get_places'] -@autodoc +@autodoc() def get_places(device_count=None, device_type=None): helper = LayerHelper('get_places', **locals()) out_places = helper.create_variable(name=unique_name(helper.name + ".out")) diff --git a/python/paddle/v2/fluid/registry.py b/python/paddle/v2/fluid/layers/layer_function_generator.py similarity index 94% rename from python/paddle/v2/fluid/registry.py rename to python/paddle/v2/fluid/layers/layer_function_generator.py index ff10542d40aabaf31897842754d38b7868472b21..b0e4d1635f7b5d0afdfa677e6ec1e8f9245a9d54 100644 --- a/python/paddle/v2/fluid/registry.py +++ b/python/paddle/v2/fluid/layers/layer_function_generator.py @@ -13,17 +13,19 @@ # limitations under the License. import re import cStringIO -import warnings import functools -import inspect +import warnings + +from .. import proto -import proto.framework_pb2 as framework_pb2 -from framework import OpProtoHolder, Variable, Program, Operator -from paddle.v2.fluid.layer_helper import LayerHelper, unique_name +framework_pb2 = proto.framework_pb2 + +from ..framework import OpProtoHolder, Variable +from ..layer_helper import LayerHelper __all__ = [ 'deprecated', - 'register_layer', + 'generate_layer_fn', 'autodoc', ] @@ -96,7 +98,7 @@ def _generate_doc_string_(op_proto): return buf.getvalue() -def register_layer(op_type): +def generate_layer_fn(op_type): """Register the Python layer for an Operator. Args: @@ -207,7 +209,10 @@ def deprecated(func_or_class): return func_wrapper -def autodoc(func): - func.__doc__ = _generate_doc_string_(OpProtoHolder.instance().get_op_proto( - func.__name__)) - return func +def autodoc(comment=""): + def __impl__(func): + func.__doc__ = _generate_doc_string_(OpProtoHolder.instance( + ).get_op_proto(func.__name__)) + comment + return func + + return __impl__ diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index d29607616266973b3e59dbae660724d0c3874def..b517f8be6a3e5558dd01afe094fb3989cfb3af44 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from ..registry import register_layer +from layer_function_generator import generate_layer_fn __activations__ = [ 'sigmoid', @@ -53,4 +52,4 @@ __all__ = [ ] + __activations__ for _OP in set(__all__): - globals()[_OP] = register_layer(_OP) + globals()[_OP] = generate_layer_fn(_OP) diff --git a/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py index 6206fcc4be2cdfb29c41270c6cafc350d0c6acfd..cf054bb0fe778d34add4ac456f672a8b47483e84 100644 --- a/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py +++ b/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py @@ -1,3 +1,17 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import paddle.v2 as paddle import paddle.v2.fluid as fluid diff --git a/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py index cc37f773c40a89c712cd48fddd3abab506be2c4b..42b3cb81ce67d38494677f3ecbfb1e07f7c0c3ad 100644 --- a/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py @@ -1,3 +1,17 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import print_function import sys diff --git a/python/paddle/v2/fluid/tests/test_compare_op.py b/python/paddle/v2/fluid/tests/test_compare_op.py index 08ef90b10eb70eb380db2821415184709602b848..c9be80fc45cd3428937998357b9dd9cbde1547cc 100644 --- a/python/paddle/v2/fluid/tests/test_compare_op.py +++ b/python/paddle/v2/fluid/tests/test_compare_op.py @@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback): for _type_name in {'float32', 'float64', 'int32', 'int64'}: create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) - create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) - create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) create_test_class('equal', _type_name, lambda _a, _b: _a == _b) if __name__ == '__main__': diff --git a/python/paddle/v2/fluid/tests/test_registry.py b/python/paddle/v2/fluid/tests/test_registry.py index 6435e7e243d4e7fa10c99fda48a011523d8cc588..44e50ca55ac609ed2e0a145ff12248fa18479668 100644 --- a/python/paddle/v2/fluid/tests/test_registry.py +++ b/python/paddle/v2/fluid/tests/test_registry.py @@ -11,26 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import unittest -import warnings import paddle.v2.fluid as fluid -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -import paddle.v2.fluid.registry as registry +import numpy as np +import decorators class TestRegistry(unittest.TestCase): + @decorators.prog_scope() def test_registry_layer(self): - self.layer_type = "mean" - program = framework.Program() - x = fluid.layers.data(name='X', shape=[10, 10], dtype='float32') - output = layers.mean(x) + output = fluid.layers.mean(x=x) + place = fluid.CPUPlace() exe = fluid.Executor(place) - X = np.random.random((10, 10)).astype("float32") - mean_out = exe.run(program, feed={"X": X}, fetch_list=[output]) - self.assertAlmostEqual(np.mean(X), mean_out) + mean_out = exe.run(feed={"X": X}, fetch_list=[output]) + self.assertAlmostEqual(np.mean(X), mean_out[0])