提交 f3ff790b 编写于 作者: Y Yi Wang

Update usage of Scope in operator_test.cc

上级 d1000623
...@@ -18,7 +18,7 @@ proto_library(op_desc SRCS op_desc.proto DEPS attr_type) ...@@ -18,7 +18,7 @@ proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
......
...@@ -30,9 +30,9 @@ class OpWithoutKernelTest : public OperatorBase { ...@@ -30,9 +30,9 @@ class OpWithoutKernelTest : public OperatorBase {
op_run_num++; op_run_num++;
ASSERT_EQ((int)inputs_.size(), 1); ASSERT_EQ((int)inputs_.size(), 1);
ASSERT_EQ((int)outputs_.size(), 1); ASSERT_EQ((int)outputs_.size(), 1);
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(scope->FindVar(inputs_[0]), nullptr);
ASSERT_EQ(x, 1); ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); ASSERT_NE(scope->FindVar(outputs_[0]), nullptr);
} }
public: public:
...@@ -71,7 +71,7 @@ TEST(OperatorBase, all) { ...@@ -71,7 +71,7 @@ TEST(OperatorBase, all) {
auto scope = std::make_shared<paddle::framework::Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1"); scope->NewVar("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context); op->Run(scope, device_context);
ASSERT_EQ(paddle::framework::op_run_num, 1); ASSERT_EQ(paddle::framework::op_run_num, 1);
...@@ -120,9 +120,9 @@ class OperatorMultiInputsTest : public OperatorBase { ...@@ -120,9 +120,9 @@ class OperatorMultiInputsTest : public OperatorBase {
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(scope->FindVar(inputs_[0]), nullptr);
ASSERT_EQ(x, 1); ASSERT_EQ(x, 1);
ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); ASSERT_NE(scope->FindVar(outputs_[0]), nullptr);
ASSERT_EQ(Input("x"), "IN1"); ASSERT_EQ(Input("x"), "IN1");
ASSERT_EQ(Input("y"), "OUT1"); ASSERT_EQ(Input("y"), "OUT1");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册