diff --git a/paddle/operators/interp_op.cc b/paddle/operators/interp_op.cc index fc8b9a11b8b16bbc4016137439fd3b958db68af2..d02b01c3f3a1b30ec27253140203b076a98ce0c2 100644 --- a/paddle/operators/interp_op.cc +++ b/paddle/operators/interp_op.cc @@ -30,27 +30,26 @@ class InterpOp : public NetOp { "Input(Y) of InterpOp should not be null."); PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName, "Input(W) of InterpOp should not be null."); - PADDLE_ENFORCE_NE(Output("MinusOut"), framework::kEmptyVarName, - "Output(MinusOut) of InterpOp should not be null."); + PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName, + "Output(SubOut) of InterpOp should not be null."); PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName, "Output(MulOut) of InterpOp should not be null."); PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, "Output(Out) of InterpOp should not be null."); - // MinusOut = X - Y + // SubOut = X - Y auto x = Input("X"); auto y = Input("Y"); - auto minus_out = Output("MinusOut"); - AppendOp(framework::OpRegistry::CreateOp("elementwise_sub", - {{"X", {x}}, {"Y", {y}}}, - {{"Out", {minus_out}}}, {})); + auto sub_out = Output("SubOut"); + AppendOp(framework::OpRegistry::CreateOp( + "elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {})); - // MulOut = MinusOut * W = (X - Y) * W + // MulOut = SubOut * W = (X - Y) * W auto w = Input("W"); auto mul_out = Output("MulOut"); AppendOp(framework::OpRegistry::CreateOp( - "elementwise_mul", {{"X", {minus_out}}, {"Y", {w}}}, - {{"Out", {mul_out}}}, {{"axis", 0}})); + "elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}}, + {{"axis", 0}})); // Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W) AppendOp(framework::OpRegistry::CreateOp("elementwise_add", @@ -65,18 +64,26 @@ class InterpOpMaker : public framework::OpProtoAndCheckerMaker { public: InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "A 2-D Tensor, the first input of interp_op"); - AddInput("Y", "A 2-D Tensor, the second input of interp_op"); - AddInput("W", "A 1-D Tensor, the interpolated values"); - AddOutput("MinusOut", - "A 2-D Tensor, the intermediate outputs, saving X - Y.") + AddInput("X", + "(Tensor), 2-D Matrix of shape [batch_size, data_dim]" + "containing data samples, the first input of interp_op"); + AddInput("Y", + "(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`" + "containing data samples, the second input of interp_op"); + AddInput("W", + "(Tensor), 1-D Vector of shape [batch_size]," + "the interpolated values in the half-open interval [0.0, 1.0)"); + AddOutput("SubOut", + "(Tensor), the intermediate subtraction outputs, saving X - Y.") .AsIntermediate(); AddOutput("MulOut", - "A 2-D Tensor, the intermediate outputs," - "saving the mul mul of (X - Y) and W") + "(Tensor), the intermediate multiplication outputs," + "saving the elementwise multiplication of (X - Y) and W.") .AsIntermediate(); AddOutput("Out", - "A 2-D Tensor, the output of interp_op, same shape with X"); + "(Tensor), the output of interp_op, same shape with X," + "returns the first-dimensional piecewise linear interpolant " + "between X and Y"); AddComment(R"DOC( Linear Interpolation with two inputs, used in NEURAL TURING MACHINE. diff --git a/python/paddle/v2/framework/tests/test_interp_op.py b/python/paddle/v2/framework/tests/test_interp_op.py index f82dcc7f507c8879148645303cc6594892c672c8..066569b96c9611bd20e7192f8bd6caa6e467202f 100644 --- a/python/paddle/v2/framework/tests/test_interp_op.py +++ b/python/paddle/v2/framework/tests/test_interp_op.py @@ -10,12 +10,12 @@ class TestInterpOp(OpTest): y = np.random.random((2, 3)).astype("float32") w = np.random.random(2).astype("float32") - minus_out = x - y - mul_out = minus_out * w.reshape(2, 1) + sub_out = x - y + mul_out = sub_out * w.reshape(2, 1) out = mul_out + y self.inputs = {'X': x, 'Y': y, 'W': w} - self.outputs = {'Out': out, 'MinusOut': minus_out, 'MulOut': mul_out} + self.outputs = {'Out': out, 'SubOut': sub_out, 'MulOut': mul_out} def test_check_output(self): self.check_output()