提交 f56f1492 编写于 作者: F fengjiayi

fix_output_name

上级 77cf7d4f
...@@ -217,7 +217,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -217,7 +217,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
// If part of input gradient of that operator is not calculated, fill // If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient. // zero variables to that input gradient.
net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", {{"X", {prefix}}}, net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", {{"X", {prefix}}},
{{"Y", {grad_input}}}, {{"Out", {grad_input}}},
AttributeMap{})); AttributeMap{}));
} }
return false; return false;
...@@ -396,7 +396,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -396,7 +396,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
desc->Rename(in_name, new_name); desc->Rename(in_name, new_name);
std::unique_ptr<OpDescBind> fill_zeros_op( std::unique_ptr<OpDescBind> fill_zeros_op(
new OpDescBind("fill_zeros_like", {{"X", {prefix}}}, new OpDescBind("fill_zeros_like", {{"X", {prefix}}},
{{"Y", {new_name}}}, AttributeMap{})); {{"Out", {new_name}}}, AttributeMap{}));
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op)); pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
} }
} }
......
...@@ -430,8 +430,8 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -430,8 +430,8 @@ TEST(Backward, op_part_of_output_are_not_need) {
ASSERT_EQ("fill_zeros_like", fill_zero.Type()); ASSERT_EQ("fill_zeros_like", fill_zero.Type());
ASSERT_EQ(1UL, fill_zero.Inputs("X").size()); ASSERT_EQ(1UL, fill_zero.Inputs("X").size());
ASSERT_EQ("Z", fill_zero.Input("X")); ASSERT_EQ("Z", fill_zero.Input("X"));
ASSERT_EQ(1UL, fill_zero.Outputs("Y").size()); ASSERT_EQ(1UL, fill_zero.Outputs("Out").size());
ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Y")); ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Out"));
auto &d_many_out = *net->ops_[1]; auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.Type()); ASSERT_EQ("many_output_op_grad", d_many_out.Type());
...@@ -772,7 +772,7 @@ TEST(Backward, var_no_grad) { ...@@ -772,7 +772,7 @@ TEST(Backward, var_no_grad) {
ASSERT_EQ(fill_zero_op->InputNames().size(), 1UL); ASSERT_EQ(fill_zero_op->InputNames().size(), 1UL);
ASSERT_EQ(fill_zero_op->OutputNames().size(), 1UL); ASSERT_EQ(fill_zero_op->OutputNames().size(), 1UL);
EXPECT_EQ(fill_zero_op->Input("X"), std::vector<std::string>({"z1"})); EXPECT_EQ(fill_zero_op->Input("X"), std::vector<std::string>({"z1"}));
EXPECT_EQ(fill_zero_op->Output("Y"), EXPECT_EQ(fill_zero_op->Output("Out"),
std::vector<std::string>({std::string("z1") + f::kZeroVarSuffix})); std::vector<std::string>({std::string("z1") + f::kZeroVarSuffix}));
f::OpDescBind *grad_op1 = block->AllOps()[5]; f::OpDescBind *grad_op1 = block->AllOps()[5];
......
...@@ -24,9 +24,9 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { ...@@ -24,9 +24,9 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FillZerosLikeOp should not be null."); "Input(X) of FillZerosLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Y) of FillZerosLikeOp should not be null."); "Output(Out) of FillZerosLikeOp should not be null.");
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
}; };
...@@ -37,7 +37,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -37,7 +37,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of fill-zeros-like op."); AddInput("X", "The input of fill-zeros-like op.");
AddOutput("Y", "The variable will be filled up with zeros."); AddOutput("Out", "The variable will be filled up with zeros.");
AddComment(R"DOC( AddComment(R"DOC(
FillZerosLike Operator. FillZerosLike Operator.
......
...@@ -23,7 +23,7 @@ template <typename DeviceContext, typename T> ...@@ -23,7 +23,7 @@ template <typename DeviceContext, typename T>
class FillZerosLikeKernel : public framework::OpKernel<T> { class FillZerosLikeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Y"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> setter; math::SetConstant<DeviceContext, T> setter;
......
...@@ -7,7 +7,7 @@ class TestFillZerosLikeOp(OpTest): ...@@ -7,7 +7,7 @@ class TestFillZerosLikeOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fill_zeros_like" self.op_type = "fill_zeros_like"
self.inputs = {'X': np.random.random((219, 232)).astype("float32")} self.inputs = {'X': np.random.random((219, 232)).astype("float32")}
self.outputs = {'Y': np.zeros_like(self.inputs["X"])} self.outputs = {'Out': np.zeros_like(self.inputs["X"])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册