未验证 提交 1363ddb6 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/executor use program bind (#5196)

* Init commit

* Make executor use ProgramDescBind

* Change Attribute from BlockDesc to BlockDescBind

* Since we will get the program desc in RNN, just BlockDesc is not
  enough.
上级 ee11f006
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: {
return attr_desc.b();
......@@ -61,13 +61,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
}
return val;
}
case framework::AttrType::BLOCK: {
PADDLE_ENFORCE(program != nullptr,
"Need to specify ProgramDesc when get a block attr");
return program->mutable_blocks(attr_desc.block_idx());
}
default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
}
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
return boost::blank();
}
......
......@@ -32,7 +32,7 @@ inline AttrType AttrTypeID() {
return static_cast<AttrType>(tmp.which() - 1);
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc);
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
class AttrReader {
public:
......
......@@ -368,7 +368,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) {
BlockDescBind* cur_block = program_desc.Block(block_idx);
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
std::vector<OpDescBind*> op_descs = cur_block->AllOps();
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0;
......@@ -443,7 +443,7 @@ ParamGradInfoMap AppendBackward(
}
const int root_block_idx = 0;
auto root_block = program_desc.Block(root_block_idx);
auto root_block = program_desc.MutableBlock(root_block_idx);
// insert fill one op for target
// TODO(qiao) add some check to the target.
......@@ -492,7 +492,7 @@ ParamGradInfoMap AppendBackward(
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
CreateGradVarInBlock(0, grad_to_var, program_desc.MutableBlock(block_index),
&retv);
}
return retv;
......
......@@ -499,7 +499,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
TEST(Backward, simple_single_op) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op = block->AppendOp();
op->SetType("rowwise_add");
......@@ -535,7 +535,7 @@ TEST(Backward, simple_single_op) {
TEST(Backward, default_attribute) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op = block->AppendOp();
op->SetType("mul");
op->SetInput("X", {"x"});
......@@ -561,7 +561,7 @@ TEST(Backward, default_attribute) {
TEST(Backward, simple_mult_op) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"});
......@@ -644,7 +644,7 @@ TEST(Backward, simple_mult_op) {
TEST(Backward, intermedia_var_no_grad) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"});
......@@ -714,7 +714,7 @@ TEST(Backward, intermedia_var_no_grad) {
TEST(Backward, var_no_grad) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp();
op1->SetType("mult_in_out");
op1->SetInput("X", {"x1"});
......@@ -790,7 +790,7 @@ TEST(Backward, var_no_grad) {
TEST(Backward, shared_var) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"});
......@@ -880,7 +880,7 @@ TEST(Backward, shared_var) {
TEST(Backward, half_backward) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
auto *op1 = block->AppendOp();
op1->SetType("minus");
op1->SetInput("X", {"a"});
......
......@@ -113,7 +113,7 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
if (this->desc_->parent_idx() == kNoneBlockIndex) {
return nullptr;
}
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
}
BlockDesc *BlockDescBind::Proto() {
......
......@@ -73,33 +73,32 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id);
auto& block = pdesc.blocks(block_id);
PADDLE_ENFORCE_LT(block_id, pdesc.Size());
auto& block = pdesc.Block(block_id);
auto& device = device_contexts_[0];
Scope& local_scope = scope->NewScope();
for (auto& var : block.vars()) {
if (var.persistable()) {
auto* ptr = scope->Var(var.name());
CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope.Var(var.name());
CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
auto* ptr = local_scope.Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(
op_desc, const_cast<ProgramDesc*>(&pdesc));
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
op->Run(local_scope, *device);
}
......
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
......@@ -34,7 +34,7 @@ class Executor {
* ProgramDesc
* Scope
*/
void Run(const ProgramDesc&, Scope*, int);
void Run(const ProgramDescBind&, Scope*, int);
private:
std::vector<platform::DeviceContext*> device_contexts_;
......
......@@ -114,7 +114,12 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
// restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr, prog->Proto());
if (attr.type() != AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr);
} else {
auto bid = attr.block_idx();
attrs_[attr_name] = prog->MutableBlock(bid);
}
}
}
......@@ -188,8 +193,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
}
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.Proto();
this->attrs_[name] = desc;
this->attrs_[name] = &block;
need_update_ = true;
}
......@@ -208,7 +212,7 @@ Attribute OpDescBind::GetAttr(const std::string &name) const {
int OpDescBind::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDesc *>(it->second)->idx();
return boost::get<BlockDescBind *>(it->second)->ID();
}
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
......
......@@ -43,13 +43,15 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val;
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc,
ProgramDesc* program) {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) "
"instead.";
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr, program);
attrs[attr.name()] = GetAttrValue(attr);
}
return CreateOp(op_desc.type(), inputs, outputs, attrs);
......
......@@ -77,8 +77,7 @@ class OpRegistry {
const VariableNameMap& outputs,
AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc,
ProgramDesc* program);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};
......
......@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
......@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "larger_than check fail";
......@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
......@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
......@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3);
caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "'test_attr' must be even!";
......@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope;
op->Run(scope, dev_ctx);
......
......@@ -83,7 +83,7 @@ TEST(OperatorBase, all) {
paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope.Var("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context);
......@@ -208,7 +208,7 @@ TEST(OpKernel, all) {
paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
......@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) {
scope.Var("y0")->GetMutable<LoDTensor>();
scope.Var("y1")->GetMutable<LoDTensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context);
}
......
......@@ -37,7 +37,9 @@ class ProgramDescBind {
BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
BlockDescBind *MutableBlock(size_t idx) { return blocks_[idx].get(); }
const BlockDescBind &Block(size_t idx) const { return *blocks_[idx]; }
size_t Size() const { return blocks_.size(); }
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace framework {
TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program;
auto* global_block = program.Block(0);
auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
......@@ -44,7 +44,7 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program_copy(program);
auto* global_block_copy = program_copy.Block(0);
auto* global_block_copy = program_copy.MutableBlock(0);
ASSERT_NE(global_block, global_block_copy);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
......@@ -82,7 +82,7 @@ TEST(ProgramDesc, copy_ctor) {
TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin;
auto* global_block = program_origin.Block(0);
auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
......@@ -108,7 +108,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
program_origin.Proto()->SerializeToString(&binary_str);
ProgramDescBind program_restored(binary_str);
auto* global_block_restored = program_restored.Block(0);
auto* global_block_restored = program_restored.MutableBlock(0);
ASSERT_NE(global_block, global_block_restored);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
......
......@@ -52,7 +52,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
TEST(Prune, one_operator) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
......@@ -69,7 +69,7 @@ TEST(Prune, one_operator) {
TEST(Prune, forward) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block);
......@@ -88,7 +88,7 @@ TEST(Prune, forward) {
TEST(Prune, multi_input_op) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block);
AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block);
......@@ -106,7 +106,7 @@ TEST(Prune, multi_input_op) {
TEST(Prune, multi_output_op) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
......@@ -122,7 +122,7 @@ TEST(Prune, multi_output_op) {
TEST(Prune, multi_target) {
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
......
......@@ -36,7 +36,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*>;
std::vector<bool>, BlockDescBind*>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
......
......@@ -63,41 +63,43 @@ namespace framework {
TEST(InferVarType, sum_op) {
ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp();
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"});
prog.Block(0)->Var("test_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_out");
prog.MutableBlock(0)->Var("test_a")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_c")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.Block(0));
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::SELECTED_ROWS, prog.Block(0)->Var("test_out")->GetType());
ASSERT_EQ(VarDesc::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.Block(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR);
op->InferVarType(prog.Block(0));
ASSERT_EQ(VarDesc::LOD_TENSOR, prog.Block(0)->Var("test_out")->GetType());
prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR);
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
}
TEST(InferVarType, sum_op_without_infer_var_type) {
ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp();
auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"});
prog.Block(0)->Var("test2_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_out");
prog.MutableBlock(0)->Var("test2_a")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_b")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_c")->SetType(VarDesc::SELECTED_ROWS);
prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.Block(0));
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc_VarType_LOD_TENSOR,
prog.Block(0)->Var("test2_out")->GetType());
prog.MutableBlock(0)->Var("test2_out")->GetType());
}
} // namespace framework
......
......@@ -51,7 +51,7 @@ class RNNAlgorithmTestHelper : public ::testing::Test {
CreateGlobalVariables();
auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
op = paddle::framework::OpRegistry::CreateOp(op_desc);
dop = &(dynamic_cast<DynamicRecurrentOp*>(op.get())->rnn);
InitCacheManually();
InitStepNet();
......
......@@ -129,7 +129,8 @@ void BindProgramDesc(py::module &m) {
}
return retv;
})
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
.def("block", &ProgramDescBind::MutableBlock,
py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size)
.def("serialize_to_string",
[](ProgramDescBind &program_desc) -> py::bytes {
......
......@@ -275,7 +275,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDescBind prog_with_targets(origin);
for (const auto &t : targets) {
prog_with_targets.Block(t[0])->Op(t[1])->MarkAsTarget();
prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
}
ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc);
......@@ -335,7 +335,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc, nullptr);
return OpRegistry::CreateOp(desc);
})
.def("backward",
[](const OperatorBase &forwardOp,
......@@ -439,7 +439,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc, nullptr);
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def("set_stepnet", [](operators::RecurrentOp &self,
......@@ -457,7 +457,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc, nullptr);
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::DynamicRecurrentOp *>(
rnn_op.release());
})
......@@ -484,7 +484,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto cond_op = OpRegistry::CreateOp(desc, nullptr);
auto cond_op = OpRegistry::CreateOp(desc);
return static_cast<operators::CondOp *>(cond_op.release());
})
.def("set_truenet",
......@@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>())
.def("run", [](Executor &self, ProgramDescBind *program_bind,
Scope *scope, int block_id) {
self.Run(*program_bind->Proto(), scope, block_id);
});
.def("run", &Executor::Run);
m.def("unique_integer", UniqueIntegerGenerator);
m.def("init_gflags", InitGflags);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册