diff --git a/paddle/operators/elementwise_add_op.cc b/paddle/operators/elementwise_add_op.cc index 5f7b654d69f081dfa85b0d61960eb52b7982faa1..d9bc80c869c023caebf0b45ed24f2def3f0b1dd8 100644 --- a/paddle/operators/elementwise_add_op.cc +++ b/paddle/operators/elementwise_add_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/elementwise_add_op.h" +#include "paddle/operators/elementwise_op.h" namespace paddle { namespace operators { diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index 9e9f1ffba6fb23f5394713c67aa4363b85717f50..e9f78ef26e05878053d968c35f17b456c128827a 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/operators/elementwise_op.h" +#include "paddle/operators/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/elementwise_div_op.cc b/paddle/operators/elementwise_div_op.cc index c6898150d310d0c4fdefae5a58a5792a72f9889e..3f56344d0007b5f14fd9b5b9b44a9b29d3c42f2a 100644 --- a/paddle/operators/elementwise_div_op.cc +++ b/paddle/operators/elementwise_div_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/elementwise_div_op.h" +#include "paddle/operators/elementwise_op.h" namespace paddle { namespace operators { diff --git a/paddle/operators/elementwise_div_op.h b/paddle/operators/elementwise_div_op.h index 9bd7c8ea548c46ec9b4c5a085e4e70d5dd162f3a..99b6d9c1991edfb0018f8a459dfa373948cec434 100644 --- a/paddle/operators/elementwise_div_op.h +++ b/paddle/operators/elementwise_div_op.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/operators/elementwise_op.h" +#include "paddle/operators/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc index f2544b54d6bc543a50d8de03d482333b485bc076..bda5dfe03e974740fe4a07191ae6b68ebfcd5d3a 100644 --- a/paddle/operators/elementwise_mul_op.cc +++ b/paddle/operators/elementwise_mul_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/elementwise_mul_op.h" +#include "paddle/operators/elementwise_op.h" namespace paddle { namespace operators { diff --git a/paddle/operators/elementwise_mul_op.h b/paddle/operators/elementwise_mul_op.h index 1eaf2e3efc97a32739efcaf37066817ee173fadc..6ab642378bb0af8593ca0677014aede3c03cff8e 100644 --- a/paddle/operators/elementwise_mul_op.h +++ b/paddle/operators/elementwise_mul_op.h @@ -13,7 +13,7 @@ limitations under the License. */ #pragma once -#include "paddle/operators/elementwise_op.h" +#include "paddle/operators/elementwise_op_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/elementwise_op.h b/paddle/operators/elementwise_op.h index c4777a00d6781ea751123b2efffb6df8e29630b0..3082f37422faa990bbf03c8a1a87b025d481a290 100644 --- a/paddle/operators/elementwise_op.h +++ b/paddle/operators/elementwise_op.h @@ -1,201 +1,24 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once -#include -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { -/* - * Out = X ⊙ Y - * If Y's shape does not match X' shape, they will be reshaped. - * For example: - * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 - * pre=2, n=3*4, post=5 - * x.shape(2, 12, 5) * y.shape(1,12,1).broadcast(2,12,5) - * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) - * pre=2*3, n=4*5, post=1 - * x.shape(2, 3, 20) * y.shape(1,1,20).broadcast(2,3,20) - */ -inline void get_mid_dims(const framework::DDim& x_dims, - const framework::DDim& y_dims, const int axis, - int& pre, int& n, int& post) { - pre = 1; - n = 1; - post = 1; - for (int i = 0; i < axis; ++i) { - pre *= x_dims[i]; - } - - for (int i = 0; i < y_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i], - "Broadcast dimension mismatch."); - n *= y_dims[i]; - } - - for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { - post *= x_dims[i]; - } -} - -#define EIGEN_FUNCTOR(name, eigen_op) \ - struct Eigen##name##Functor { \ - template \ - inline void Run(const framework::Tensor* x, const framework::Tensor* y, \ - framework::Tensor* z, \ - const framework::ExecutionContext& ctx) { \ - auto x_e = framework::EigenVector::Flatten(*x); \ - auto y_e = framework::EigenVector::Flatten(*y); \ - auto z_e = framework::EigenVector::Flatten(*z); \ - z_e.device(ctx.GetEigenDevice()) = eigen_op(x_e, y_e); \ - } \ - template \ - inline void RunBroadCast(const framework::Tensor* x, \ - const framework::Tensor* y, framework::Tensor* z, \ - const framework::ExecutionContext& ctx, int pre, \ - int n) { \ - auto x_e = framework::EigenVector::Flatten(*x); \ - auto y_e = framework::EigenVector::Flatten(*y); \ - auto z_e = framework::EigenVector::Flatten(*z); \ - auto y_bcast = y_e.reshape(Eigen::DSizes(1, n)) \ - .broadcast(Eigen::DSizes(pre, 1)) \ - .reshape(Eigen::DSizes(x_e.size())); \ - z_e.device(ctx.GetEigenDevice()) = eigen_op(x_e, y_bcast); \ - } \ - template \ - inline void RunBroadCast2(const framework::Tensor* x, \ - const framework::Tensor* y, \ - framework::Tensor* z, \ - const framework::ExecutionContext& ctx, int pre, \ - int n, int post) { \ - auto x_e = framework::EigenVector::Flatten(*x); \ - auto y_e = framework::EigenVector::Flatten(*y); \ - auto z_e = framework::EigenVector::Flatten(*z); \ - auto y_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) \ - .broadcast(Eigen::DSizes(pre, 1, post)) \ - .reshape(Eigen::DSizes(x_e.size())); \ - z_e.device(ctx.GetEigenDevice()) = eigen_op(x_e, y_bcast); \ - } \ - } - -template -void ElementwiseCompute(const framework::ExecutionContext& ctx) { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), - "Rank of first input must >= rank of second input.") - - if (x_dims == y_dims || product(y_dims) == 1) { - functor f; - f.template Run(x, y, z, ctx); - return; - } - - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, pre, n, post); - if (post == 1) { - functor f; - f.template RunBroadCast(x, y, z, ctx, pre, n); - return; - } else { - functor f; - f.template RunBroadCast2(x, y, z, ctx, pre, n, post); - return; - } -} - -#define EIGEN_ADD(x, y) ((x) + (y)) -EIGEN_FUNCTOR(Add, EIGEN_ADD); - -#define EIGEN_SUB(x, y) ((x) - (y)) -EIGEN_FUNCTOR(Sub, EIGEN_SUB); - -#define EIGEN_MUL(x, y) ((x) * (y)) -EIGEN_FUNCTOR(Mul, EIGEN_MUL); - -#define EIGEN_DIV(x, y) ((x) / (y)) -EIGEN_FUNCTOR(Div, EIGEN_DIV); - -template -void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Input("Out"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - - auto place = ctx.GetEigenDevice(); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - if (dx) { - dx->mutable_data(ctx.GetPlace()); - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - } - - if (x_dims == y_dims) { - functor f; - f(place, x, y, out, dx, dy, dout); - return; - } - - if (product(y_dims) == 1) { - functor1 f; - f(place, x, y, out, dx, dy, dout); - return; - } - - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, pre, n, post); - - if (post == 1) { - broadcastfunctor f; - f(place, x, y, out, dx, dy, dout, pre, n); - return; - } else { - broadcast2functor f; - f(place, x, y, out, dx, dy, dout, pre, n, post); - return; - } -} - class ElementwiseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h new file mode 100644 index 0000000000000000000000000000000000000000..3eb97f60b59848d23bcd15ea1e3d2f21b721f6a4 --- /dev/null +++ b/paddle/operators/elementwise_op_function.h @@ -0,0 +1,200 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +/* + * Out = X ⊙ Y + * If Y's shape does not match X' shape, they will be reshaped. + * For example: + * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + * pre=2, n=3*4, post=5 + * x.shape(2, 12, 5) * y.shape(1,12,1).broadcast(2,12,5) + * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) + * pre=2*3, n=4*5, post=1 + * x.shape(2, 3, 20) * y.shape(1,1,20).broadcast(2,3,20) + */ +inline void get_mid_dims(const framework::DDim& x_dims, + const framework::DDim& y_dims, const int axis, + int& pre, int& n, int& post) { + pre = 1; + n = 1; + post = 1; + for (int i = 0; i < axis; ++i) { + pre *= x_dims[i]; + } + + for (int i = 0; i < y_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i], + "Broadcast dimension mismatch."); + n *= y_dims[i]; + } + + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + post *= x_dims[i]; + } +} + +#define EIGEN_FUNCTOR(name, eigen_op) \ + struct Eigen##name##Functor { \ + template \ + inline void Run(const framework::Tensor* x, const framework::Tensor* y, \ + framework::Tensor* z, \ + const framework::ExecutionContext& ctx) { \ + auto x_e = framework::EigenVector::Flatten(*x); \ + auto y_e = framework::EigenVector::Flatten(*y); \ + auto z_e = framework::EigenVector::Flatten(*z); \ + z_e.device(ctx.GetEigenDevice()) = eigen_op(x_e, y_e); \ + } \ + template \ + inline void RunBroadCast(const framework::Tensor* x, \ + const framework::Tensor* y, framework::Tensor* z, \ + const framework::ExecutionContext& ctx, int pre, \ + int n) { \ + auto x_e = framework::EigenVector::Flatten(*x); \ + auto y_e = framework::EigenVector::Flatten(*y); \ + auto z_e = framework::EigenVector::Flatten(*z); \ + auto y_bcast = y_e.reshape(Eigen::DSizes(1, n)) \ + .broadcast(Eigen::DSizes(pre, 1)) \ + .reshape(Eigen::DSizes(x_e.size())); \ + z_e.device(ctx.GetEigenDevice()) = eigen_op(x_e, y_bcast); \ + } \ + template \ + inline void RunBroadCast2(const framework::Tensor* x, \ + const framework::Tensor* y, \ + framework::Tensor* z, \ + const framework::ExecutionContext& ctx, int pre, \ + int n, int post) { \ + auto x_e = framework::EigenVector::Flatten(*x); \ + auto y_e = framework::EigenVector::Flatten(*y); \ + auto z_e = framework::EigenVector::Flatten(*z); \ + auto y_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) \ + .broadcast(Eigen::DSizes(pre, 1, post)) \ + .reshape(Eigen::DSizes(x_e.size())); \ + z_e.device(ctx.GetEigenDevice()) = eigen_op(x_e, y_bcast); \ + } \ + } + +template +void ElementwiseCompute(const framework::ExecutionContext& ctx) { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input.") + + if (x_dims == y_dims || product(y_dims) == 1) { + functor f; + f.template Run(x, y, z, ctx); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + functor f; + f.template RunBroadCast(x, y, z, ctx, pre, n); + return; + } else { + functor f; + f.template RunBroadCast2(x, y, z, ctx, pre, n, post); + return; + } +} + +#define EIGEN_ADD(x, y) ((x) + (y)) +EIGEN_FUNCTOR(Add, EIGEN_ADD); + +#define EIGEN_SUB(x, y) ((x) - (y)) +EIGEN_FUNCTOR(Sub, EIGEN_SUB); + +#define EIGEN_MUL(x, y) ((x) * (y)) +EIGEN_FUNCTOR(Mul, EIGEN_MUL); + +#define EIGEN_DIV(x, y) ((x) / (y)) +EIGEN_FUNCTOR(Div, EIGEN_DIV); + +template +void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto place = ctx.GetEigenDevice(); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + if (dx) { + dx->mutable_data(ctx.GetPlace()); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + } + + if (x_dims == y_dims) { + functor f; + f(place, x, y, out, dx, dy, dout); + return; + } + + if (product(y_dims) == 1) { + functor1 f; + f(place, x, y, out, dx, dy, dout); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + + if (post == 1) { + broadcastfunctor f; + f(place, x, y, out, dx, dy, dout, pre, n); + return; + } else { + broadcast2functor f; + f(place, x, y, out, dx, dy, dout, pre, n, post); + return; + } +} +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/elementwise_sub_op.cc b/paddle/operators/elementwise_sub_op.cc index 31c37ff7ab5595c29f973929387d3945b6f3aaf8..3e4f98fdb35b148931a67d511fe41958eb523f99 100644 --- a/paddle/operators/elementwise_sub_op.cc +++ b/paddle/operators/elementwise_sub_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/elementwise_sub_op.h" +#include "paddle/operators/elementwise_op.h" namespace paddle { namespace operators { diff --git a/paddle/operators/elementwise_sub_op.h b/paddle/operators/elementwise_sub_op.h index f6bc66cd0e1594a8bc7070e2f182401b92d1c88e..3ca1376c73b3332b76a5973e201f9e4fba77cd21 100644 --- a/paddle/operators/elementwise_sub_op.h +++ b/paddle/operators/elementwise_sub_op.h @@ -13,7 +13,7 @@ limitations under the License. */ #pragma once -#include "paddle/operators/elementwise_op.h" +#include "paddle/operators/elementwise_op_function.h" namespace paddle { namespace operators {