diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c65ba7eb262f3aabe2c00837b79806c0b40b60fd --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class BilinearTensorProductOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input(X) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL, + "The input(Weight) must be a 3D tensor."); + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The first dimension(batch_size) of input(X) must be " + "equal to the first dimension of the input(Y)."); + PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], + "The second dimension of input(X) must be equal to " + "the second dimension of the input(Weight)."); + PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], + "The second dimension of input(Y) must be equal to " + "the third dimension of the input(Weight)."); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE(bias_dims.size() == 2UL && bias_dims[0] == 1UL, + "The Input(Bias) must be a 2-D tensor with " + "the 2nd dimension fixed to 1 (a row vector)."); + PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], + "The second dimension of input(Bias) must be equal " + "to the first dimension of the input(Weight)."); + } + + ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BilinearTensorProductOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The first input of bilinear_tensor_product operator."); + AddInput("Y", "The second input of bilinear_tensor_product operator."); + AddInput("Weight", + "The learnable parameters of bilinear_tensor_product operator."); + AddInput("Bias", "The learnable bias of bilinear_tensor_product operator.") + .AsDispensable(); + AddOutput("Out", "The output of bilinear_tensor_product operator."); + AddComment(R"DOC( +Bilinear Tensor Product operator. +Given input X and Y, a 3D tensor weight, and bias. Each column of the +output is computed by one slice i = 1, . . . , k of the tensor: + + M = (X W_i) \cdot Y + Out_i = \sum_i {M_i} + Bias_i + +)DOC"); + } +}; + +class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto weight_dims = ctx->GetInputDim("Weight"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(out_dims.size(), 2UL, + "The input(Out@GRAD) must be a 2D Tensor."); + PADDLE_ENFORCE_EQ( + x_dims[0], out_dims[0], + "The first dimension(batch_size) of input(Out@GRAD) must be " + "equal to the first dimension of the Input(X)."); + PADDLE_ENFORCE_EQ( + weight_dims[0], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the third dimension of the Input(Weight)."); + + if (ctx->HasInput("Bias")) { + auto bias_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ( + bias_dims[1], out_dims[1], + "The second dimension of input(Out@GRAD) must be equal to " + "the second dimension of the Input(Bias)."); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + auto weight_grad_name = framework::GradVarName("Weight"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + if (ctx->HasOutput(weight_grad_name)) { + ctx->SetOutputDim(weight_grad_name, weight_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp, + ops::BilinearTensorProductOpMaker, bilinear_tensor_product_grad, + ops::BilinearTensorProductOpGrad); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); +REGISTER_OP_CPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..858d2668d01379afe8082cd1eda32a2a5d09bd18 --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/bilinear_tensor_product_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product, + ops::BilinearTensorProductKernel, + ops::BilinearTensorProductKernel); +REGISTER_OP_GPU_KERNEL( + bilinear_tensor_product_grad, + ops::BilinearTensorProductGradKernel, + ops::BilinearTensorProductGradKernel); diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ffa4f43a327418498c1f110504127e7d2878409d --- /dev/null +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -0,0 +1,184 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +class BilinearTensorProductKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto y_mat = EigenMatrix::From(*y); + auto output_mat = EigenMatrix::From(*out); + + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + auto place = ctx.GetEigenDevice(); + + // Create the intermediate variable to caculate the result of + // Input(X) multiplied by Input(Weight_i), the formula is: + // left_mul = X Weight_i. + Tensor left_mul; + left_mul.mutable_data(framework::make_ddim({batch_size, y_dim}), + ctx.GetPlace()); + auto left_mul_mat = EigenMatrix::From(left_mul); + + for (int i = 0; i < out_dim; ++i) { + auto output_col_vec = output_mat.chip(i, 1); + Tensor weight_mat = + weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, y_dim, x_dim, 1, x->data(), + weight_mat.data(), 0, left_mul.data()); + output_col_vec.device(place) = + (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + } + if (bias) { + auto bias_vec = EigenMatrix::From(*bias); + Eigen::DSizes bcast(batch_size, 1); + output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; + } + } +}; + +template +class BilinearTensorProductGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + const Tensor* weight = ctx.Input("Weight"); + Tensor* d_x = ctx.Output(framework::GradVarName("X")); + Tensor* d_y = ctx.Output(framework::GradVarName("Y")); + Tensor* d_weight = ctx.Output(framework::GradVarName("Weight")); + Tensor* d_bias = ctx.Output(framework::GradVarName("Bias")); + const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); + + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + int out_dim = weight_dims[0]; + auto x_dim = weight_dims[1]; + auto y_dim = weight_dims[2]; + + auto x_mat = EigenMatrix::From(*x); + auto y_mat = EigenMatrix::From(*y); + auto d_out_mat = EigenMatrix::From(*d_out); + auto place = ctx.GetEigenDevice(); + + // Create the intermediate variable to caculate the Output(Y@Grad). + Tensor x_scale; + x_scale.mutable_data(framework::make_ddim({batch_size, x_dim}), + ctx.GetPlace()); + auto x_scale_mat = EigenMatrix::From(x_scale); + + // Create the intermediate variable to caculate the Output(X@Grad). + Tensor y_scale; + y_scale.mutable_data(framework::make_ddim({batch_size, y_dim}), + ctx.GetPlace()); + auto y_scale_mat = EigenMatrix::From(y_scale); + + math::SetConstant set_zero; + + // Set Output(X@Grad) be zero. + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + set_zero(ctx.device_context(), d_x, static_cast(0)); + } + + // Set Output(Y@Grad) be zero. + if (d_y) { + d_y->mutable_data(ctx.GetPlace()); + set_zero(ctx.device_context(), d_y, static_cast(0)); + } + + // Caculate the Output(X@Grad) and Output(Y@Grad). + if (d_x || d_y) { + Eigen::DSizes bcast_for_x(1, y_dim); + Eigen::DSizes bcast_for_y(1, x_dim); + for (int i = 0; i < out_dim; ++i) { + Tensor weight_i = weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + auto output_vec = d_out_mat.chip(i, 1); + if (d_x) { + y_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_x) * + y_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, + batch_size, x_dim, y_dim, 1, y_scale.data(), + weight_i.data(), 1, d_x->data()); + } + if (d_y) { + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_y) * + x_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, y_dim, x_dim, 1, x_scale.data(), + weight_i.data(), 1, d_y->data()); + } + } + } + + // Caculate the gradient of Input(Weight). + if (d_weight) { + d_weight->mutable_data(ctx.GetPlace()); + Eigen::DSizes bcast_for_weight(1, x_dim); + for (int i = 0; i < out_dim; ++i) { + Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( + framework::make_ddim({x_dim, y_dim})); + auto output_vec = d_out_mat.chip(i, 1); + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_weight) * + x_mat; + math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, + x_dim, y_dim, batch_size, 1, x_scale.data(), + y->data(), 0, d_weight_i.data()); + } + } + + // Caculate the gradient of Input(Bias). + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + auto d_bias_mat = EigenMatrix::From(*d_bias); + d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes(0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py new file mode 100644 index 0000000000000000000000000000000000000000..080ca43b8269e0f6a9f4d0ce3973f4d4a07a8e2a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py @@ -0,0 +1,37 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestBilinearTensorProductOp(OpTest): + def setUp(self): + self.op_type = "bilinear_tensor_product" + batch_size = 6 + size0 = 3 + size1 = 4 + size2 = 5 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + bias = np.random.random((1, size2)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) + self.inputs = { + 'X': a, + 'Y': b, + 'Weight': w, + 'Bias': bias, + } + self.outputs = {'Out': output + bias} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') + + +if __name__ == "__main__": + unittest.main()