diff --git a/paddle/gserver/tests/sequence_recurrent_group.py b/paddle/gserver/tests/sequence_recurrent_group.py index da182942c9b4d3aa80f5ea90c1dd5adc321c7c07..8b5a3d49838c9bb49321a9d7514fc0241e6d67cd 100644 --- a/paddle/gserver/tests/sequence_recurrent_group.py +++ b/paddle/gserver/tests/sequence_recurrent_group.py @@ -1,16 +1,16 @@ # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from paddle.trainer_config_helpers import * ######################## data source ################################ diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index 6478e1e0c2e1cfc8a1be5e8842113ec8ca33d762..a8389429f26c17ceab1db22175c90888546ead6f 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -28,39 +28,7 @@ template class ElementwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - TransformFunctor, T, DeviceContext> functor( - x, y, z, ctx.template device_context(), AddFunctor()); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), - "Rank of first input must >= rank of second input."); - - if (x_dims == y_dims) { - functor.Run(); - return; - } - - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, pre, n, post); - if (post == 1) { - functor.RunRowWise(n, pre); - return; - } else { - functor.RunMidWise(n, pre, post); - return; - } + ElementwiseComputeEx, DeviceContext, T>(ctx); } }; diff --git a/paddle/operators/elementwise_max_op.cc b/paddle/operators/elementwise_max_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..53c27ae5be4cbfe85ce61aa27196594ae152eea4 --- /dev/null +++ b/paddle/operators/elementwise_max_op.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/elementwise_max_op.h" +#include "paddle/operators/elementwise_op.h" + +namespace paddle { +namespace operators { +class ElementwiseMaxOpMaker : public ElementwiseOpMaker { + public: + ElementwiseMaxOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : ElementwiseOpMaker(proto, op_checker) { + SetComment("Max", "Out = max(X, Y)"); + AddComment(comment_); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(elementwise_max, ops::ElementwiseOp, ops::ElementwiseMaxOpMaker, + elementwise_max_grad, ops::ElementwiseOpGrad); +REGISTER_OP_CPU_KERNEL( + elementwise_max, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel); +REGISTER_OP_CPU_KERNEL( + elementwise_max_grad, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel); diff --git a/paddle/operators/elementwise_max_op.cu b/paddle/operators/elementwise_max_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5ff4af17477cbd35b765cc00d46c95fda620e2df --- /dev/null +++ b/paddle/operators/elementwise_max_op.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/elementwise_max_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + elementwise_max, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_max_grad, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel); diff --git a/paddle/operators/elementwise_max_op.h b/paddle/operators/elementwise_max_op.h new file mode 100644 index 0000000000000000000000000000000000000000..255728e8e620665a7de225b228c19d6c510da1c8 --- /dev/null +++ b/paddle/operators/elementwise_max_op.h @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/operators/elementwise_op_function.h" + +namespace paddle { +namespace operators { + +template +struct MaxFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a > b ? a : b; } +}; + +template +class ElementwiseMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElementwiseComputeEx, DeviceContext, T>(ctx); + } +}; + +template +struct ElementwiseMaxGradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e > y_e).template cast() * dz_e; + } + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = (x_e <= y_e).template cast() * dz_e; + } + } +}; + +template +struct ElementwiseMaxBroadCastGradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) + .broadcast(Eigen::DSizes(pre, 1)) + .reshape(Eigen::DSizes(x_e.size())); + + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e > y_e_bcast).template cast() * dz_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = ((x_e <= y_e_bcast).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n)) + .sum(Eigen::array{{0}}); + } + } +}; + +template +struct ElementwiseMaxBroadCast2GradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, + Post post) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) + .broadcast(Eigen::DSizes(pre, 1, post)) + .reshape(Eigen::DSizes(x_e.size())); + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e > y_e_bcast).template cast() * dz_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = ((x_e <= y_e_bcast).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n, post)) + .sum(Eigen::array{{0, 2}}); + } + } +}; + +template +class ElementwiseMaxGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElementwiseGradCompute, + ElementwiseMaxBroadCastGradFunctor, + ElementwiseMaxBroadCast2GradFunctor>(ctx); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/elementwise_min_op.cc b/paddle/operators/elementwise_min_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..99482e1bf60c88062087c5fe0105e90aa0a8677c --- /dev/null +++ b/paddle/operators/elementwise_min_op.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/elementwise_min_op.h" +#include "paddle/operators/elementwise_op.h" + +namespace paddle { +namespace operators { +class ElementwiseMinOpMaker : public ElementwiseOpMaker { + public: + ElementwiseMinOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : ElementwiseOpMaker(proto, op_checker) { + SetComment("Max", "Out = min(X, Y)"); + AddComment(comment_); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(elementwise_min, ops::ElementwiseOp, ops::ElementwiseMinOpMaker, + elementwise_min_grad, ops::ElementwiseOpGrad); +REGISTER_OP_CPU_KERNEL( + elementwise_min, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel); +REGISTER_OP_CPU_KERNEL( + elementwise_min_grad, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel); diff --git a/paddle/operators/elementwise_min_op.cu b/paddle/operators/elementwise_min_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..3547e6ccb77177002b1ecbee4e4604b602f72209 --- /dev/null +++ b/paddle/operators/elementwise_min_op.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/elementwise_min_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + elementwise_min, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_min_grad, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel); diff --git a/paddle/operators/elementwise_min_op.h b/paddle/operators/elementwise_min_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e6627a0f1bb468c8e4661b83489cb964b72dddb0 --- /dev/null +++ b/paddle/operators/elementwise_min_op.h @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/operators/elementwise_op_function.h" + +namespace paddle { +namespace operators { + +template +struct MinFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? a : b; } +}; + +template +class ElementwiseMinKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElementwiseComputeEx, DeviceContext, T>(ctx); + } +}; + +template +struct ElementwiseMinGradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e < y_e).template cast() * dz_e; + } + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = (x_e >= y_e).template cast() * dz_e; + } + } +}; + +template +struct ElementwiseMinBroadCastGradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) + .broadcast(Eigen::DSizes(pre, 1)) + .reshape(Eigen::DSizes(x_e.size())); + + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e < y_e_bcast).template cast() * dz_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = ((x_e >= y_e_bcast).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n)) + .sum(Eigen::array{{0}}); + } + } +}; + +template +struct ElementwiseMinBroadCast2GradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, + Post post) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) + .broadcast(Eigen::DSizes(pre, 1, post)) + .reshape(Eigen::DSizes(x_e.size())); + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e < y_e_bcast).template cast() * dz_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = ((x_e >= y_e_bcast).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n, post)) + .sum(Eigen::array{{0, 2}}); + } + } +}; + +template +class ElementwiseMinGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElementwiseGradCompute, + ElementwiseMinBroadCastGradFunctor, + ElementwiseMinBroadCast2GradFunctor>(ctx); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 0c75276b03140473cee4b57a4022ff3c6989ab4c..be11d5cc9de0c73dd5b3fa127c9f043b6b5d3972 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -356,5 +356,43 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { return; } } + +template +void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + TransformFunctor functor( + x, y, z, ctx.template device_context(), Functor()); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } +} + } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 73d7c895806ef28ffc98db88809e317a86762769..21945edf0827e7a86ec2f8ce8f84c9093808c68b 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -55,6 +55,8 @@ __all__ = [ 'elementwise_div', 'elementwise_sub', 'elementwise_mul', + 'elementwise_max', + 'elementwise_min', 'clip', 'sequence_softmax', ] + __activations__ diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py index 38e87728b387bb70a8921a2fe73a4e69701aabe9..cf118df634bb8288456009ebd4954f08d5eb4323 100644 --- a/python/paddle/v2/fluid/tests/test_edit_distance_op.py +++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py @@ -1,3 +1,16 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/v2/fluid/tests/test_elementwise_max_op.py b/python/paddle/v2/fluid/tests/test_elementwise_max_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3dfab4dd2fc82d470bacbb4b615bbb3b5d3e6230 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_elementwise_max_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +from op_test import OpTest + + +class TestElementwiseOp(OpTest): + def setUp(self): + self.op_type = "elementwise_max" + # If x and y have the same value, the max() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float32") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float32") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float32") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + + +class TestElementwiseMaxOp_Vector(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + x = np.random.random((32, )).astype("float32") + sgn = np.random.choice([-1, 1], (32, )).astype("float32") + y = x + sgn * np.random.uniform(0.1, 1, (32, )).astype("float32") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) + sgn = np.random.choice([-1, 1], (2, )).astype(np.float32) + y = x[:, 0, 0] + sgn * \ + np.random.uniform(1, 2, (2, )).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.attrs = {'axis': 0} + self.outputs = { + 'Out': + np.maximum(self.inputs['X'], self.inputs['Y'].reshape(2, 1, 1)) + } + + +class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) + sgn = np.random.choice([-1, 1], (3, )).astype(np.float32) + y = x[0, :, 0] + sgn * \ + np.random.uniform(1, 2, (3, )).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.attrs = {'axis': 1} + self.outputs = { + 'Out': + np.maximum(self.inputs['X'], self.inputs['Y'].reshape(1, 3, 1)) + } + + +class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) + sgn = np.random.choice([-1, 1], (4, )).astype(np.float32) + y = x[0, 0, :] + sgn * \ + np.random.uniform(1, 2, (4, )).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.outputs = { + 'Out': + np.maximum(self.inputs['X'], self.inputs['Y'].reshape(1, 1, 4)) + } + + +class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32) + sgn = np.random.choice([-1, 1], (3, 4)).astype(np.float32) + y = x[0, :, :, 0] + sgn * \ + np.random.uniform(1, 2, (3, 4)).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.attrs = {'axis': 1} + self.outputs = { + 'Out': + np.maximum(self.inputs['X'], self.inputs['Y'].reshape(1, 3, 4, 1)) + } + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_elementwise_min_op.py b/python/paddle/v2/fluid/tests/test_elementwise_min_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8422a9cdae70f2b88b9e86d5606b696509060865 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_elementwise_min_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +from op_test import OpTest + + +class TestElementwiseOp(OpTest): + def setUp(self): + self.op_type = "elementwise_min" + # If x and y have the same value, the min() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float32") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float32") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float32") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + + +class TestElementwiseMaxOp_Vector(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + x = np.random.random((32, )).astype("float32") + sgn = np.random.choice([-1, 1], (32, )).astype("float32") + y = x + sgn * np.random.uniform(0.1, 1, (32, )).astype("float32") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) + sgn = np.random.choice([-1, 1], (2, )).astype(np.float32) + y = x[:, 0, 0] + sgn * \ + np.random.uniform(1, 2, (2, )).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.attrs = {'axis': 0} + self.outputs = { + 'Out': + np.minimum(self.inputs['X'], self.inputs['Y'].reshape(2, 1, 1)) + } + + +class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) + sgn = np.random.choice([-1, 1], (3, )).astype(np.float32) + y = x[0, :, 0] + sgn * \ + np.random.uniform(1, 2, (3, )).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.attrs = {'axis': 1} + self.outputs = { + 'Out': + np.minimum(self.inputs['X'], self.inputs['Y'].reshape(1, 3, 1)) + } + + +class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) + sgn = np.random.choice([-1, 1], (4, )).astype(np.float32) + y = x[0, 0, :] + sgn * \ + np.random.uniform(1, 2, (4, )).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.outputs = { + 'Out': + np.minimum(self.inputs['X'], self.inputs['Y'].reshape(1, 1, 4)) + } + + +class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32) + sgn = np.random.choice([-1, 1], (3, 4)).astype(np.float32) + y = x[0, :, :, 0] + sgn * \ + np.random.uniform(1, 2, (3, 4)).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.attrs = {'axis': 1} + self.outputs = { + 'Out': + np.minimum(self.inputs['X'], self.inputs['Y'].reshape(1, 3, 4, 1)) + } + + +if __name__ == '__main__': + unittest.main()