From f5cd9619002e2b392ac5195da0e3feef67c0c4c4 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 15 Jan 2018 18:56:17 +0800 Subject: [PATCH] complete elementwise_min_op --- paddle/operators/elementwise_max_op.h | 6 +- paddle/operators/elementwise_min_op.cc | 2 +- paddle/operators/elementwise_min_op.cu | 32 ++++++ paddle/operators/elementwise_min_op.h | 152 +++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 4 deletions(-) create mode 100644 paddle/operators/elementwise_min_op.cu create mode 100644 paddle/operators/elementwise_min_op.h diff --git a/paddle/operators/elementwise_max_op.h b/paddle/operators/elementwise_max_op.h index e370aeb308..92152f7cb6 100644 --- a/paddle/operators/elementwise_max_op.h +++ b/paddle/operators/elementwise_max_op.h @@ -79,7 +79,7 @@ struct ElementwiseMaxGradFunctor { } if (dy) { auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = (y_e >= x_e).template cast() * dz_e; + dy_e.device(d) = (x_e <= y_e).template cast() * dz_e; } } }; @@ -104,7 +104,7 @@ struct ElementwiseMaxBroadCastGradFunctor { if (dy) { auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = ((y_e_bcast >= x_e).template cast() * dz_e) + dy_e.device(d) = ((x_e <= y_e_bcast).template cast() * dz_e) .reshape(Eigen::DSizes(pre, n)) .sum(Eigen::array{{0}}); } @@ -131,7 +131,7 @@ struct ElementwiseMaxBroadCast2GradFunctor { if (dy) { auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = ((y_e_bcast >= x_e).template cast() * dz_e) + dy_e.device(d) = ((x_e <= y_e_bcast).template cast() * dz_e) .reshape(Eigen::DSizes(pre, n, post)) .sum(Eigen::array{{0, 2}}); } diff --git a/paddle/operators/elementwise_min_op.cc b/paddle/operators/elementwise_min_op.cc index b78846f17a..99482e1bf6 100644 --- a/paddle/operators/elementwise_min_op.cc +++ b/paddle/operators/elementwise_min_op.cc @@ -42,4 +42,4 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseMinGradKernel, ops::ElementwiseMinGradKernel, ops::ElementwiseMinGradKernel, - ops::ElementwiseMinGradKernel); \ No newline at end of file + ops::ElementwiseMinGradKernel); diff --git a/paddle/operators/elementwise_min_op.cu b/paddle/operators/elementwise_min_op.cu new file mode 100644 index 0000000000..3547e6ccb7 --- /dev/null +++ b/paddle/operators/elementwise_min_op.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/elementwise_min_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + elementwise_min, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel, + ops::ElementwiseMinKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_min_grad, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel, + ops::ElementwiseMinGradKernel); diff --git a/paddle/operators/elementwise_min_op.h b/paddle/operators/elementwise_min_op.h new file mode 100644 index 0000000000..53b7f59fa0 --- /dev/null +++ b/paddle/operators/elementwise_min_op.h @@ -0,0 +1,152 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/operators/elementwise_op_function.h" + +namespace paddle { +namespace operators { + +template +struct MinFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? a : b; } +}; + +template +class ElementwiseMinKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + TransformFunctor, T, DeviceContext> functor( + x, y, z, ctx.template device_context(), MinFunctor()); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } + } +}; + +template +struct ElementwiseMinGradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e < y_e).template cast() * dz_e; + } + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = (x_e >= y_e).template cast() * dz_e; + } + } +}; + +template +struct ElementwiseMinBroadCastGradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) + .broadcast(Eigen::DSizes(pre, 1)) + .reshape(Eigen::DSizes(x_e.size())); + + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e < y_e_bcast).template cast() * dz_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = ((x_e >= y_e_bcast).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n)) + .sum(Eigen::array{{0}}); + } + } +}; + +template +struct ElementwiseMinBroadCast2GradFunctor { + template + void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, + Post post) { + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dz_e = framework::EigenVector::Flatten(*dz); + + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) + .broadcast(Eigen::DSizes(pre, 1, post)) + .reshape(Eigen::DSizes(x_e.size())); + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(d) = (x_e < y_e_bcast).template cast() * dz_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(d) = ((x_e >= y_e_bcast).template cast() * dz_e) + .reshape(Eigen::DSizes(pre, n, post)) + .sum(Eigen::array{{0, 2}}); + } + } +}; + +template +class ElementwiseMinGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElementwiseGradCompute, + ElementwiseMinBroadCastGradFunctor, + ElementwiseMinBroadCast2GradFunctor>(ctx); + } +}; + +} // namespace operators +} // namespace paddle -- GitLab