diff --git a/paddle/operators/bilinear_tensor_product_op.cc b/paddle/operators/bilinear_tensor_product_op.cc index 64569e5fe77bb61c294a65016ae195506df2cad0..3bd2d40cd284aae5a3d4113f993786695274e8c8 100644 --- a/paddle/operators/bilinear_tensor_product_op.cc +++ b/paddle/operators/bilinear_tensor_product_op.cc @@ -34,8 +34,8 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); auto weight_dims = ctx->GetInputDim("Weight"); - PADDLE_ENFORCE_EQ(x_dims.size(), 1, "The input X must be a vector."); - PADDLE_ENFORCE_EQ(y_dims.size(), 1, "The input Y must be a vector."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The input X must be a 2D Tensor."); + PADDLE_ENFORCE_EQ(y_dims.size(), 2, "The input Y must be a 2D Tensor."); PADDLE_ENFORCE_EQ(weight_dims.size(), 3, "The input Weight must be a 3D tensor."); PADDLE_ENFORCE_GT(weight_dims[0], 0, @@ -44,24 +44,29 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { "The second dimension of Weight must be larger than 0."); PADDLE_ENFORCE_GT(weight_dims[2], 0, "The third dimension of Weight must be larger than 0."); - PADDLE_ENFORCE_EQ(x_dims[0], weight_dims[1], - "The dimension of X must be equal with the second " + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The first dimension(batch_size) of X must be " + "equal with the first dimension of the Y."); + PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], + "The second dimension of X must be equal with the second " "dimension of the Weight."); - PADDLE_ENFORCE_EQ(y_dims[0], weight_dims[2], - "The dimension of Y must be equal with the third " + PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], + "The second dimension of Y must be equal with the third " "dimension of the Weight."); - auto bias = Input("Bias"); - if (bias != framework::kEmptyVarName) { + if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims.size(), 1, - "The input Bias must be a vector."); - PADDLE_ENFORCE_EQ(bias_dims[0], weight_dims[0], - "The dimension of Bias must be equal with the first " - "dimension of the Weight."); + PADDLE_ENFORCE_EQ(bias_dims.size(), 2, + "The input Bias must have 2 dimensions."); + PADDLE_ENFORCE_EQ(bias_dims[0], 1, + "The first dimention of input Bias must be 1."); + PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], + "The second dimension of Bias must be equal with the " + "first dimension of the Weight."); } - ctx->SetOutputDim("Out", {weight_dims[0]}); + ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -70,19 +75,19 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { BilinearTensorProductOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of tensor op"); - AddInput("Y", "The second input of tensor op"); - AddInput("Weight", "The input weight of tensor op"); - AddInput("Bias", "The input bias of tensor op"); - AddOutput("Out", "The output of tensor op"); + AddInput("X", "The first input of BilinearTensorProduct op"); + AddInput("Y", "The second input of BilinearTensorProduct op"); + AddInput("Weight", "The input weight of BilinearTensorProduct op"); + AddInput("Bias", "The input bias of BilinearTensorProduct op") + .AsDispensable(); + AddOutput("Out", "The output of BilinearTensorProduct op"); AddComment(R"DOC( Bilinear Tensor Product operator. -Given input X and Y, a 3D tensor weight, and bias. Each entry of the output is -computed by one slice i = 1, . . . , k of the tensor: Out_i = X*W_i*Y + Bias_i . +Given input X and Y, a 3D tensor weight, and bias. Each column of the +output is computed by one slice i = 1, . . . , k of the tensor: -The equation of this operator is: - - Out = \sum_{i} X*W_i*Y + Bias + M = (X W_i) \cdot Y + Out_i = \sum_i {M_i} + Bias_i )DOC"); } @@ -104,19 +109,20 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { auto weight_dims = ctx->GetInputDim("Weight"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ(out_dims.size(), 1, "The Out@GRAD must be a vector."); + PADDLE_ENFORCE_EQ(out_dims.size(), 2, "The Out@GRAD must be a 2D Tensor."); PADDLE_ENFORCE_EQ( - weight_dims[0], out_dims[0], - "The dimension of Out@GRAD must be equal with the third dimension of " - "the Weight."); - - auto bias = Input("Bias"); - if (bias != framework::kEmptyVarName) { + x_dims[0], out_dims[0], + "The first dimension(batch_size) of Out@GRAD must be equal with " + "the first dimension of the X."); + PADDLE_ENFORCE_EQ(weight_dims[0], out_dims[1], + "The second dimension of Out@GRAD must be equal with " + "the third dimension of the Weight."); + + if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(bias_dims.size(), 1, "Input Bias must be a vector."); - PADDLE_ENFORCE_EQ( - bias_dims[0], out_dims[0], - "The dimension of Bias must be equal with the Out@GRAD "); + PADDLE_ENFORCE_EQ(bias_dims[1], out_dims[1], + "The second dimension of Bias must be equal with " + "the second dimension of the Out@GRAD."); auto bias_grad_name = framework::GradVarName("Bias"); if (ctx->HasOutput(bias_grad_name)) ctx->SetOutputDim(bias_grad_name, bias_dims); @@ -150,4 +156,4 @@ REGISTER_OP_CPU_KERNEL( ops::BilinearTensorProductKernel); REGISTER_OP_CPU_KERNEL( bilinear_tensor_product_grad, - ops::BilinearTensorProductGradKernel); + ops::BilinearTensorProductGradKernel); \ No newline at end of file diff --git a/paddle/operators/bilinear_tensor_product_op.cu b/paddle/operators/bilinear_tensor_product_op.cu index a212460560e796a2295a8e03a1d9a57c5470f2ac..1d65c17f8c1a6630215fefe4c26b52e6abe09196 100644 --- a/paddle/operators/bilinear_tensor_product_op.cu +++ b/paddle/operators/bilinear_tensor_product_op.cu @@ -15,10 +15,85 @@ #define EIGEN_USE_GPU #include "paddle/operators/bilinear_tensor_product_op.h" +namespace paddle { +namespace operators { + +template +class BilinearTensorProductCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto y_mat = EigenMatrix::From(*y); + auto batch_size = x->dims()[0]; + auto weight_dims = weight->dims(); + + auto place = ctx.GetEigenDevice(); + auto cpu_place = ctx.GetEigenDevice(); + + // Copy the output to cpu. + Tensor output_cpu; + output_cpu.CopyFrom(*out, platform::CPUPlace(), ctx.device_context()); + auto* output_cpu_ptr = output_cpu.data(); + auto output_cpu_mat = EigenMatrix::From(output_cpu); + + // Create the temporary variables. + Tensor left_mul; + left_mul.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), + ctx.GetPlace()); + auto left_mul_mat = EigenMatrix::From(left_mul); + Tensor output_col; + output_col.mutable_data(framework::make_ddim({batch_size}), + ctx.GetPlace()); + auto output_col_vec = EigenVector::From(output_col); + + for (size_t i = 0; i < weight_dims[0]; ++i) { + Tensor weight_mat = weight->Slice(i, i + 1).Resize( + framework::make_ddim({weight_dims[1], weight_dims[2]})); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, weight_dims[2], weight_dims[1], 1, + x->data(), weight_mat.data(), 0, + left_mul.data()); + output_col_vec.device(place) = + (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + + // Copy the output_col to cpu. + Tensor output_col_cpu; + output_col_cpu.CopyFrom(output_col, platform::CPUPlace(), + ctx.device_context()); + auto* output_col_ptr = output_col_cpu.data(); + + for (size_t j = 0; j < batch_size; ++j) { + output_cpu_ptr[i + j * weight_dims[0]] = output_col_ptr[j]; + } + } + + if (bias) { + // Copy the bias to cpu. + Tensor bias_cpu; + bias_cpu.CopyFrom(*bias, platform::CPUPlace(), ctx.device_context()); + auto bias_vec = EigenMatrix::From(bias_cpu); + Eigen::DSizes bcast(batch_size, 1); + output_cpu_mat.device(cpu_place) = + bias_vec.broadcast(bcast) + output_cpu_mat; + } + + // Copy the output to gpu. + out->CopyFrom(output_cpu, platform::GPUPlace(), ctx.device_context()); + } +}; +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( bilinear_tensor_product, - ops::BilinearTensorProductKernel); + ops::BilinearTensorProductCUDAKernel); REGISTER_OP_GPU_KERNEL( bilinear_tensor_product_grad, - ops::BilinearTensorProductGradKernel); + ops::BilinearTensorProductGradKernel); \ No newline at end of file diff --git a/paddle/operators/bilinear_tensor_product_op.h b/paddle/operators/bilinear_tensor_product_op.h index b816d6d7c210d6a7f9aae5ba2dcb524f1b063afa..238d1d7749694d656fa9f3af14d337025a57405b 100644 --- a/paddle/operators/bilinear_tensor_product_op.h +++ b/paddle/operators/bilinear_tensor_product_op.h @@ -14,15 +14,22 @@ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" -#include "paddle/platform/transform.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -using platform::Transform; + +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; template class BilinearTensorProductKernel : public framework::OpKernel { @@ -35,43 +42,45 @@ class BilinearTensorProductKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); + auto y_mat = EigenMatrix::From(*y); + auto output_mat = EigenMatrix::From(*out); + + auto batch_size = x->dims()[0]; auto weight_dims = weight->dims(); - Tensor left_mul_vec; - left_mul_vec.mutable_data(framework::make_ddim({weight_dims[2]}), - ctx.GetPlace()); - if (bias) { - out->CopyFrom(*bias, ctx.GetPlace(), ctx.device_context()); - } - for (int i = 0; i < weight_dims[0]; ++i) { + auto place = ctx.GetEigenDevice(); + + // Create the temporary variables. + Tensor left_mul; + left_mul.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), + ctx.GetPlace()); + auto left_mul_mat = EigenMatrix::From(left_mul); + Tensor output_col; + output_col.mutable_data(framework::make_ddim({weight_dims[0]}), + ctx.GetPlace()); + auto output_col_vec = EigenVector::From(output_col); + + for (size_t i = 0; i < weight_dims[0]; ++i) { Tensor weight_mat = weight->Slice(i, i + 1).Resize( framework::make_ddim({weight_dims[1], weight_dims[2]})); - math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, 1, - weight_dims[2], weight_dims[1], 1, x->data(), - weight_mat.data(), 0, left_mul_vec.data()); - if (bias) { - math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - 1, 1, weight_dims[2], 1, left_mul_vec.data(), - y->data(), 1, &(out->data()[i])); - } else { - math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, - 1, 1, weight_dims[2], 1, left_mul_vec.data(), - y->data(), 0, &(out->data()[i])); + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, weight_dims[2], weight_dims[1], 1, + x->data(), weight_mat.data(), 0, + left_mul.data()); + output_col_vec = (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); + for (size_t j = 0; j < batch_size; ++j) { + output_mat(j, i) = output_col_vec(j); } } + if (bias) { + auto bias_vec = EigenMatrix::From(*bias); + Eigen::DSizes bcast(batch_size, 1); + output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat; + } else { + output_mat.device(place) = output_mat; + } } }; -template -class ScaleFunctor { - public: - explicit ScaleFunctor(const T* scale) : scale_(scale) {} - - HOSTDEVICE T operator()(const T& x) const { return x * (*scale_); } - - private: - const T* scale_; -}; - template class BilinearTensorProductGradKernel : public framework::OpKernel { public: @@ -84,66 +93,65 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { Tensor* d_weight = ctx.Output(framework::GradVarName("Weight")); Tensor* d_bias = ctx.Output(framework::GradVarName("Bias")); const Tensor* d_out = ctx.Input(framework::GradVarName("Out")); - auto* d_out_ptr = d_out->data(); + + auto batch_size = x->dims()[0]; auto weight_dims = weight->dims(); - // Get the first matrix of Weight. - Tensor weight_mat_0 = weight->Slice(0, 1).Resize( - framework::make_ddim({weight_dims[1], weight_dims[2]})); + auto x_mat = EigenMatrix::From(*x); + auto y_mat = EigenMatrix::From(*y); + auto d_out_mat = EigenMatrix::From(*d_out); + auto place = ctx.GetEigenDevice(); - // Create the intermediate variable for gradient. - int numel_x = x->numel(); - int numel_y = y->numel(); - const T* x_ptr = x->data(); - const T* y_ptr = y->data(); + // Create the temporary variables for gradient. Tensor x_scale; - T* x_scale_ptr = x_scale.mutable_data( - framework::make_ddim({weight_dims[1]}), ctx.GetPlace()); + x_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[1]}), + ctx.GetPlace()); + auto x_scale_mat = EigenMatrix::From(x_scale); Tensor y_scale; - T* y_scale_ptr = y_scale.mutable_data( - framework::make_ddim({weight_dims[2]}), ctx.GetPlace()); - Transform trans; + y_scale.mutable_data(framework::make_ddim({batch_size, weight_dims[2]}), + ctx.GetPlace()); + auto y_scale_mat = EigenMatrix::From(y_scale); + + math::SetConstant set_zero; - // Caculate the gradient of X according to the first matrix of Weight. + // Set X@Grad be zero at first. if (d_x) { d_x->mutable_data(ctx.GetPlace()); - trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr, - ScaleFunctor(&d_out_ptr[0])); - math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, 1, - weight_dims[1], weight_dims[2], 1, y_scale.data(), - weight_mat_0.data(), 0, d_x->data()); + set_zero(ctx.device_context(), d_x, static_cast(0)); } - // Caculate the gradient of Y according to the first matrix of Weight. + // Set Y@Grad be zero at first. if (d_y) { d_y->mutable_data(ctx.GetPlace()); - trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, - ScaleFunctor(&d_out_ptr[0])); - math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, - weight_dims[2], 1, weight_dims[1], 1, - weight_mat_0.data(), x_scale.data(), 0, - d_y->data()); + set_zero(ctx.device_context(), d_y, static_cast(0)); } - // Caculate the gradient of X and Y completly. + // Caculate the X@Grad and Y@Grad. if (d_x || d_y) { - for (int i = 1; i < weight_dims[0]; ++i) { - Tensor weight_mat = weight->Slice(i, i + 1).Resize( + Eigen::DSizes bcast_for_x(1, weight_dims[2]); + Eigen::DSizes bcast_for_y(1, weight_dims[1]); + for (int i = 0; i < weight_dims[0]; ++i) { + Tensor weight_i = weight->Slice(i, i + 1).Resize( framework::make_ddim({weight_dims[1], weight_dims[2]})); + auto output_vec = d_out_mat.chip(i, 1); if (d_x) { - trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr, - ScaleFunctor(&d_out_ptr[i])); + y_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_x) * + y_mat; math::gemm(ctx.device_context(), CblasNoTrans, CblasTrans, - 1, weight_dims[1], weight_dims[2], 1, - y_scale.data(), weight_mat.data(), 1, + batch_size, weight_dims[1], weight_dims[2], 1, + y_scale.data(), weight_i.data(), 1, d_x->data()); } if (d_y) { - trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, - ScaleFunctor(&d_out_ptr[i])); - math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, - weight_dims[2], 1, weight_dims[1], 1, - weight_mat.data(), x_scale.data(), 1, + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_y) * + x_mat; + math::gemm(ctx.device_context(), CblasNoTrans, CblasNoTrans, + batch_size, weight_dims[2], weight_dims[1], 1, + x_scale.data(), weight_i.data(), 1, d_y->data()); } } @@ -152,22 +160,27 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { // Caculate the gradient of Weight. if (d_weight) { d_weight->mutable_data(ctx.GetPlace()); + Eigen::DSizes bcast_for_weight(1, weight_dims[1]); for (int i = 0; i < weight_dims[0]; ++i) { - Tensor d_weight_mat = d_weight->Slice(i, i + 1).Resize( + Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize( framework::make_ddim({weight_dims[1], weight_dims[2]})); - trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, - ScaleFunctor(&d_out_ptr[i])); + auto output_vec = d_out_mat.chip(i, 1); + x_scale_mat.device(place) = + output_vec.reshape(Eigen::DSizes(batch_size, 1)) + .broadcast(bcast_for_weight) * + x_mat; math::gemm(ctx.device_context(), CblasTrans, CblasNoTrans, - weight_dims[1], weight_dims[2], 1, 1, + weight_dims[1], weight_dims[2], batch_size, 1, x_scale.data(), y->data(), 0, - d_weight_mat.data()); + d_weight_i.data()); } } // Caculate the gradient of Bias. if (d_bias) { d_bias->mutable_data(ctx.GetPlace()); - d_bias->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context()); + auto d_bias_mat = EigenMatrix::From(*d_bias); + d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes(0)); } } }; diff --git a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py index 10d90a9f0f9f8babe8c072550e0fcd28ebcf4314..1c1f388098065fcdb8f1237288bdf9f566fb463b 100644 --- a/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py +++ b/python/paddle/v2/framework/tests/test_bilinear_tensor_product_op.py @@ -6,24 +6,85 @@ from op_test import OpTest class TestBilinearTensorProductOp(OpTest): def setUp(self): self.op_type = "bilinear_tensor_product" + batch_size = 6 + size0 = 3 + size1 = 4 + size2 = 5 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + bias = np.random.random((1, size2)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) self.inputs = { - 'X': np.random.random(3).astype("float32"), - 'Y': np.random.random(4).astype("float32"), - 'Weight': np.random.random((5, 3, 4)).astype("float32"), - 'Bias': np.random.random(5).astype("float32") + 'X': a, + 'Y': b, + 'Weight': w, + 'Bias': bias, } - self.outputs = { - 'Out': np.matmul( - np.matmul(self.inputs['Weight'], self.inputs['Y']), - self.inputs['X']) + self.inputs['Bias'] + self.outputs = {'Out': output + bias} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') + + +class TestBilinearTensorProductOp2(TestBilinearTensorProductOp): + def setUp(self): + self.op_type = "bilinear_tensor_product" + batch_size = 1 + size0 = 1 + size1 = 1 + size2 = 1 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + bias = np.random.random((1, size2)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) + self.inputs = { + 'X': a, + 'Y': b, + 'Weight': w, + 'Bias': bias, } + self.outputs = {'Out': output + bias} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') + + +class TestBilinearTensorProductOp3(TestBilinearTensorProductOp): + def setUp(self): + self.op_type = "bilinear_tensor_product" + batch_size = 7 + size0 = 4 + size1 = 5 + size2 = 6 + a = np.random.random((batch_size, size0)).astype("float32") + b = np.random.random((batch_size, size1)).astype("float32") + w = np.random.random((size2, size0, size1)).astype("float32") + output = np.zeros((batch_size, size2)).astype("float32") + for i in range(size2): + w_i = w[i, :, :] + output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1) + self.inputs = {'X': a, 'Y': b, 'Weight': w} + self.outputs = {'Out': output} def test_check_output(self): self.check_output() def test_check_grad_normal(self): - self.check_grad( - ['X', 'Y', 'Weight', 'Bias'], 'Out', max_relative_error=0.5) + self.check_grad(['X', 'Y', 'Weight'], 'Out') if __name__ == "__main__":