提交 3ae14242 编写于 作者: P peterzhang2029

update for mini-batch

上级 611ee68b
...@@ -34,8 +34,8 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { ...@@ -34,8 +34,8 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(x_dims.size(), 1, "The input X must be a vector."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The input X must be a 2D Tensor.");
PADDLE_ENFORCE_EQ(y_dims.size(), 1, "The input Y must be a vector."); PADDLE_ENFORCE_EQ(y_dims.size(), 2, "The input Y must be a 2D Tensor.");
PADDLE_ENFORCE_EQ(weight_dims.size(), 3, PADDLE_ENFORCE_EQ(weight_dims.size(), 3,
"The input Weight must be a 3D tensor."); "The input Weight must be a 3D tensor.");
PADDLE_ENFORCE_GT(weight_dims[0], 0, PADDLE_ENFORCE_GT(weight_dims[0], 0,
...@@ -44,24 +44,29 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { ...@@ -44,24 +44,29 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
"The second dimension of Weight must be larger than 0."); "The second dimension of Weight must be larger than 0.");
PADDLE_ENFORCE_GT(weight_dims[2], 0, PADDLE_ENFORCE_GT(weight_dims[2], 0,
"The third dimension of Weight must be larger than 0."); "The third dimension of Weight must be larger than 0.");
PADDLE_ENFORCE_EQ(x_dims[0], weight_dims[1], PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The dimension of X must be equal with the second " "The first dimension(batch_size) of X must be "
"equal with the first dimension of the Y.");
PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1],
"The second dimension of X must be equal with the second "
"dimension of the Weight."); "dimension of the Weight.");
PADDLE_ENFORCE_EQ(y_dims[0], weight_dims[2], PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2],
"The dimension of Y must be equal with the third " "The second dimension of Y must be equal with the third "
"dimension of the Weight."); "dimension of the Weight.");
auto bias = Input("Bias"); if (ctx->HasInput("Bias")) {
if (bias != framework::kEmptyVarName) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims.size(), 1, PADDLE_ENFORCE_EQ(bias_dims.size(), 2,
"The input Bias must be a vector."); "The input Bias must have 2 dimensions.");
PADDLE_ENFORCE_EQ(bias_dims[0], weight_dims[0], PADDLE_ENFORCE_EQ(bias_dims[0], 1,
"The dimension of Bias must be equal with the first " "The first dimention of input Bias must be 1.");
"dimension of the Weight."); PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0],
"The second dimension of Bias must be equal with the "
"first dimension of the Weight.");
} }
ctx->SetOutputDim("Out", {weight_dims[0]}); ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]});
ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
...@@ -70,19 +75,19 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -70,19 +75,19 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
BilinearTensorProductOpMaker(framework::OpProto* proto, BilinearTensorProductOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of tensor op"); AddInput("X", "The first input of BilinearTensorProduct op");
AddInput("Y", "The second input of tensor op"); AddInput("Y", "The second input of BilinearTensorProduct op");
AddInput("Weight", "The input weight of tensor op"); AddInput("Weight", "The input weight of BilinearTensorProduct op");
AddInput("Bias", "The input bias of tensor op"); AddInput("Bias", "The input bias of BilinearTensorProduct op")
AddOutput("Out", "The output of tensor op"); .AsDispensable();
AddOutput("Out", "The output of BilinearTensorProduct op");
AddComment(R"DOC( AddComment(R"DOC(
Bilinear Tensor Product operator. Bilinear Tensor Product operator.
Given input X and Y, a 3D tensor weight, and bias. Each entry of the output is Given input X and Y, a 3D tensor weight, and bias. Each column of the
computed by one slice i = 1, . . . , k of the tensor: Out_i = X*W_i*Y + Bias_i . output is computed by one slice i = 1, . . . , k of the tensor:
The equation of this operator is: M = (X W_i) \cdot Y
Out_i = \sum_i {M_i} + Bias_i
Out = \sum_{i} X*W_i*Y + Bias
)DOC"); )DOC");
} }
...@@ -104,19 +109,20 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { ...@@ -104,19 +109,20 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(out_dims.size(), 1, "The Out@GRAD must be a vector."); PADDLE_ENFORCE_EQ(out_dims.size(), 2, "The Out@GRAD must be a 2D Tensor.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_dims[0], out_dims[0], x_dims[0], out_dims[0],
"The dimension of Out@GRAD must be equal with the third dimension of " "The first dimension(batch_size) of Out@GRAD must be equal with "
"the Weight."); "the first dimension of the X.");
PADDLE_ENFORCE_EQ(weight_dims[0], out_dims[1],
auto bias = Input("Bias"); "The second dimension of Out@GRAD must be equal with "
if (bias != framework::kEmptyVarName) { "the third dimension of the Weight.");
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims.size(), 1, "Input Bias must be a vector."); PADDLE_ENFORCE_EQ(bias_dims[1], out_dims[1],
PADDLE_ENFORCE_EQ( "The second dimension of Bias must be equal with "
bias_dims[0], out_dims[0], "the second dimension of the Out@GRAD.");
"The dimension of Bias must be equal with the Out@GRAD ");
auto bias_grad_name = framework::GradVarName("Bias"); auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims); ctx->SetOutputDim(bias_grad_name, bias_dims);
...@@ -150,4 +156,4 @@ REGISTER_OP_CPU_KERNEL( ...@@ -150,4 +156,4 @@ REGISTER_OP_CPU_KERNEL(
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float>); ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product_grad, bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float>); ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float>);
\ No newline at end of file
...@@ -15,10 +15,85 @@ ...@@ -15,10 +15,85 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/bilinear_tensor_product_op.h" #include "paddle/operators/bilinear_tensor_product_op.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class BilinearTensorProductCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto y_mat = EigenMatrix<T>::From(*y);
auto batch_size = x->dims()[0];
auto weight_dims = weight->dims();
auto place = ctx.GetEigenDevice<Place>();
auto cpu_place = ctx.GetEigenDevice<platform::CPUPlace>();
// Copy the output to cpu.
Tensor output_cpu;
output_cpu.CopyFrom(*out, platform::CPUPlace(), ctx.device_context());
auto* output_cpu_ptr = output_cpu.data<T>();
auto output_cpu_mat = EigenMatrix<T>::From(output_cpu);
// Create the temporary variables.
Tensor left_mul;
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
ctx.GetPlace());
auto left_mul_mat = EigenMatrix<T>::From(left_mul);
Tensor output_col;
output_col.mutable_data<T>(framework::make_ddim({batch_size}),
ctx.GetPlace());
auto output_col_vec = EigenVector<T>::From(output_col);
for (size_t i = 0; i < weight_dims[0]; ++i) {
Tensor weight_mat = weight->Slice(i, i + 1).Resize(
framework::make_ddim({weight_dims[1], weight_dims[2]}));
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
batch_size, weight_dims[2], weight_dims[1], 1,
x->data<T>(), weight_mat.data<T>(), 0,
left_mul.data<T>());
output_col_vec.device(place) =
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
// Copy the output_col to cpu.
Tensor output_col_cpu;
output_col_cpu.CopyFrom(output_col, platform::CPUPlace(),
ctx.device_context());
auto* output_col_ptr = output_col_cpu.data<T>();
for (size_t j = 0; j < batch_size; ++j) {
output_cpu_ptr[i + j * weight_dims[0]] = output_col_ptr[j];
}
}
if (bias) {
// Copy the bias to cpu.
Tensor bias_cpu;
bias_cpu.CopyFrom(*bias, platform::CPUPlace(), ctx.device_context());
auto bias_vec = EigenMatrix<T>::From(bias_cpu);
Eigen::DSizes<int, 2> bcast(batch_size, 1);
output_cpu_mat.device(cpu_place) =
bias_vec.broadcast(bcast) + output_cpu_mat;
}
// Copy the output to gpu.
out->CopyFrom(output_cpu, platform::GPUPlace(), ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
bilinear_tensor_product, bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, float>); ops::BilinearTensorProductCUDAKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
bilinear_tensor_product_grad, bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, float>); ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, float>);
\ No newline at end of file
...@@ -14,15 +14,22 @@ ...@@ -14,15 +14,22 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using platform::Transform;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class BilinearTensorProductKernel : public framework::OpKernel<T> { class BilinearTensorProductKernel : public framework::OpKernel<T> {
...@@ -35,43 +42,45 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> { ...@@ -35,43 +42,45 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto y_mat = EigenMatrix<T>::From(*y);
auto output_mat = EigenMatrix<T>::From(*out);
auto batch_size = x->dims()[0];
auto weight_dims = weight->dims(); auto weight_dims = weight->dims();
Tensor left_mul_vec; auto place = ctx.GetEigenDevice<Place>();
left_mul_vec.mutable_data<T>(framework::make_ddim({weight_dims[2]}),
ctx.GetPlace()); // Create the temporary variables.
if (bias) { Tensor left_mul;
out->CopyFrom(*bias, ctx.GetPlace(), ctx.device_context()); left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
} ctx.GetPlace());
for (int i = 0; i < weight_dims[0]; ++i) { auto left_mul_mat = EigenMatrix<T>::From(left_mul);
Tensor output_col;
output_col.mutable_data<T>(framework::make_ddim({weight_dims[0]}),
ctx.GetPlace());
auto output_col_vec = EigenVector<T>::From(output_col);
for (size_t i = 0; i < weight_dims[0]; ++i) {
Tensor weight_mat = weight->Slice(i, i + 1).Resize( Tensor weight_mat = weight->Slice(i, i + 1).Resize(
framework::make_ddim({weight_dims[1], weight_dims[2]})); framework::make_ddim({weight_dims[1], weight_dims[2]}));
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans, 1, math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
weight_dims[2], weight_dims[1], 1, x->data<T>(), batch_size, weight_dims[2], weight_dims[1], 1,
weight_mat.data<T>(), 0, left_mul_vec.data<T>()); x->data<T>(), weight_mat.data<T>(), 0,
if (bias) { left_mul.data<T>());
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans, output_col_vec = (left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
1, 1, weight_dims[2], 1, left_mul_vec.data<T>(), for (size_t j = 0; j < batch_size; ++j) {
y->data<T>(), 1, &(out->data<T>()[i])); output_mat(j, i) = output_col_vec(j);
} else {
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
1, 1, weight_dims[2], 1, left_mul_vec.data<T>(),
y->data<T>(), 0, &(out->data<T>()[i]));
} }
} }
if (bias) {
auto bias_vec = EigenMatrix<T>::From(*bias);
Eigen::DSizes<int, 2> bcast(batch_size, 1);
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat;
} else {
output_mat.device(place) = output_mat;
}
} }
}; };
template <typename T>
class ScaleFunctor {
public:
explicit ScaleFunctor(const T* scale) : scale_(scale) {}
HOSTDEVICE T operator()(const T& x) const { return x * (*scale_); }
private:
const T* scale_;
};
template <typename Place, typename T> template <typename Place, typename T>
class BilinearTensorProductGradKernel : public framework::OpKernel<T> { class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -84,66 +93,65 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -84,66 +93,65 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
Tensor* d_weight = ctx.Output<Tensor>(framework::GradVarName("Weight")); Tensor* d_weight = ctx.Output<Tensor>(framework::GradVarName("Weight"));
Tensor* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); Tensor* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_out_ptr = d_out->data<T>();
auto batch_size = x->dims()[0];
auto weight_dims = weight->dims(); auto weight_dims = weight->dims();
// Get the first matrix of Weight. auto x_mat = EigenMatrix<T>::From(*x);
Tensor weight_mat_0 = weight->Slice(0, 1).Resize( auto y_mat = EigenMatrix<T>::From(*y);
framework::make_ddim({weight_dims[1], weight_dims[2]})); auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variable for gradient. // Create the temporary variables for gradient.
int numel_x = x->numel();
int numel_y = y->numel();
const T* x_ptr = x->data<T>();
const T* y_ptr = y->data<T>();
Tensor x_scale; Tensor x_scale;
T* x_scale_ptr = x_scale.mutable_data<T>( x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
framework::make_ddim({weight_dims[1]}), ctx.GetPlace()); ctx.GetPlace());
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
Tensor y_scale; Tensor y_scale;
T* y_scale_ptr = y_scale.mutable_data<T>( y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
framework::make_ddim({weight_dims[2]}), ctx.GetPlace()); ctx.GetPlace());
Transform<Place> trans; auto y_scale_mat = EigenMatrix<T>::From(y_scale);
math::SetConstant<Place, T> set_zero;
// Caculate the gradient of X according to the first matrix of Weight. // Set X@Grad be zero at first.
if (d_x) { if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr, set_zero(ctx.device_context(), d_x, static_cast<T>(0));
ScaleFunctor<T>(&d_out_ptr[0]));
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans, 1,
weight_dims[1], weight_dims[2], 1, y_scale.data<T>(),
weight_mat_0.data<T>(), 0, d_x->data<T>());
} }
// Caculate the gradient of Y according to the first matrix of Weight. // Set Y@Grad be zero at first.
if (d_y) { if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace()); d_y->mutable_data<T>(ctx.GetPlace());
trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, set_zero(ctx.device_context(), d_y, static_cast<T>(0));
ScaleFunctor<T>(&d_out_ptr[0]));
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
weight_dims[2], 1, weight_dims[1], 1,
weight_mat_0.data<T>(), x_scale.data<T>(), 0,
d_y->data<T>());
} }
// Caculate the gradient of X and Y completly. // Caculate the X@Grad and Y@Grad.
if (d_x || d_y) { if (d_x || d_y) {
for (int i = 1; i < weight_dims[0]; ++i) { Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
Tensor weight_mat = weight->Slice(i, i + 1).Resize( Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
for (int i = 0; i < weight_dims[0]; ++i) {
Tensor weight_i = weight->Slice(i, i + 1).Resize(
framework::make_ddim({weight_dims[1], weight_dims[2]})); framework::make_ddim({weight_dims[1], weight_dims[2]}));
auto output_vec = d_out_mat.chip(i, 1);
if (d_x) { if (d_x) {
trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr, y_scale_mat.device(place) =
ScaleFunctor<T>(&d_out_ptr[i])); output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_x) *
y_mat;
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans, math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans,
1, weight_dims[1], weight_dims[2], 1, batch_size, weight_dims[1], weight_dims[2], 1,
y_scale.data<T>(), weight_mat.data<T>(), 1, y_scale.data<T>(), weight_i.data<T>(), 1,
d_x->data<T>()); d_x->data<T>());
} }
if (d_y) { if (d_y) {
trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, x_scale_mat.device(place) =
ScaleFunctor<T>(&d_out_ptr[i])); output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans, .broadcast(bcast_for_y) *
weight_dims[2], 1, weight_dims[1], 1, x_mat;
weight_mat.data<T>(), x_scale.data<T>(), 1, math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
batch_size, weight_dims[2], weight_dims[1], 1,
x_scale.data<T>(), weight_i.data<T>(), 1,
d_y->data<T>()); d_y->data<T>());
} }
} }
...@@ -152,22 +160,27 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> { ...@@ -152,22 +160,27 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
// Caculate the gradient of Weight. // Caculate the gradient of Weight.
if (d_weight) { if (d_weight) {
d_weight->mutable_data<T>(ctx.GetPlace()); d_weight->mutable_data<T>(ctx.GetPlace());
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
for (int i = 0; i < weight_dims[0]; ++i) { for (int i = 0; i < weight_dims[0]; ++i) {
Tensor d_weight_mat = d_weight->Slice(i, i + 1).Resize( Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
framework::make_ddim({weight_dims[1], weight_dims[2]})); framework::make_ddim({weight_dims[1], weight_dims[2]}));
trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr, auto output_vec = d_out_mat.chip(i, 1);
ScaleFunctor<T>(&d_out_ptr[i])); x_scale_mat.device(place) =
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
.broadcast(bcast_for_weight) *
x_mat;
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans, math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
weight_dims[1], weight_dims[2], 1, 1, weight_dims[1], weight_dims[2], batch_size, 1,
x_scale.data<T>(), y->data<T>(), 0, x_scale.data<T>(), y->data<T>(), 0,
d_weight_mat.data<T>()); d_weight_i.data<T>());
} }
} }
// Caculate the gradient of Bias. // Caculate the gradient of Bias.
if (d_bias) { if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<T>(ctx.GetPlace());
d_bias->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context()); auto d_bias_mat = EigenMatrix<T>::From(*d_bias);
d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes<int, 1>(0));
} }
} }
}; };
......
...@@ -6,24 +6,85 @@ from op_test import OpTest ...@@ -6,24 +6,85 @@ from op_test import OpTest
class TestBilinearTensorProductOp(OpTest): class TestBilinearTensorProductOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "bilinear_tensor_product" self.op_type = "bilinear_tensor_product"
batch_size = 6
size0 = 3
size1 = 4
size2 = 5
a = np.random.random((batch_size, size0)).astype("float32")
b = np.random.random((batch_size, size1)).astype("float32")
w = np.random.random((size2, size0, size1)).astype("float32")
bias = np.random.random((1, size2)).astype("float32")
output = np.zeros((batch_size, size2)).astype("float32")
for i in range(size2):
w_i = w[i, :, :]
output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1)
self.inputs = { self.inputs = {
'X': np.random.random(3).astype("float32"), 'X': a,
'Y': np.random.random(4).astype("float32"), 'Y': b,
'Weight': np.random.random((5, 3, 4)).astype("float32"), 'Weight': w,
'Bias': np.random.random(5).astype("float32") 'Bias': bias,
} }
self.outputs = { self.outputs = {'Out': output + bias}
'Out': np.matmul(
np.matmul(self.inputs['Weight'], self.inputs['Y']), def test_check_output(self):
self.inputs['X']) + self.inputs['Bias'] self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out')
class TestBilinearTensorProductOp2(TestBilinearTensorProductOp):
def setUp(self):
self.op_type = "bilinear_tensor_product"
batch_size = 1
size0 = 1
size1 = 1
size2 = 1
a = np.random.random((batch_size, size0)).astype("float32")
b = np.random.random((batch_size, size1)).astype("float32")
w = np.random.random((size2, size0, size1)).astype("float32")
bias = np.random.random((1, size2)).astype("float32")
output = np.zeros((batch_size, size2)).astype("float32")
for i in range(size2):
w_i = w[i, :, :]
output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1)
self.inputs = {
'X': a,
'Y': b,
'Weight': w,
'Bias': bias,
} }
self.outputs = {'Out': output + bias}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out')
class TestBilinearTensorProductOp3(TestBilinearTensorProductOp):
def setUp(self):
self.op_type = "bilinear_tensor_product"
batch_size = 7
size0 = 4
size1 = 5
size2 = 6
a = np.random.random((batch_size, size0)).astype("float32")
b = np.random.random((batch_size, size1)).astype("float32")
w = np.random.random((size2, size0, size1)).astype("float32")
output = np.zeros((batch_size, size2)).astype("float32")
for i in range(size2):
w_i = w[i, :, :]
output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1)
self.inputs = {'X': a, 'Y': b, 'Weight': w}
self.outputs = {'Out': output}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad( self.check_grad(['X', 'Y', 'Weight'], 'Out')
['X', 'Y', 'Weight', 'Bias'], 'Out', max_relative_error=0.5)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册