提交 a06f099d 编写于 作者: L Luo Tao

refine comment of interp_op

上级 5b862fed
...@@ -30,27 +30,26 @@ class InterpOp : public NetOp { ...@@ -30,27 +30,26 @@ class InterpOp : public NetOp {
"Input(Y) of InterpOp should not be null."); "Input(Y) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName, PADDLE_ENFORCE_NE(Input("W"), framework::kEmptyVarName,
"Input(W) of InterpOp should not be null."); "Input(W) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("MinusOut"), framework::kEmptyVarName, PADDLE_ENFORCE_NE(Output("SubOut"), framework::kEmptyVarName,
"Output(MinusOut) of InterpOp should not be null."); "Output(SubOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName, PADDLE_ENFORCE_NE(Output("MulOut"), framework::kEmptyVarName,
"Output(MulOut) of InterpOp should not be null."); "Output(MulOut) of InterpOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
"Output(Out) of InterpOp should not be null."); "Output(Out) of InterpOp should not be null.");
// MinusOut = X - Y // SubOut = X - Y
auto x = Input("X"); auto x = Input("X");
auto y = Input("Y"); auto y = Input("Y");
auto minus_out = Output("MinusOut"); auto sub_out = Output("SubOut");
AppendOp(framework::OpRegistry::CreateOp("elementwise_sub", AppendOp(framework::OpRegistry::CreateOp(
{{"X", {x}}, {"Y", {y}}}, "elementwise_sub", {{"X", {x}}, {"Y", {y}}}, {{"Out", {sub_out}}}, {}));
{{"Out", {minus_out}}}, {}));
// MulOut = MinusOut * W = (X - Y) * W // MulOut = SubOut * W = (X - Y) * W
auto w = Input("W"); auto w = Input("W");
auto mul_out = Output("MulOut"); auto mul_out = Output("MulOut");
AppendOp(framework::OpRegistry::CreateOp( AppendOp(framework::OpRegistry::CreateOp(
"elementwise_mul", {{"X", {minus_out}}, {"Y", {w}}}, "elementwise_mul", {{"X", {sub_out}}, {"Y", {w}}}, {{"Out", {mul_out}}},
{{"Out", {mul_out}}}, {{"axis", 0}})); {{"axis", 0}}));
// Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W) // Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W)
AppendOp(framework::OpRegistry::CreateOp("elementwise_add", AppendOp(framework::OpRegistry::CreateOp("elementwise_add",
...@@ -65,18 +64,26 @@ class InterpOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,18 +64,26 @@ class InterpOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) InterpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "A 2-D Tensor, the first input of interp_op"); AddInput("X",
AddInput("Y", "A 2-D Tensor, the second input of interp_op"); "(Tensor), 2-D Matrix of shape [batch_size, data_dim]"
AddInput("W", "A 1-D Tensor, the interpolated values"); "containing data samples, the first input of interp_op");
AddOutput("MinusOut", AddInput("Y",
"A 2-D Tensor, the intermediate outputs, saving X - Y.") "(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`"
"containing data samples, the second input of interp_op");
AddInput("W",
"(Tensor), 1-D Vector of shape [batch_size],"
"the interpolated values in the half-open interval [0.0, 1.0)");
AddOutput("SubOut",
"(Tensor), the intermediate subtraction outputs, saving X - Y.")
.AsIntermediate(); .AsIntermediate();
AddOutput("MulOut", AddOutput("MulOut",
"A 2-D Tensor, the intermediate outputs," "(Tensor), the intermediate multiplication outputs,"
"saving the mul mul of (X - Y) and W") "saving the elementwise multiplication of (X - Y) and W.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Out", AddOutput("Out",
"A 2-D Tensor, the output of interp_op, same shape with X"); "(Tensor), the output of interp_op, same shape with X,"
"returns the first-dimensional piecewise linear interpolant "
"between X and Y");
AddComment(R"DOC( AddComment(R"DOC(
Linear Interpolation with two inputs, used in NEURAL TURING MACHINE. Linear Interpolation with two inputs, used in NEURAL TURING MACHINE.
......
...@@ -10,12 +10,12 @@ class TestInterpOp(OpTest): ...@@ -10,12 +10,12 @@ class TestInterpOp(OpTest):
y = np.random.random((2, 3)).astype("float32") y = np.random.random((2, 3)).astype("float32")
w = np.random.random(2).astype("float32") w = np.random.random(2).astype("float32")
minus_out = x - y sub_out = x - y
mul_out = minus_out * w.reshape(2, 1) mul_out = sub_out * w.reshape(2, 1)
out = mul_out + y out = mul_out + y
self.inputs = {'X': x, 'Y': y, 'W': w} self.inputs = {'X': x, 'Y': y, 'W': w}
self.outputs = {'Out': out, 'MinusOut': minus_out, 'MulOut': mul_out} self.outputs = {'Out': out, 'SubOut': sub_out, 'MulOut': mul_out}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册