From acf37ad67504ee2f4ce5f601906c1b5102ede124 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 15 Jan 2018 17:53:36 +0800 Subject: [PATCH] Complete elementwise_max_op --- paddle/operators/elementwise_max_op.cc | 2 +- paddle/operators/elementwise_max_op.cu | 32 +++++++++++++ paddle/operators/elementwise_max_op.h | 63 ++++++++++++++++++++++---- 3 files changed, 87 insertions(+), 10 deletions(-) create mode 100644 paddle/operators/elementwise_max_op.cu diff --git a/paddle/operators/elementwise_max_op.cc b/paddle/operators/elementwise_max_op.cc index b5c6b11ba..53c27ae5b 100644 --- a/paddle/operators/elementwise_max_op.cc +++ b/paddle/operators/elementwise_max_op.cc @@ -42,4 +42,4 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel, - ops::ElementwiseMaxGradKernel); \ No newline at end of file + ops::ElementwiseMaxGradKernel); diff --git a/paddle/operators/elementwise_max_op.cu b/paddle/operators/elementwise_max_op.cu new file mode 100644 index 000000000..5ff4af174 --- /dev/null +++ b/paddle/operators/elementwise_max_op.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/elementwise_max_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + elementwise_max, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel, + ops::ElementwiseMaxKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_max_grad, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel, + ops::ElementwiseMaxGradKernel); diff --git a/paddle/operators/elementwise_max_op.h b/paddle/operators/elementwise_max_op.h index 5c685b75e..e370aeb30 100644 --- a/paddle/operators/elementwise_max_op.h +++ b/paddle/operators/elementwise_max_op.h @@ -65,43 +65,88 @@ class ElementwiseMaxKernel : public framework::OpKernel { }; template -struct ElementwiseSubGradFunctor { +struct ElementwiseMaxGradFunctor { template void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto dz_e = framework::EigenVector::Flatten(*dz); auto x_e = framework::EigenVector::Flatten(*x); auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); if (dx) { auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = (x_e > y_e) * dz_e; + dx_e.device(d) = (x_e > y_e).template cast() * dz_e; } if (dy) { auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (y_e >= x_e) * dz_e; + dy_e.device(d) = (y_e >= x_e).template cast() * dz_e; } } }; template -struct ElementwiseSubOneGradFunctor { +struct ElementwiseMaxBroadCastGradFunctor { template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { + typename dY, typename dZ, typename Pre, typename N> + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) + .broadcast(Eigen::DSizes(pre, 1)) + .reshape(Eigen::DSizes(x_e.size())); + + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e > y_e_bcast).template cast() * dz_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = ((y_e_bcast >= x_e).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n)) + .sum(Eigen::array{{0}}); + } + } +}; + +template +struct ElementwiseMaxBroadCast2GradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, + Post post) { auto x_e = framework::EigenVector::Flatten(*x); auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) + .broadcast(Eigen::DSizes(pre, 1, post)) + .reshape(Eigen::DSizes(x_e.size())); if (dx) { auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e; + dx_e.device(d) = (x_e > y_e_bcast).template cast() * dz_e; } + if (dy) { auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (-1.0) * dz_e.sum(); + dy_e.device(d) = ((y_e_bcast >= x_e).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n, post)) + .sum(Eigen::array{{0, 2}}); } } }; +template +class ElementwiseMaxGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElementwiseGradCompute, + ElementwiseMaxBroadCastGradFunctor, + ElementwiseMaxBroadCast2GradFunctor>(ctx); + } +}; + } // namespace operators } // namespace paddle -- GitLab