diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1742925545d29df5d7df719faaea3b754680ab61 --- /dev/null +++ b/paddle/operators/elementwise_mul_op.cc @@ -0,0 +1,109 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/elementwise_mul_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class ElementWiseMulOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + auto x_dim = ctx.Input("X")->dims(); + auto y_dim = ctx.Input("Y")->dims(); + PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), + "Rank of first input must >= rank of second input.") + ctx.Output("Out")->Resize(x_dim); + } +}; + +class ElementWiseMulOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ElementWiseMulOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of elementwise mul op"); + AddInput("Y", "The second input of elementwise mul op"); + AddAttr("axis", + R"DOC( +When shape(Y) does not equal shape(X),Y will be broadcasted +to match the shape of X and axis should be dimension index Y in X + )DOC") + .SetDefault(-1) + .EqualGreaterThan(-1); + + AddOutput("Out", "The output of elementwise mul op"); + AddComment(R"DOC( +Limited elementwise multiple operator.The equation is: Out = X ⊙ Y. +1. The shape of Y should be same with X or +2. Y's shape is a subset of X. + Y will be broadcasted to match the shape of X and axis should be dimension index Y in X. + example: + shape(X) = (2, 3, 4, 5), shape(Y) = (,) + shape(X) = (2, 3, 4, 5), shape(Y) = (5,) + shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) + shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 +)DOC"); + } +}; + +class ElementWiseMulOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + auto *y_grad = ctx.Output(framework::GradVarName("Y")); + + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input.") + + if (x_grad) { + x_grad->Resize(x_dims); + } + + if (y_grad) { + y_grad->Resize(y_dims); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(elementwise_mul, ops::ElementWiseMulOp, ops::ElementWiseMulOpMaker, + elementwise_mul_grad, ops::ElementWiseMulOpGrad); +REGISTER_OP_CPU_KERNEL( + elementwise_mul, + ops::ElementWiseMulKernel); +REGISTER_OP_CPU_KERNEL( + elementwise_mul_grad, + ops::ElementWiseMulGradKernel); diff --git a/paddle/operators/elementwise_mul_op.cu b/paddle/operators/elementwise_mul_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..56f2087c22c6c599a3c5aef36eb0fe3eac295bef --- /dev/null +++ b/paddle/operators/elementwise_mul_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/elementwise_mul_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + elementwise_mul, + ops::ElementWiseMulKernel); +REGISTER_OP_GPU_KERNEL( + elementwise_mul_grad, + ops::ElementWiseMulGradKernel); diff --git a/paddle/operators/elementwise_mul_op.h b/paddle/operators/elementwise_mul_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e9ed6791799240039f9af42c1a4339be7126ee65 --- /dev/null +++ b/paddle/operators/elementwise_mul_op.h @@ -0,0 +1,185 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +/* + * Out = X ⊙ Y + * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + * pre=2, n=3*4, post=5 + * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) + * pre=2*3, n=4*5, post=1 + */ + +inline void get_mid_dims(const framework::DDim& x_dims, + const framework::DDim& y_dims, const int axis, + int& pre, int& n, int& post) { + pre = 1; + n = 1; + post = 1; + for (int i = 0; i < axis; ++i) { + pre *= x_dims[i]; + } + + for (int i = 0; i < y_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i], + "Broadcast dimension mismatch."); + n *= y_dims[i]; + } + + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + post *= x_dims[i]; + } +} + +template +class ElementWiseMulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto z_e = framework::EigenVector::Flatten(*z); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input.") + + if (x_dims == y_dims || product(y_dims) == 1) { + z_e.device(ctx.GetEigenDevice()) = x_e * y_e; + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + auto y_bcast = y_e.reshape(Eigen::DSizes(1, n)) + .broadcast(Eigen::DSizes(pre, 1)) + .reshape(Eigen::DSizes(x_e.size())); + z_e.device(ctx.GetEigenDevice()) = x_e * y_bcast; + return; + } else { + auto y_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) + .broadcast(Eigen::DSizes(pre, 1, post)) + .reshape(Eigen::DSizes(x_e.size())); + z_e.device(ctx.GetEigenDevice()) = x_e * y_bcast; + return; + } + } +}; + +template +class ElementWiseMulGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + auto x_e = framework::EigenVector::Flatten(*x); + auto y_e = framework::EigenVector::Flatten(*y); + auto dout_e = framework::EigenVector::Flatten(*dout); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + if (dx) { + dx->mutable_data(ctx.GetPlace()); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + } + + if (x_dims == y_dims || product(y_dims) == 1) { + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(ctx.GetEigenDevice()) = dout_e * y_e; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(ctx.GetEigenDevice()) = x_e * dout_e; + } + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + + // TODO(gongweibao): wrap reshape to a function. + if (post == 1) { + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n)) + .broadcast(Eigen::DSizes(pre, 1)) + .reshape(Eigen::DSizes(x_e.size())); + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(ctx.GetEigenDevice()) = dout_e * y_e_bcast; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(ctx.GetEigenDevice()) = + (x_e * dout_e) + .reshape(Eigen::DSizes(pre, n)) + .sum(Eigen::array{{0}}); + } + return; + } else { + auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1)) + .broadcast(Eigen::DSizes(pre, 1, post)) + .reshape(Eigen::DSizes(x_e.size())); + if (dx) { + auto dx_e = framework::EigenVector::Flatten(*dx); + dx_e.device(ctx.GetEigenDevice()) = dout_e * y_e_bcast; + } + + if (dy) { + auto dy_e = framework::EigenVector::Flatten(*dy); + dy_e.device(ctx.GetEigenDevice()) = + (x_e * dout_e) + .reshape(Eigen::DSizes(pre, n, post)) + .sum(Eigen::array{{0, 2}}); + } + return; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 16a2368aae5fff7445654161db4fd6a97d5bebfc..ef62d6e997fa0af5c342e0c310c09b840b0b8583 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -35,6 +35,7 @@ USE_OP(add); USE_OP(onehot_cross_entropy); USE_OP(sgd); USE_OP(mul); +USE_OP(elementwise_mul); USE_OP(mean); USE_OP(sigmoid); USE_OP(softmax); diff --git a/python/paddle/v2/framework/tests/test_elementwise_mul_op.py b/python/paddle/v2/framework/tests/test_elementwise_mul_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e268cfddb26721a35ddd2d2cc18f526ff7b2f6d9 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_elementwise_mul_op.py @@ -0,0 +1,157 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestElementwiseMulOp_Matrix(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + """ Warning + CPU gradient check error! + 'X': np.random.random((32,84)).astype("float32"), + 'Y': np.random.random((32,84)).astype("float32") + """ + self.inputs = { + 'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32") + } + self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y')) + + +class TestElementwiseMulOp_Vector(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + 'X': np.random.random((32, )).astype("float32"), + 'Y': np.random.random((32, )).astype("float32") + } + self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y')) + + +class TestElementwiseMulOp_broadcast_0(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + 'X': np.random.rand(2, 3, 4).astype(np.float32), + 'Y': np.random.rand(2).astype(np.float32) + } + + self.attrs = {'axis': 0} + self.outputs = { + 'Out': self.inputs['X'] * self.inputs['Y'].reshape(2, 1, 1) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y')) + + +class TestElementwiseMulOp_broadcast_1(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + 'X': np.random.rand(2, 3, 4).astype(np.float32), + 'Y': np.random.rand(3).astype(np.float32) + } + + self.attrs = {'axis': 1} + self.outputs = { + 'Out': self.inputs['X'] * self.inputs['Y'].reshape(1, 3, 1) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y')) + + +class TestElementwiseMulOp_broadcast_2(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + 'X': np.random.rand(2, 3, 4).astype(np.float32), + 'Y': np.random.rand(4).astype(np.float32) + } + + self.outputs = { + 'Out': self.inputs['X'] * self.inputs['Y'].reshape(1, 1, 4) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y')) + + +class TestElementwiseMulOp_broadcast_3(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + 'X': np.random.rand(2, 3, 4, 5).astype(np.float32), + 'Y': np.random.rand(3, 4).astype(np.float32) + } + + self.attrs = {'axis': 1} + self.outputs = { + 'Out': self.inputs['X'] * self.inputs['Y'].reshape(1, 3, 4, 1) + } + + +if __name__ == '__main__': + unittest.main()