提交 5cf82041 编写于 作者: P peterzhang2029

refine docString

上级 44e1ac38
...@@ -34,35 +34,28 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel { ...@@ -34,35 +34,28 @@ class BilinearTensorProductOp : public framework::OperatorWithKernel {
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input X must be a 2D Tensor."); PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "The input(X) must be a 2D Tensor.");
PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input Y must be a 2D Tensor."); PADDLE_ENFORCE_EQ(y_dims.size(), 2UL, "The input(Y) must be a 2D Tensor.");
PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL, PADDLE_ENFORCE_EQ(weight_dims.size(), 3UL,
"The input Weight must be a 3D tensor."); "The input(Weight) must be a 3D tensor.");
PADDLE_ENFORCE(weight_dims[0],
"The first dimension of Weight must be larger than 0.");
PADDLE_ENFORCE(weight_dims[1],
"The second dimension of Weight must be larger than 0.");
PADDLE_ENFORCE(weight_dims[2],
"The third dimension of Weight must be larger than 0.");
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The first dimension(batch_size) of X must be " "The first dimension(batch_size) of input(X) must be "
"equal to the first dimension of the Y."); "equal to the first dimension of the input(Y).");
PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1], PADDLE_ENFORCE_EQ(x_dims[1], weight_dims[1],
"The second dimension of X must be equal to the second " "The second dimension of input(X) must be equal to "
"dimension of the Weight."); "the second dimension of the input(Weight).");
PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2], PADDLE_ENFORCE_EQ(y_dims[1], weight_dims[2],
"The second dimension of Y must be equal to the third " "The second dimension of input(Y) must be equal to "
"dimension of the Weight."); "the third dimension of the input(Weight).");
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims.size(), 2UL, PADDLE_ENFORCE(bias_dims.size() == 2UL && bias_dims[0] == 1UL,
"The input Bias must have 2 dimensions."); "The Input(Bias) must be a 2-D tensor with "
PADDLE_ENFORCE_EQ(bias_dims[0], 1UL, "the 2nd dimension fixed to 1 (a row vector).");
"The first dimention of input Bias must be 1.");
PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0], PADDLE_ENFORCE_EQ(bias_dims[1], weight_dims[0],
"The second dimension of Bias must be equal to the " "The second dimension of input(Bias) must be equal "
"first dimension of the Weight."); "to the first dimension of the input(Weight).");
} }
ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]}); ctx->SetOutputDim("Out", {x_dims[0], weight_dims[0]});
...@@ -75,12 +68,13 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -75,12 +68,13 @@ class BilinearTensorProductOpMaker : public framework::OpProtoAndCheckerMaker {
BilinearTensorProductOpMaker(framework::OpProto* proto, BilinearTensorProductOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of BilinearTensorProduct op."); AddInput("X", "The first input of bilinear_tensor_product operator.");
AddInput("Y", "The second input of BilinearTensorProduct op."); AddInput("Y", "The second input of bilinear_tensor_product operator.");
AddInput("Weight", "The input weight of BilinearTensorProduct op."); AddInput("Weight",
AddInput("Bias", "The input bias of BilinearTensorProduct op.") "The learnable parameters of bilinear_tensor_product operator.");
AddInput("Bias", "The learnable bias of bilinear_tensor_product operator.")
.AsDispensable(); .AsDispensable();
AddOutput("Out", "The output of BilinearTensorProduct op."); AddOutput("Out", "The output of bilinear_tensor_product operator.");
AddComment(R"DOC( AddComment(R"DOC(
Bilinear Tensor Product operator. Bilinear Tensor Product operator.
Given input X and Y, a 3D tensor weight, and bias. Each column of the Given input X and Y, a 3D tensor weight, and bias. Each column of the
...@@ -104,27 +98,29 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { ...@@ -104,27 +98,29 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Weight"), PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) should not be null."); "Input(Weight) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input (Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto weight_dims = ctx->GetInputDim("Weight"); auto weight_dims = ctx->GetInputDim("Weight");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(out_dims.size(), 2UL, PADDLE_ENFORCE_EQ(out_dims.size(), 2UL,
"The Out@GRAD must be a 2D Tensor."); "The input(Out@GRAD) must be a 2D Tensor.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[0], out_dims[0], x_dims[0], out_dims[0],
"The first dimension(batch_size) of Out@GRAD must be equal to " "The first dimension(batch_size) of input(Out@GRAD) must be "
"the first dimension of the Input(X)."); "equal to the first dimension of the Input(X).");
PADDLE_ENFORCE_EQ(weight_dims[0], out_dims[1], PADDLE_ENFORCE_EQ(
"The second dimension of Out@GRAD must be equal to " weight_dims[0], out_dims[1],
"the third dimension of the Input(Weight)."); "The second dimension of input(Out@GRAD) must be equal to "
"the third dimension of the Input(Weight).");
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims[1], out_dims[1], PADDLE_ENFORCE_EQ(
"The second dimension of Out@GRAD must be equal to " bias_dims[1], out_dims[1],
"the second dimension of the Input(Bias)."); "The second dimension of input(Out@GRAD) must be equal to "
"the second dimension of the Input(Bias).");
auto bias_grad_name = framework::GradVarName("Bias"); auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims); ctx->SetOutputDim(bias_grad_name, bias_dims);
...@@ -155,7 +151,9 @@ REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp, ...@@ -155,7 +151,9 @@ REGISTER_OP(bilinear_tensor_product, ops::BilinearTensorProductOp,
ops::BilinearTensorProductOpGrad); ops::BilinearTensorProductOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product, bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float>); ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, float>,
ops::BilinearTensorProductKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
bilinear_tensor_product_grad, bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float>); ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, float>,
ops::BilinearTensorProductGradKernel<paddle::platform::CPUPlace, double>);
...@@ -18,7 +18,9 @@ limitations under the License. */ ...@@ -18,7 +18,9 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
bilinear_tensor_product, bilinear_tensor_product,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, float>); ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, float>,
ops::BilinearTensorProductKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
bilinear_tensor_product_grad, bilinear_tensor_product_grad,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, float>); ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, float>,
ops::BilinearTensorProductGradKernel<paddle::platform::GPUPlace, double>);
...@@ -33,59 +33,5 @@ class TestBilinearTensorProductOp(OpTest): ...@@ -33,59 +33,5 @@ class TestBilinearTensorProductOp(OpTest):
self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out') self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out')
class TestBilinearTensorProductOp2(TestBilinearTensorProductOp):
def setUp(self):
self.op_type = "bilinear_tensor_product"
batch_size = 1
size0 = 1
size1 = 1
size2 = 1
a = np.random.random((batch_size, size0)).astype("float32")
b = np.random.random((batch_size, size1)).astype("float32")
w = np.random.random((size2, size0, size1)).astype("float32")
bias = np.random.random((1, size2)).astype("float32")
output = np.zeros((batch_size, size2)).astype("float32")
for i in range(size2):
w_i = w[i, :, :]
output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1)
self.inputs = {
'X': a,
'Y': b,
'Weight': w,
'Bias': bias,
}
self.outputs = {'Out': output + bias}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y', 'Weight', 'Bias'], 'Out')
class TestBilinearTensorProductOp3(TestBilinearTensorProductOp):
def setUp(self):
self.op_type = "bilinear_tensor_product"
batch_size = 7
size0 = 4
size1 = 5
size2 = 6
a = np.random.random((batch_size, size0)).astype("float32")
b = np.random.random((batch_size, size1)).astype("float32")
w = np.random.random((size2, size0, size1)).astype("float32")
output = np.zeros((batch_size, size2)).astype("float32")
for i in range(size2):
w_i = w[i, :, :]
output[:, i] = np.sum(np.matmul(a, w_i) * b, axis=1)
self.inputs = {'X': a, 'Y': b, 'Weight': w}
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y', 'Weight'], 'Out')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册