提交 a06f099d 编写于 作者: L Luo Tao

refine comment of interp_op

上级 5b862fed
......@@ -30,27 +30,26 @@ class InterpOp : public NetOp {
"Input(Y) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName,
"Input(W) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("MinusOut"), framework::kEmptyVarName,
"Output(MinusOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName,
"Output(SubOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName,
"Output(MulOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
"Output(Out) of InterpOp should not be null.");
// MinusOut = X - Y
// SubOut = X - Y
auto x = Input("X");
auto y = Input("Y");
auto minus_out = Output("MinusOut");
AppendOp(framework::OpRegistry::CreateOp("elementwise_sub",
{{"X", {x}}, {"Y", {y}}},
{{"Out", {minus_out}}}, {}));
auto sub_out = Output("SubOut");
AppendOp(framework::OpRegistry::CreateOp(
"elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {}));
// MulOut = MinusOut * W = (X - Y) * W
// MulOut = SubOut * W = (X - Y) * W
auto w = Input("W");
auto mul_out = Output("MulOut");
AppendOp(framework::OpRegistry::CreateOp(
"elementwise_mul", {{"X", {minus_out}}, {"Y", {w}}},
{{"Out", {mul_out}}}, {{"axis", 0}}));
"elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}},
{{"axis", 0}}));
// Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W)
AppendOp(framework::OpRegistry::CreateOp("elementwise_add",
......@@ -65,18 +64,26 @@ class InterpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "A 2-D Tensor, the first input of interp_op");
AddInput("Y", "A 2-D Tensor, the second input of interp_op");
AddInput("W", "A 1-D Tensor, the interpolated values");
AddOutput("MinusOut",
"A 2-D Tensor, the intermediate outputs, saving X - Y.")
AddInput("X",
"(Tensor), 2-D Matrix of shape [batch_size, data_dim]"
"containing data samples, the first input of interp_op");
AddInput("Y",
"(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`"
"containing data samples, the second input of interp_op");
AddInput("W",
"(Tensor), 1-D Vector of shape [batch_size],"
"the interpolated values in the half-open interval [0.0, 1.0)");
AddOutput("SubOut",
"(Tensor), the intermediate subtraction outputs, saving X - Y.")
.AsIntermediate();
AddOutput("MulOut",
"A 2-D Tensor, the intermediate outputs,"
"saving the mul mul of (X - Y) and W")
"(Tensor), the intermediate multiplication outputs,"
"saving the elementwise multiplication of (X - Y) and W.")
.AsIntermediate();
AddOutput("Out",
"A 2-D Tensor, the output of interp_op, same shape with X");
"(Tensor), the output of interp_op, same shape with X,"
"returns the first-dimensional piecewise linear interpolant "
"between X and Y");
AddComment(R"DOC(
Linear Interpolation with two inputs, used in NEURAL TURING MACHINE.
......
......@@ -10,12 +10,12 @@ class TestInterpOp(OpTest):
y = np.random.random((2, 3)).astype("float32")
w = np.random.random(2).astype("float32")
minus_out = x - y
mul_out = minus_out * w.reshape(2, 1)
sub_out = x - y
mul_out = sub_out * w.reshape(2, 1)
out = mul_out + y
self.inputs = {'X': x, 'Y': y, 'W': w}
self.outputs = {'Out': out, 'MinusOut': minus_out, 'MulOut': mul_out}
self.outputs = {'Out': out, 'SubOut': sub_out, 'MulOut': mul_out}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册