diff --git a/paddle/operators/elementwise_max_op.cc b/paddle/operators/elementwise_max_op.cc index b5c6b11ba3b8d76330acc8d76fc88fa51bd98dc7..53c27ae5be4cbfe85ce61aa27196594ae152eea4 100644 --- a/paddle/operators/elementwise_max_op.cc +++ b/paddle/operators/elementwise_max_op.cc @@ -42,4 +42,4 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel, - ops::ElementwiseMaxGradKernel); \ No newline at end of file + ops::ElementwiseMaxGradKernel); diff --git a/paddle/operators/elementwise_max_op.cu b/paddle/operators/elementwise_max_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5ff4af17477cbd35b765cc00d46c95fda620e2df --- /dev/null +++ b/paddle/operators/elementwise_max_op.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/elementwise_max_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + elementwise_max, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_max_grad, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel); diff --git a/paddle/operators/elementwise_max_op.h b/paddle/operators/elementwise_max_op.h index 5c685b75e508fcf4605dd61114a48876821b1582..e370aeb3080a0467e9840119f9a8d690d30b1f59 100644 --- a/paddle/operators/elementwise_max_op.h +++ b/paddle/operators/elementwise_max_op.h @@ -65,43 +65,88 @@ class ElementwiseMaxKernel : public framework::OpKernel { }; template -struct ElementwiseSubGradFunctor { +struct ElementwiseMaxGradFunctor { template void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto dz_e = framework::EigenVector::Flatten(*dz); auto x_e = framework::EigenVector::Flatten(*x); auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); if (dx) { auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = (x_e > y_e) * dz_e; + dx_e.device(d) = (x_e > y_e).template cast() * dz_e; } if (dy) { auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (y_e >= x_e) * dz_e; + dy_e.device(d) = (y_e >= x_e).template cast() * dz_e; } } }; template -struct ElementwiseSubOneGradFunctor { +struct ElementwiseMaxBroadCastGradFunctor { template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { + typename dY, typename dZ, typename Pre, typename N> + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) + .broadcast(Eigen::DSizes(pre, 1)) + .reshape(Eigen::DSizes(x_e.size())); + + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e > y_e_bcast).template cast() * dz_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = ((y_e_bcast >= x_e).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n)) + .sum(Eigen::array{{0}}); + } + } +}; + +template +struct ElementwiseMaxBroadCast2GradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, + Post post) { auto x_e = framework::EigenVector::Flatten(*x); auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) + .broadcast(Eigen::DSizes(pre, 1, post)) + .reshape(Eigen::DSizes(x_e.size())); if (dx) { auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e; + dx_e.device(d) = (x_e > y_e_bcast).template cast() * dz_e; } + if (dy) { auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (-1.0) * dz_e.sum(); + dy_e.device(d) = ((y_e_bcast >= x_e).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n, post)) + .sum(Eigen::array{{0, 2}}); } } }; +template +class ElementwiseMaxGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElementwiseGradCompute, + ElementwiseMaxBroadCastGradFunctor, + ElementwiseMaxBroadCast2GradFunctor>(ctx); + } +}; + } // namespace operators } // namespace paddle